fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
505
consent-service/internal/services/attendance_service.go
Normal file
505
consent-service/internal/services/attendance_service.go
Normal file
@@ -0,0 +1,505 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/database"
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services/matrix"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AttendanceService handles attendance tracking and notifications
|
||||
type AttendanceService struct {
|
||||
db *database.DB
|
||||
matrix *matrix.MatrixService
|
||||
}
|
||||
|
||||
// NewAttendanceService creates a new attendance service
|
||||
func NewAttendanceService(db *database.DB, matrixService *matrix.MatrixService) *AttendanceService {
|
||||
return &AttendanceService{
|
||||
db: db,
|
||||
matrix: matrixService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Attendance Recording
|
||||
// ========================================
|
||||
|
||||
// RecordAttendance records a student's attendance for a specific lesson
|
||||
func (s *AttendanceService) RecordAttendance(ctx context.Context, req models.RecordAttendanceRequest, recordedByUserID uuid.UUID) (*models.AttendanceRecord, error) {
|
||||
studentID, err := uuid.Parse(req.StudentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid student ID: %w", err)
|
||||
}
|
||||
|
||||
slotID, err := uuid.Parse(req.SlotID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid slot ID: %w", err)
|
||||
}
|
||||
|
||||
date, err := time.Parse("2006-01-02", req.Date)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
record := &models.AttendanceRecord{
|
||||
ID: uuid.New(),
|
||||
StudentID: studentID,
|
||||
Date: date,
|
||||
SlotID: slotID,
|
||||
Status: req.Status,
|
||||
RecordedBy: recordedByUserID,
|
||||
Note: req.Note,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (student_id, date, slot_id)
|
||||
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = EXCLUDED.updated_at
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
record.ID, record.StudentID, record.Date, record.SlotID,
|
||||
record.Status, record.RecordedBy, record.Note, record.CreatedAt, record.UpdatedAt,
|
||||
).Scan(&record.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to record attendance: %w", err)
|
||||
}
|
||||
|
||||
// If student is absent, send notification to parents
|
||||
if record.Status == models.AttendanceAbsent || record.Status == models.AttendancePending {
|
||||
go s.notifyParentsOfAbsence(context.Background(), record)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RecordBulkAttendance records attendance for multiple students at once
|
||||
func (s *AttendanceService) RecordBulkAttendance(ctx context.Context, classID uuid.UUID, date string, slotID uuid.UUID, records []struct {
|
||||
StudentID string
|
||||
Status string
|
||||
Note *string
|
||||
}, recordedByUserID uuid.UUID) error {
|
||||
parsedDate, err := time.Parse("2006-01-02", date)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
for _, rec := range records {
|
||||
studentID, err := uuid.Parse(rec.StudentID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
ON CONFLICT (student_id, date, slot_id)
|
||||
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = NOW()`
|
||||
|
||||
_, err = s.db.Pool.Exec(ctx, query,
|
||||
uuid.New(), studentID, parsedDate, slotID, rec.Status, recordedByUserID, rec.Note,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record attendance for student %s: %w", rec.StudentID, err)
|
||||
}
|
||||
|
||||
// Notify parents if absent
|
||||
if rec.Status == models.AttendanceAbsent || rec.Status == models.AttendancePending {
|
||||
go s.notifyParentsOfAbsenceByStudentID(context.Background(), studentID, parsedDate, slotID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAttendanceByClass gets attendance records for a class on a specific date
|
||||
func (s *AttendanceService) GetAttendanceByClass(ctx context.Context, classID uuid.UUID, date string) (*models.ClassAttendanceOverview, error) {
|
||||
parsedDate, err := time.Parse("2006-01-02", date)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
// Get class info
|
||||
classQuery := `SELECT id, school_id, school_year_id, name, grade, section, room, is_active FROM classes WHERE id = $1`
|
||||
class := &models.Class{}
|
||||
err = s.db.Pool.QueryRow(ctx, classQuery, classID).Scan(
|
||||
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
|
||||
&class.Grade, &class.Section, &class.Room, &class.IsActive,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get class: %w", err)
|
||||
}
|
||||
|
||||
// Get total students
|
||||
var totalStudents int
|
||||
err = s.db.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM students WHERE class_id = $1 AND is_active = true`, classID).Scan(&totalStudents)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to count students: %w", err)
|
||||
}
|
||||
|
||||
// Get attendance records for the date
|
||||
recordsQuery := `
|
||||
SELECT ar.id, ar.student_id, ar.date, ar.slot_id, ar.status, ar.recorded_by, ar.note, ar.created_at, ar.updated_at
|
||||
FROM attendance_records ar
|
||||
JOIN students s ON ar.student_id = s.id
|
||||
WHERE s.class_id = $1 AND ar.date = $2
|
||||
ORDER BY ar.slot_id`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, recordsQuery, classID, parsedDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attendance records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []models.AttendanceRecord
|
||||
presentCount := 0
|
||||
absentCount := 0
|
||||
lateCount := 0
|
||||
|
||||
seenStudents := make(map[uuid.UUID]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var record models.AttendanceRecord
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.StudentID, &record.Date, &record.SlotID,
|
||||
&record.Status, &record.RecordedBy, &record.Note, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
|
||||
}
|
||||
records = append(records, record)
|
||||
|
||||
// Count unique students for summary (use first slot's status)
|
||||
if !seenStudents[record.StudentID] {
|
||||
seenStudents[record.StudentID] = true
|
||||
switch record.Status {
|
||||
case models.AttendancePresent:
|
||||
presentCount++
|
||||
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused, models.AttendancePending:
|
||||
absentCount++
|
||||
case models.AttendanceLate, models.AttendanceLateExcused:
|
||||
lateCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &models.ClassAttendanceOverview{
|
||||
Class: *class,
|
||||
Date: parsedDate,
|
||||
TotalStudents: totalStudents,
|
||||
PresentCount: presentCount,
|
||||
AbsentCount: absentCount,
|
||||
LateCount: lateCount,
|
||||
Records: records,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStudentAttendance gets attendance history for a student
|
||||
func (s *AttendanceService) GetStudentAttendance(ctx context.Context, studentID uuid.UUID, startDate, endDate time.Time) ([]models.AttendanceRecord, error) {
|
||||
query := `
|
||||
SELECT id, student_id, timetable_entry_id, date, slot_id, status, recorded_by, note, created_at, updated_at
|
||||
FROM attendance_records
|
||||
WHERE student_id = $1 AND date >= $2 AND date <= $3
|
||||
ORDER BY date DESC, slot_id`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, studentID, startDate, endDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get student attendance: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []models.AttendanceRecord
|
||||
for rows.Next() {
|
||||
var record models.AttendanceRecord
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.StudentID, &record.TimetableEntryID, &record.Date,
|
||||
&record.SlotID, &record.Status, &record.RecordedBy, &record.Note,
|
||||
&record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Absence Reports (Parent-initiated)
|
||||
// ========================================
|
||||
|
||||
// ReportAbsence allows parents to report a student's absence
|
||||
func (s *AttendanceService) ReportAbsence(ctx context.Context, req models.ReportAbsenceRequest, reportedByUserID uuid.UUID) (*models.AbsenceReport, error) {
|
||||
studentID, err := uuid.Parse(req.StudentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid student ID: %w", err)
|
||||
}
|
||||
|
||||
startDate, err := time.Parse("2006-01-02", req.StartDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start date format: %w", err)
|
||||
}
|
||||
|
||||
endDate, err := time.Parse("2006-01-02", req.EndDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end date format: %w", err)
|
||||
}
|
||||
|
||||
report := &models.AbsenceReport{
|
||||
ID: uuid.New(),
|
||||
StudentID: studentID,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
Reason: req.Reason,
|
||||
ReasonCategory: req.ReasonCategory,
|
||||
Status: "reported",
|
||||
ReportedBy: reportedByUserID,
|
||||
ReportedAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO absence_reports (id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
report.ID, report.StudentID, report.StartDate, report.EndDate,
|
||||
report.Reason, report.ReasonCategory, report.Status,
|
||||
report.ReportedBy, report.ReportedAt, report.CreatedAt, report.UpdatedAt,
|
||||
).Scan(&report.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create absence report: %w", err)
|
||||
}
|
||||
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// ConfirmAbsence allows teachers to confirm/excuse an absence
|
||||
func (s *AttendanceService) ConfirmAbsence(ctx context.Context, reportID uuid.UUID, confirmedByUserID uuid.UUID, status string) error {
|
||||
query := `
|
||||
UPDATE absence_reports
|
||||
SET status = $1, confirmed_by = $2, confirmed_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $3`
|
||||
|
||||
result, err := s.db.Pool.Exec(ctx, query, status, confirmedByUserID, reportID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to confirm absence: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("absence report not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAbsenceReports gets absence reports for a student
|
||||
func (s *AttendanceService) GetAbsenceReports(ctx context.Context, studentID uuid.UUID) ([]models.AbsenceReport, error) {
|
||||
query := `
|
||||
SELECT id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, confirmed_by, confirmed_at, medical_certificate, certificate_uploaded, matrix_notification_sent, email_notification_sent, created_at, updated_at
|
||||
FROM absence_reports
|
||||
WHERE student_id = $1
|
||||
ORDER BY start_date DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, studentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absence reports: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reports []models.AbsenceReport
|
||||
for rows.Next() {
|
||||
var report models.AbsenceReport
|
||||
err := rows.Scan(
|
||||
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
|
||||
&report.Reason, &report.ReasonCategory, &report.Status,
|
||||
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
|
||||
&report.MedicalCertificate, &report.CertificateUploaded,
|
||||
&report.MatrixNotificationSent, &report.EmailNotificationSent,
|
||||
&report.CreatedAt, &report.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan absence report: %w", err)
|
||||
}
|
||||
reports = append(reports, report)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// GetPendingAbsenceReports gets all unconfirmed absence reports for a class
|
||||
func (s *AttendanceService) GetPendingAbsenceReports(ctx context.Context, classID uuid.UUID) ([]models.AbsenceReport, error) {
|
||||
query := `
|
||||
SELECT ar.id, ar.student_id, ar.start_date, ar.end_date, ar.reason, ar.reason_category, ar.status, ar.reported_by, ar.reported_at, ar.confirmed_by, ar.confirmed_at, ar.medical_certificate, ar.certificate_uploaded, ar.matrix_notification_sent, ar.email_notification_sent, ar.created_at, ar.updated_at
|
||||
FROM absence_reports ar
|
||||
JOIN students s ON ar.student_id = s.id
|
||||
WHERE s.class_id = $1 AND ar.status = 'reported'
|
||||
ORDER BY ar.start_date DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, classID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get pending absence reports: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reports []models.AbsenceReport
|
||||
for rows.Next() {
|
||||
var report models.AbsenceReport
|
||||
err := rows.Scan(
|
||||
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
|
||||
&report.Reason, &report.ReasonCategory, &report.Status,
|
||||
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
|
||||
&report.MedicalCertificate, &report.CertificateUploaded,
|
||||
&report.MatrixNotificationSent, &report.EmailNotificationSent,
|
||||
&report.CreatedAt, &report.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan absence report: %w", err)
|
||||
}
|
||||
reports = append(reports, report)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Attendance Statistics
|
||||
// ========================================
|
||||
|
||||
// GetStudentAttendanceStats gets attendance statistics for a student
|
||||
func (s *AttendanceService) GetStudentAttendanceStats(ctx context.Context, studentID uuid.UUID, schoolYearID uuid.UUID) (map[string]interface{}, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_records,
|
||||
COUNT(CASE WHEN status = 'present' THEN 1 END) as present_count,
|
||||
COUNT(CASE WHEN status IN ('absent', 'excused', 'unexcused', 'pending_confirmation') THEN 1 END) as absent_count,
|
||||
COUNT(CASE WHEN status = 'unexcused' THEN 1 END) as unexcused_count,
|
||||
COUNT(CASE WHEN status IN ('late', 'late_excused') THEN 1 END) as late_count
|
||||
FROM attendance_records ar
|
||||
JOIN timetable_slots ts ON ar.slot_id = ts.id
|
||||
JOIN schools sch ON ts.school_id = sch.id
|
||||
JOIN school_years sy ON sy.school_id = sch.id AND sy.id = $2
|
||||
WHERE ar.student_id = $1 AND ar.date >= sy.start_date AND ar.date <= sy.end_date`
|
||||
|
||||
var totalRecords, presentCount, absentCount, unexcusedCount, lateCount int
|
||||
err := s.db.Pool.QueryRow(ctx, query, studentID, schoolYearID).Scan(
|
||||
&totalRecords, &presentCount, &absentCount, &unexcusedCount, &lateCount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attendance stats: %w", err)
|
||||
}
|
||||
|
||||
var attendanceRate float64
|
||||
if totalRecords > 0 {
|
||||
attendanceRate = float64(presentCount) / float64(totalRecords) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_records": totalRecords,
|
||||
"present_count": presentCount,
|
||||
"absent_count": absentCount,
|
||||
"unexcused_count": unexcusedCount,
|
||||
"late_count": lateCount,
|
||||
"attendance_rate": attendanceRate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Parent Notifications
|
||||
// ========================================
|
||||
|
||||
func (s *AttendanceService) notifyParentsOfAbsence(ctx context.Context, record *models.AttendanceRecord) {
|
||||
if s.matrix == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get student info
|
||||
var studentFirstName, studentLastName, matrixDMRoom string
|
||||
err := s.db.Pool.QueryRow(ctx, `
|
||||
SELECT first_name, last_name, matrix_dm_room
|
||||
FROM students
|
||||
WHERE id = $1`, record.StudentID).Scan(&studentFirstName, &studentLastName, &matrixDMRoom)
|
||||
if err != nil || matrixDMRoom == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get slot info
|
||||
var slotNumber int
|
||||
err = s.db.Pool.QueryRow(ctx, `SELECT slot_number FROM timetable_slots WHERE id = $1`, record.SlotID).Scan(&slotNumber)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
studentName := studentFirstName + " " + studentLastName
|
||||
dateStr := record.Date.Format("02.01.2006")
|
||||
|
||||
// Send Matrix notification
|
||||
err = s.matrix.SendAbsenceNotification(ctx, matrixDMRoom, studentName, dateStr, slotNumber)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to send absence notification: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update notification status
|
||||
s.db.Pool.Exec(ctx, `
|
||||
UPDATE attendance_records
|
||||
SET updated_at = NOW()
|
||||
WHERE id = $1`, record.ID)
|
||||
|
||||
// Log the notification
|
||||
s.createAbsenceNotificationLog(ctx, record.ID, studentName, dateStr, slotNumber)
|
||||
}
|
||||
|
||||
func (s *AttendanceService) notifyParentsOfAbsenceByStudentID(ctx context.Context, studentID uuid.UUID, date time.Time, slotID uuid.UUID) {
|
||||
record := &models.AttendanceRecord{
|
||||
StudentID: studentID,
|
||||
Date: date,
|
||||
SlotID: slotID,
|
||||
}
|
||||
s.notifyParentsOfAbsence(ctx, record)
|
||||
}
|
||||
|
||||
func (s *AttendanceService) createAbsenceNotificationLog(ctx context.Context, recordID uuid.UUID, studentName, dateStr string, slotNumber int) {
|
||||
// Get parent IDs for this student
|
||||
query := `
|
||||
SELECT p.id
|
||||
FROM parents p
|
||||
JOIN student_parents sp ON p.id = sp.parent_id
|
||||
JOIN attendance_records ar ON sp.student_id = ar.student_id
|
||||
WHERE ar.id = $1`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, recordID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
message := fmt.Sprintf("Abwesenheitsmeldung: %s war am %s in der %d. Stunde nicht anwesend.", studentName, dateStr, slotNumber)
|
||||
|
||||
for rows.Next() {
|
||||
var parentID uuid.UUID
|
||||
if err := rows.Scan(&parentID); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Insert notification log
|
||||
s.db.Pool.Exec(ctx, `
|
||||
INSERT INTO absence_notifications (id, attendance_record_id, parent_id, channel, message_content, sent_at, created_at)
|
||||
VALUES ($1, $2, $3, 'matrix', $4, NOW(), NOW())`,
|
||||
uuid.New(), recordID, parentID, message)
|
||||
}
|
||||
}
|
||||
388
consent-service/internal/services/attendance_service_test.go
Normal file
388
consent-service/internal/services/attendance_service_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestValidateAttendanceRecord tests attendance record validation
|
||||
func TestValidateAttendanceRecord(t *testing.T) {
|
||||
slotID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
record models.AttendanceRecord
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid present record",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid absent record",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendanceAbsent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid late record",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendanceLate,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "missing student ID",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.Nil,
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid status",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: "invalid_status",
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "future date",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now().AddDate(0, 0, 7),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing slot ID",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: uuid.Nil,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateAttendanceRecord(tt.record)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateAttendanceRecord validates an attendance record
|
||||
func validateAttendanceRecord(record models.AttendanceRecord) bool {
|
||||
if record.StudentID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if record.SlotID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if record.RecordedBy == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if record.Date.After(time.Now().AddDate(0, 0, 1)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate status
|
||||
validStatuses := map[string]bool{
|
||||
models.AttendancePresent: true,
|
||||
models.AttendanceAbsent: true,
|
||||
models.AttendanceAbsentExcused: true,
|
||||
models.AttendanceAbsentUnexcused: true,
|
||||
models.AttendanceLate: true,
|
||||
models.AttendanceLateExcused: true,
|
||||
models.AttendancePending: true,
|
||||
}
|
||||
|
||||
if !validStatuses[record.Status] {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TestValidateAbsenceReport tests absence report validation
|
||||
func TestValidateAbsenceReport(t *testing.T) {
|
||||
now := time.Now()
|
||||
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
||||
reason := "Krankheit"
|
||||
medicalReason := "Arzttermin"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
report models.AbsenceReport
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid single day absence",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "illness",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid multi-day absence",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today.AddDate(0, 0, 3),
|
||||
Reason: &medicalReason,
|
||||
ReasonCategory: "appointment",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "end before start",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today.AddDate(0, 0, 3),
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "illness",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing reason category",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid reason category",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "invalid_type",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateAbsenceReport(tt.report)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateAbsenceReport validates an absence report
|
||||
func validateAbsenceReport(report models.AbsenceReport) bool {
|
||||
if report.StudentID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if report.ReportedBy == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if report.EndDate.Before(report.StartDate) {
|
||||
return false
|
||||
}
|
||||
if report.ReasonCategory == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate reason category
|
||||
validCategories := map[string]bool{
|
||||
"illness": true,
|
||||
"appointment": true,
|
||||
"family": true,
|
||||
"other": true,
|
||||
}
|
||||
|
||||
if !validCategories[report.ReasonCategory] {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TestCalculateAttendanceStats tests attendance statistics calculation
|
||||
func TestCalculateAttendanceStats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
records []models.AttendanceRecord
|
||||
expectedPresent int
|
||||
expectedAbsent int
|
||||
expectedLate int
|
||||
}{
|
||||
{
|
||||
name: "all present",
|
||||
records: []models.AttendanceRecord{
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendancePresent},
|
||||
},
|
||||
expectedPresent: 3,
|
||||
expectedAbsent: 0,
|
||||
expectedLate: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed attendance",
|
||||
records: []models.AttendanceRecord{
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendanceAbsent},
|
||||
{Status: models.AttendanceLate},
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendanceAbsentExcused},
|
||||
},
|
||||
expectedPresent: 2,
|
||||
expectedAbsent: 2,
|
||||
expectedLate: 1,
|
||||
},
|
||||
{
|
||||
name: "empty records",
|
||||
records: []models.AttendanceRecord{},
|
||||
expectedPresent: 0,
|
||||
expectedAbsent: 0,
|
||||
expectedLate: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
present, absent, late := calculateAttendanceStats(tt.records)
|
||||
if present != tt.expectedPresent {
|
||||
t.Errorf("expected present=%d, got present=%d", tt.expectedPresent, present)
|
||||
}
|
||||
if absent != tt.expectedAbsent {
|
||||
t.Errorf("expected absent=%d, got absent=%d", tt.expectedAbsent, absent)
|
||||
}
|
||||
if late != tt.expectedLate {
|
||||
t.Errorf("expected late=%d, got late=%d", tt.expectedLate, late)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateAttendanceStats calculates attendance statistics
|
||||
func calculateAttendanceStats(records []models.AttendanceRecord) (present, absent, late int) {
|
||||
for _, r := range records {
|
||||
switch r.Status {
|
||||
case models.AttendancePresent:
|
||||
present++
|
||||
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused:
|
||||
absent++
|
||||
case models.AttendanceLate, models.AttendanceLateExcused:
|
||||
late++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TestAttendanceRateCalculation tests attendance rate percentage calculation
|
||||
func TestAttendanceRateCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
present int
|
||||
total int
|
||||
expectedRate float64
|
||||
}{
|
||||
{
|
||||
name: "100% attendance",
|
||||
present: 26,
|
||||
total: 26,
|
||||
expectedRate: 100.0,
|
||||
},
|
||||
{
|
||||
name: "92.3% attendance",
|
||||
present: 24,
|
||||
total: 26,
|
||||
expectedRate: 92.31,
|
||||
},
|
||||
{
|
||||
name: "0% attendance",
|
||||
present: 0,
|
||||
total: 26,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
{
|
||||
name: "empty class",
|
||||
present: 0,
|
||||
total: 0,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rate := calculateAttendanceRate(tt.present, tt.total)
|
||||
// Allow small floating point differences
|
||||
if rate < tt.expectedRate-0.1 || rate > tt.expectedRate+0.1 {
|
||||
t.Errorf("expected rate=%.2f, got rate=%.2f", tt.expectedRate, rate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateAttendanceRate calculates attendance rate as percentage
|
||||
func calculateAttendanceRate(present, total int) float64 {
|
||||
if total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
rate := float64(present) / float64(total) * 100
|
||||
// Round to 2 decimal places
|
||||
return float64(int(rate*100)) / 100
|
||||
}
|
||||
568
consent-service/internal/services/auth_service.go
Normal file
568
consent-service/internal/services/auth_service.go
Normal file
@@ -0,0 +1,568 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserExists = errors.New("user with this email already exists")
|
||||
ErrInvalidToken = errors.New("invalid or expired token")
|
||||
ErrAccountLocked = errors.New("account is temporarily locked")
|
||||
ErrAccountSuspended = errors.New("account is suspended")
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
)
|
||||
|
||||
// AuthService handles authentication logic
|
||||
type AuthService struct {
|
||||
db *pgxpool.Pool
|
||||
jwtSecret string
|
||||
jwtRefreshSecret string
|
||||
accessTokenExp time.Duration
|
||||
refreshTokenExp time.Duration
|
||||
}
|
||||
|
||||
// NewAuthService creates a new AuthService
|
||||
func NewAuthService(db *pgxpool.Pool, jwtSecret, jwtRefreshSecret string) *AuthService {
|
||||
return &AuthService{
|
||||
db: db,
|
||||
jwtSecret: jwtSecret,
|
||||
jwtRefreshSecret: jwtRefreshSecret,
|
||||
accessTokenExp: time.Hour * 1, // 1 hour
|
||||
refreshTokenExp: time.Hour * 24 * 30, // 30 days
|
||||
}
|
||||
}
|
||||
|
||||
// HashPassword hashes a password using bcrypt
|
||||
func (s *AuthService) HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a password against a hash
|
||||
func (s *AuthService) VerifyPassword(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GenerateSecureToken generates a cryptographically secure token
|
||||
func (s *AuthService) GenerateSecureToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// HashToken creates a SHA256 hash of a token for storage
|
||||
func (s *AuthService) HashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// JWTClaims for access tokens
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
AccountStatus string `json:"account_status"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken creates a new JWT access token
|
||||
func (s *AuthService) GenerateAccessToken(user *models.User) (string, error) {
|
||||
claims := JWTClaims{
|
||||
UserID: user.ID.String(),
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AccountStatus: user.AccountStatus,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.accessTokenExp)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Subject: user.ID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.jwtSecret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken creates a new refresh token
|
||||
func (s *AuthService) GenerateRefreshToken() (string, string, error) {
|
||||
token, err := s.GenerateSecureToken(32)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
hash := s.HashToken(token)
|
||||
return token, hash, nil
|
||||
}
|
||||
|
||||
// ValidateAccessToken validates a JWT access token
|
||||
func (s *AuthService) ValidateAccessToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Register creates a new user account
|
||||
func (s *AuthService) Register(ctx context.Context, req *models.RegisterRequest) (*models.User, string, error) {
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", req.Email).Scan(&exists)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, "", ErrUserExists
|
||||
}
|
||||
|
||||
// Hash password
|
||||
passwordHash, err := s.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Create user
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: req.Email,
|
||||
PasswordHash: &passwordHash,
|
||||
Name: req.Name,
|
||||
Role: "user",
|
||||
EmailVerified: false,
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO users (id, email, password_hash, name, role, email_verified, account_status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
`, user.ID, user.Email, user.PasswordHash, user.Name, user.Role, user.EmailVerified, user.AccountStatus)
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Generate email verification token
|
||||
verificationToken, err := s.GenerateSecureToken(32)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Store verification token
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO email_verification_tokens (user_id, token, expires_at, created_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
`, user.ID, verificationToken, time.Now().Add(24*time.Hour))
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create verification token: %w", err)
|
||||
}
|
||||
|
||||
// Create notification preferences
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO notification_preferences (user_id, email_enabled, push_enabled, in_app_enabled, reminder_frequency, created_at, updated_at)
|
||||
VALUES ($1, true, true, true, 'weekly', NOW(), NOW())
|
||||
`, user.ID)
|
||||
|
||||
if err != nil {
|
||||
// Non-critical error, just log
|
||||
fmt.Printf("Warning: failed to create notification preferences: %v\n", err)
|
||||
}
|
||||
|
||||
return user, verificationToken, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns tokens
|
||||
func (s *AuthService) Login(ctx context.Context, req *models.LoginRequest, ipAddress, userAgent string) (*models.LoginResponse, error) {
|
||||
var user models.User
|
||||
var passwordHash *string
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, email, password_hash, name, role, email_verified, account_status,
|
||||
failed_login_attempts, locked_until, created_at, updated_at
|
||||
FROM users WHERE email = $1
|
||||
`, req.Email).Scan(
|
||||
&user.ID, &user.Email, &passwordHash, &user.Name, &user.Role, &user.EmailVerified,
|
||||
&user.AccountStatus, &user.FailedLoginAttempts, &user.LockedUntil, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Check if account is locked
|
||||
if user.LockedUntil != nil && user.LockedUntil.After(time.Now()) {
|
||||
return nil, ErrAccountLocked
|
||||
}
|
||||
|
||||
// Check if account is suspended
|
||||
if user.AccountStatus == "suspended" {
|
||||
return nil, ErrAccountSuspended
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if passwordHash == nil || !s.VerifyPassword(req.Password, *passwordHash) {
|
||||
// Increment failed login attempts
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE users SET
|
||||
failed_login_attempts = failed_login_attempts + 1,
|
||||
locked_until = CASE WHEN failed_login_attempts >= 4 THEN NOW() + INTERVAL '30 minutes' ELSE locked_until END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, user.ID)
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Reset failed login attempts and update last login
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE users SET
|
||||
failed_login_attempts = 0,
|
||||
locked_until = NULL,
|
||||
last_login_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, user.ID)
|
||||
|
||||
// Generate tokens
|
||||
accessToken, err := s.GenerateAccessToken(&user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
refreshToken, refreshTokenHash, err := s.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Store session
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO user_sessions (user_id, token_hash, ip_address, user_agent, expires_at, created_at, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())
|
||||
`, user.ID, refreshTokenHash, ipAddress, userAgent, time.Now().Add(s.refreshTokenExp))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
return &models.LoginResponse{
|
||||
User: user,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: int(s.accessTokenExp.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes the access token using a refresh token
|
||||
func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*models.LoginResponse, error) {
|
||||
tokenHash := s.HashToken(refreshToken)
|
||||
|
||||
var session models.UserSession
|
||||
var userID uuid.UUID
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, revoked_at FROM user_sessions
|
||||
WHERE token_hash = $1
|
||||
`, tokenHash).Scan(&session.ID, &userID, &session.ExpiresAt, &session.RevokedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// Check if session is expired or revoked
|
||||
if session.RevokedAt != nil || session.ExpiresAt.Before(time.Now()) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// Get user
|
||||
var user models.User
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT id, email, name, role, email_verified, account_status, created_at, updated_at
|
||||
FROM users WHERE id = $1
|
||||
`, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified,
|
||||
&user.AccountStatus, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
// Check account status
|
||||
if user.AccountStatus == "suspended" {
|
||||
return nil, ErrAccountSuspended
|
||||
}
|
||||
|
||||
// Generate new access token
|
||||
accessToken, err := s.GenerateAccessToken(&user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
// Update session last activity
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE user_sessions SET last_activity_at = NOW() WHERE id = $1
|
||||
`, session.ID)
|
||||
|
||||
return &models.LoginResponse{
|
||||
User: user,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: int(s.accessTokenExp.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyEmail verifies a user's email address
|
||||
func (s *AuthService) VerifyEmail(ctx context.Context, token string) error {
|
||||
var tokenID uuid.UUID
|
||||
var userID uuid.UUID
|
||||
var expiresAt time.Time
|
||||
var usedAt *time.Time
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, used_at FROM email_verification_tokens
|
||||
WHERE token = $1
|
||||
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
|
||||
|
||||
if err != nil {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
if usedAt != nil || expiresAt.Before(time.Now()) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE email_verification_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update token: %w", err)
|
||||
}
|
||||
|
||||
// Verify user email
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE users SET email_verified = true, email_verified_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, userID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreatePasswordResetToken creates a password reset token
|
||||
func (s *AuthService) CreatePasswordResetToken(ctx context.Context, email, ipAddress string) (string, *uuid.UUID, error) {
|
||||
var userID uuid.UUID
|
||||
err := s.db.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", email).Scan(&userID)
|
||||
if err != nil {
|
||||
// Don't reveal if user exists
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
token, err := s.GenerateSecureToken(32)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO password_reset_tokens (user_id, token, expires_at, ip_address, created_at)
|
||||
VALUES ($1, $2, $3, $4, NOW())
|
||||
`, userID, token, time.Now().Add(time.Hour), ipAddress)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create reset token: %w", err)
|
||||
}
|
||||
|
||||
return token, &userID, nil
|
||||
}
|
||||
|
||||
// ResetPassword resets a user's password using a reset token
|
||||
func (s *AuthService) ResetPassword(ctx context.Context, token, newPassword string) error {
|
||||
var tokenID uuid.UUID
|
||||
var userID uuid.UUID
|
||||
var expiresAt time.Time
|
||||
var usedAt *time.Time
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, used_at FROM password_reset_tokens
|
||||
WHERE token = $1
|
||||
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
|
||||
|
||||
if err != nil {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
if usedAt != nil || expiresAt.Before(time.Now()) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
passwordHash, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE password_reset_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update token: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2
|
||||
`, passwordHash, userID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// Revoke all sessions for security
|
||||
_, err = s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to revoke sessions: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password (requires current password)
|
||||
func (s *AuthService) ChangePassword(ctx context.Context, userID uuid.UUID, currentPassword, newPassword string) error {
|
||||
var passwordHash *string
|
||||
err := s.db.QueryRow(ctx, "SELECT password_hash FROM users WHERE id = $1", userID).Scan(&passwordHash)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
if passwordHash == nil || !s.VerifyPassword(currentPassword, *passwordHash) {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
newPasswordHash, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(ctx, `UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2`, newPasswordHash, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (s *AuthService) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
|
||||
var user models.User
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, email, name, role, email_verified, email_verified_at, account_status,
|
||||
last_login_at, created_at, updated_at
|
||||
FROM users WHERE id = $1
|
||||
`, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified, &user.EmailVerifiedAt,
|
||||
&user.AccountStatus, &user.LastLoginAt, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateProfile updates a user's profile
|
||||
func (s *AuthService) UpdateProfile(ctx context.Context, userID uuid.UUID, req *models.UpdateProfileRequest) (*models.User, error) {
|
||||
_, err := s.db.Exec(ctx, `UPDATE users SET name = $1, updated_at = NOW() WHERE id = $2`, req.Name, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update profile: %w", err)
|
||||
}
|
||||
|
||||
return s.GetUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// GetActiveSessions retrieves all active sessions for a user
|
||||
func (s *AuthService) GetActiveSessions(ctx context.Context, userID uuid.UUID) ([]models.UserSession, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id, user_id, device_info, ip_address, user_agent, expires_at, created_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
||||
ORDER BY last_activity_at DESC
|
||||
`, userID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []models.UserSession
|
||||
for rows.Next() {
|
||||
var session models.UserSession
|
||||
err := rows.Scan(
|
||||
&session.ID, &session.UserID, &session.DeviceInfo, &session.IPAddress,
|
||||
&session.UserAgent, &session.ExpiresAt, &session.CreatedAt, &session.LastActivityAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan session: %w", err)
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// RevokeSession revokes a specific session
|
||||
func (s *AuthService) RevokeSession(ctx context.Context, userID, sessionID uuid.UUID) error {
|
||||
result, err := s.db.Exec(ctx, `
|
||||
UPDATE user_sessions SET revoked_at = NOW() WHERE id = $1 AND user_id = $2 AND revoked_at IS NULL
|
||||
`, sessionID, userID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke session: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return errors.New("session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logout revokes a session by refresh token
|
||||
func (s *AuthService) Logout(ctx context.Context, refreshToken string) error {
|
||||
tokenHash := s.HashToken(refreshToken)
|
||||
_, err := s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE token_hash = $1`, tokenHash)
|
||||
return err
|
||||
}
|
||||
367
consent-service/internal/services/auth_service_test.go
Normal file
367
consent-service/internal/services/auth_service_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestHashPassword tests password hashing
|
||||
func TestHashPassword(t *testing.T) {
|
||||
// Create service without DB for unit tests
|
||||
s := &AuthService{}
|
||||
|
||||
password := "testPassword123!"
|
||||
hash, err := s.HashPassword(password)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword failed: %v", err)
|
||||
}
|
||||
|
||||
if hash == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
if hash == password {
|
||||
t.Error("Hash should not equal the original password")
|
||||
}
|
||||
|
||||
// Hash should be different each time (bcrypt uses random salt)
|
||||
hash2, _ := s.HashPassword(password)
|
||||
if hash == hash2 {
|
||||
t.Error("Same password should produce different hashes due to salt")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyPassword tests password verification
|
||||
func TestVerifyPassword(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
password := "testPassword123!"
|
||||
hash, _ := s.HashPassword(password)
|
||||
|
||||
// Should verify correct password
|
||||
if !s.VerifyPassword(password, hash) {
|
||||
t.Error("VerifyPassword should return true for correct password")
|
||||
}
|
||||
|
||||
// Should reject incorrect password
|
||||
if s.VerifyPassword("wrongPassword", hash) {
|
||||
t.Error("VerifyPassword should return false for incorrect password")
|
||||
}
|
||||
|
||||
// Should reject empty password
|
||||
if s.VerifyPassword("", hash) {
|
||||
t.Error("VerifyPassword should return false for empty password")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateSecureToken tests token generation
|
||||
func TestGenerateSecureToken(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
}{
|
||||
{"short token", 16},
|
||||
{"standard token", 32},
|
||||
{"long token", 64},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := s.GenerateSecureToken(tt.length)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSecureToken failed: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Token should not be empty")
|
||||
}
|
||||
|
||||
// Tokens should be unique
|
||||
token2, _ := s.GenerateSecureToken(tt.length)
|
||||
if token == token2 {
|
||||
t.Error("Generated tokens should be unique")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashToken tests token hashing for storage
|
||||
func TestHashToken(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
token := "test-token-123"
|
||||
hash := s.HashToken(token)
|
||||
|
||||
if hash == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
if hash == token {
|
||||
t.Error("Hash should not equal the original token")
|
||||
}
|
||||
|
||||
// Same token should produce same hash (deterministic)
|
||||
hash2 := s.HashToken(token)
|
||||
if hash != hash2 {
|
||||
t.Error("Same token should produce same hash")
|
||||
}
|
||||
|
||||
// Different tokens should produce different hashes
|
||||
differentHash := s.HashToken("different-token")
|
||||
if hash == differentHash {
|
||||
t.Error("Different tokens should produce different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateAccessToken tests JWT access token generation
|
||||
func TestGenerateAccessToken(t *testing.T) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
token, err := s.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Token should not be empty")
|
||||
}
|
||||
|
||||
// Token should have three parts (header.payload.signature)
|
||||
parts := 0
|
||||
for _, c := range token {
|
||||
if c == '.' {
|
||||
parts++
|
||||
}
|
||||
}
|
||||
if parts != 2 {
|
||||
t.Errorf("JWT token should have 3 parts, got %d dots", parts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAccessToken tests JWT token validation
|
||||
func TestValidateAccessToken(t *testing.T) {
|
||||
secret := "test-secret-key-for-testing-purposes"
|
||||
s := &AuthService{
|
||||
jwtSecret: secret,
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "admin",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
token, _ := s.GenerateAccessToken(user)
|
||||
|
||||
// Should validate valid token
|
||||
claims, err := s.ValidateAccessToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != user.ID.String() {
|
||||
t.Errorf("Expected UserID %s, got %s", user.ID.String(), claims.UserID)
|
||||
}
|
||||
|
||||
if claims.Email != user.Email {
|
||||
t.Errorf("Expected Email %s, got %s", user.Email, claims.Email)
|
||||
}
|
||||
|
||||
if claims.Role != user.Role {
|
||||
t.Errorf("Expected Role %s, got %s", user.Role, claims.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAccessToken_Invalid tests invalid token scenarios
|
||||
func TestValidateAccessToken_Invalid(t *testing.T) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"empty token", ""},
|
||||
{"invalid format", "not-a-jwt-token"},
|
||||
{"invalid signature", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMTIzIn0.invalidsignature"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := s.ValidateAccessToken(tt.token)
|
||||
if err == nil {
|
||||
t.Error("ValidateAccessToken should fail for invalid token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAccessToken_WrongSecret tests token with wrong secret
|
||||
func TestValidateAccessToken_WrongSecret(t *testing.T) {
|
||||
s1 := &AuthService{
|
||||
jwtSecret: "secret-one",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
s2 := &AuthService{
|
||||
jwtSecret: "secret-two",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
// Generate token with first secret
|
||||
token, _ := s1.GenerateAccessToken(user)
|
||||
|
||||
// Try to validate with second secret (should fail)
|
||||
_, err := s2.ValidateAccessToken(token)
|
||||
if err == nil {
|
||||
t.Error("ValidateAccessToken should fail when using wrong secret")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRefreshToken tests refresh token generation
|
||||
func TestGenerateRefreshToken(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
token, hash, err := s.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRefreshToken failed: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Token should not be empty")
|
||||
}
|
||||
|
||||
if hash == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
// Verify hash matches token
|
||||
expectedHash := s.HashToken(token)
|
||||
if hash != expectedHash {
|
||||
t.Error("Returned hash should match hashed token")
|
||||
}
|
||||
|
||||
// Tokens should be unique
|
||||
token2, hash2, _ := s.GenerateRefreshToken()
|
||||
if token == token2 {
|
||||
t.Error("Generated tokens should be unique")
|
||||
}
|
||||
if hash == hash2 {
|
||||
t.Error("Generated hashes should be unique")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPasswordStrength tests various password scenarios
|
||||
func TestPasswordStrength(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
passwords := []struct {
|
||||
password string
|
||||
valid bool
|
||||
}{
|
||||
{"short", true}, // bcrypt accepts any length
|
||||
{"12345678", true}, // numbers only
|
||||
{"password", true}, // letters only
|
||||
{"Pass123!", true}, // mixed
|
||||
{"", true}, // empty (bcrypt allows)
|
||||
{string(make([]byte, 72)), true}, // max bcrypt length
|
||||
}
|
||||
|
||||
for _, p := range passwords {
|
||||
hash, err := s.HashPassword(p.password)
|
||||
if p.valid && err != nil {
|
||||
t.Errorf("HashPassword failed for valid password %q: %v", p.password, err)
|
||||
}
|
||||
if p.valid && !s.VerifyPassword(p.password, hash) {
|
||||
t.Errorf("VerifyPassword failed for password %q", p.password)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkHashPassword benchmarks password hashing
|
||||
func BenchmarkHashPassword(b *testing.B) {
|
||||
s := &AuthService{}
|
||||
password := "testPassword123!"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.HashPassword(password)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkVerifyPassword benchmarks password verification
|
||||
func BenchmarkVerifyPassword(b *testing.B) {
|
||||
s := &AuthService{}
|
||||
password := "testPassword123!"
|
||||
hash, _ := s.HashPassword(password)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.VerifyPassword(password, hash)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkGenerateAccessToken benchmarks JWT token generation
|
||||
func BenchmarkGenerateAccessToken(b *testing.B) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.GenerateAccessToken(user)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkValidateAccessToken benchmarks JWT token validation
|
||||
func BenchmarkValidateAccessToken(b *testing.B) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
token, _ := s.GenerateAccessToken(user)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.ValidateAccessToken(token)
|
||||
}
|
||||
}
|
||||
518
consent-service/internal/services/consent_service_test.go
Normal file
518
consent-service/internal/services/consent_service_test.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestConsentService_CreateConsent tests creating a new consent
|
||||
func TestConsentService_CreateConsent(t *testing.T) {
|
||||
// This is a unit test with table-driven approach
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
versionID uuid.UUID
|
||||
consented bool
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid consent - accepted",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
consented: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid consent - declined",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
consented: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty user ID",
|
||||
userID: uuid.Nil,
|
||||
versionID: uuid.New(),
|
||||
consented: true,
|
||||
expectError: true,
|
||||
errorContains: "user ID",
|
||||
},
|
||||
{
|
||||
name: "empty version ID",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.Nil,
|
||||
consented: true,
|
||||
expectError: true,
|
||||
errorContains: "version ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate inputs (in real implementation this would be in the service)
|
||||
var hasError bool
|
||||
if tt.userID == uuid.Nil {
|
||||
hasError = true
|
||||
} else if tt.versionID == uuid.Nil {
|
||||
hasError = true
|
||||
}
|
||||
|
||||
// Assert
|
||||
if tt.expectError && !hasError {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
if !tt.expectError && hasError {
|
||||
t.Error("Expected no error, got error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_WithdrawConsent tests withdrawing consent
|
||||
func TestConsentService_WithdrawConsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
consentID uuid.UUID
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid withdrawal",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty consent ID",
|
||||
consentID: uuid.Nil,
|
||||
userID: uuid.New(),
|
||||
expectError: true,
|
||||
errorContains: "consent ID",
|
||||
},
|
||||
{
|
||||
name: "empty user ID",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
errorContains: "user ID",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
consentID: uuid.Nil,
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
errorContains: "ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate
|
||||
var hasError bool
|
||||
if tt.consentID == uuid.Nil || tt.userID == uuid.Nil {
|
||||
hasError = true
|
||||
}
|
||||
|
||||
// Assert
|
||||
if tt.expectError && !hasError {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
if !tt.expectError && hasError {
|
||||
t.Error("Expected no error, got error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_CheckConsent tests checking consent status
|
||||
func TestConsentService_CheckConsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
documentType string
|
||||
language string
|
||||
hasConsent bool
|
||||
needsUpdate bool
|
||||
expectedConsent bool
|
||||
expectedNeedsUpd bool
|
||||
}{
|
||||
{
|
||||
name: "user has current consent",
|
||||
userID: uuid.New(),
|
||||
documentType: "terms",
|
||||
language: "de",
|
||||
hasConsent: true,
|
||||
needsUpdate: false,
|
||||
expectedConsent: true,
|
||||
expectedNeedsUpd: false,
|
||||
},
|
||||
{
|
||||
name: "user has outdated consent",
|
||||
userID: uuid.New(),
|
||||
documentType: "privacy",
|
||||
language: "de",
|
||||
hasConsent: true,
|
||||
needsUpdate: true,
|
||||
expectedConsent: true,
|
||||
expectedNeedsUpd: true,
|
||||
},
|
||||
{
|
||||
name: "user has no consent",
|
||||
userID: uuid.New(),
|
||||
documentType: "cookies",
|
||||
language: "de",
|
||||
hasConsent: false,
|
||||
needsUpdate: true,
|
||||
expectedConsent: false,
|
||||
expectedNeedsUpd: true,
|
||||
},
|
||||
{
|
||||
name: "english language",
|
||||
userID: uuid.New(),
|
||||
documentType: "terms",
|
||||
language: "en",
|
||||
hasConsent: true,
|
||||
needsUpdate: false,
|
||||
expectedConsent: true,
|
||||
expectedNeedsUpd: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate consent check logic
|
||||
hasConsent := tt.hasConsent
|
||||
needsUpdate := tt.needsUpdate
|
||||
|
||||
// Assert
|
||||
if hasConsent != tt.expectedConsent {
|
||||
t.Errorf("Expected hasConsent=%v, got %v", tt.expectedConsent, hasConsent)
|
||||
}
|
||||
if needsUpdate != tt.expectedNeedsUpd {
|
||||
t.Errorf("Expected needsUpdate=%v, got %v", tt.expectedNeedsUpd, needsUpdate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_GetConsentHistory tests retrieving consent history
|
||||
func TestConsentService_GetConsentHistory(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "valid user with consents",
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
expectEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "valid user without consents",
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
expectEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
// Assert error expectation
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_UpdateConsent tests updating existing consent
|
||||
func TestConsentService_UpdateConsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
consentID uuid.UUID
|
||||
userID uuid.UUID
|
||||
newConsented bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "update to consented",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
newConsented: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "update to not consented",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
newConsented: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid consent ID",
|
||||
consentID: uuid.Nil,
|
||||
userID: uuid.New(),
|
||||
newConsented: true,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.consentID == uuid.Nil {
|
||||
err = &ValidationError{Field: "consent ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_GetConsentStats tests getting consent statistics
|
||||
func TestConsentService_GetConsentStats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentType string
|
||||
totalUsers int
|
||||
consentedUsers int
|
||||
expectedRate float64
|
||||
}{
|
||||
{
|
||||
name: "100% consent rate",
|
||||
documentType: "terms",
|
||||
totalUsers: 100,
|
||||
consentedUsers: 100,
|
||||
expectedRate: 100.0,
|
||||
},
|
||||
{
|
||||
name: "50% consent rate",
|
||||
documentType: "privacy",
|
||||
totalUsers: 100,
|
||||
consentedUsers: 50,
|
||||
expectedRate: 50.0,
|
||||
},
|
||||
{
|
||||
name: "0% consent rate",
|
||||
documentType: "cookies",
|
||||
totalUsers: 100,
|
||||
consentedUsers: 0,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
{
|
||||
name: "no users",
|
||||
documentType: "terms",
|
||||
totalUsers: 0,
|
||||
consentedUsers: 0,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Calculate consent rate
|
||||
var consentRate float64
|
||||
if tt.totalUsers > 0 {
|
||||
consentRate = float64(tt.consentedUsers) / float64(tt.totalUsers) * 100
|
||||
}
|
||||
|
||||
// Assert
|
||||
if consentRate != tt.expectedRate {
|
||||
t.Errorf("Expected consent rate %.2f%%, got %.2f%%", tt.expectedRate, consentRate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_BulkConsentCheck tests checking multiple consents at once
|
||||
func TestConsentService_BulkConsentCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
documentTypes []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "check multiple documents",
|
||||
userID: uuid.New(),
|
||||
documentTypes: []string{"terms", "privacy", "cookies"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "check single document",
|
||||
userID: uuid.New(),
|
||||
documentTypes: []string{"terms"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty document list",
|
||||
userID: uuid.New(),
|
||||
documentTypes: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
documentTypes: []string{"terms"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_ConsentVersionComparison tests version comparison logic
|
||||
func TestConsentService_ConsentVersionComparison(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentVersion string
|
||||
consentedVersion string
|
||||
needsUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "same version",
|
||||
currentVersion: "1.0.0",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "minor version update",
|
||||
currentVersion: "1.1.0",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "major version update",
|
||||
currentVersion: "2.0.0",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "patch version update",
|
||||
currentVersion: "1.0.1",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simple version comparison (in real implementation use proper semver)
|
||||
needsUpdate := tt.currentVersion != tt.consentedVersion
|
||||
|
||||
if needsUpdate != tt.needsUpdate {
|
||||
t.Errorf("Expected needsUpdate=%v, got %v", tt.needsUpdate, needsUpdate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_ConsentDeadlineCheck tests deadline validation
|
||||
func TestConsentService_ConsentDeadlineCheck(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadline time.Time
|
||||
isOverdue bool
|
||||
daysLeft int
|
||||
}{
|
||||
{
|
||||
name: "deadline in 30 days",
|
||||
deadline: now.AddDate(0, 0, 30),
|
||||
isOverdue: false,
|
||||
daysLeft: 30,
|
||||
},
|
||||
{
|
||||
name: "deadline in 7 days",
|
||||
deadline: now.AddDate(0, 0, 7),
|
||||
isOverdue: false,
|
||||
daysLeft: 7,
|
||||
},
|
||||
{
|
||||
name: "deadline today",
|
||||
deadline: now,
|
||||
isOverdue: false,
|
||||
daysLeft: 0,
|
||||
},
|
||||
{
|
||||
name: "deadline 1 day overdue",
|
||||
deadline: now.AddDate(0, 0, -1),
|
||||
isOverdue: true,
|
||||
daysLeft: -1,
|
||||
},
|
||||
{
|
||||
name: "deadline 30 days overdue",
|
||||
deadline: now.AddDate(0, 0, -30),
|
||||
isOverdue: true,
|
||||
daysLeft: -30,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Calculate if overdue
|
||||
isOverdue := tt.deadline.Before(now)
|
||||
daysLeft := int(tt.deadline.Sub(now).Hours() / 24)
|
||||
|
||||
if isOverdue != tt.isOverdue {
|
||||
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
|
||||
}
|
||||
|
||||
// Allow 1 day difference due to time precision
|
||||
if abs(daysLeft-tt.daysLeft) > 1 {
|
||||
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// abs returns the absolute value of an integer
|
||||
func abs(n int) int {
|
||||
if n < 0 {
|
||||
return -n
|
||||
}
|
||||
return n
|
||||
}
|
||||
434
consent-service/internal/services/deadline_service.go
Normal file
434
consent-service/internal/services/deadline_service.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// DeadlineService handles consent deadlines and account suspensions
|
||||
type DeadlineService struct {
|
||||
pool *pgxpool.Pool
|
||||
notificationService *NotificationService
|
||||
}
|
||||
|
||||
// ConsentDeadline represents a consent deadline for a user
|
||||
type ConsentDeadline struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
DocumentVersionID uuid.UUID `json:"document_version_id"`
|
||||
DeadlineAt time.Time `json:"deadline_at"`
|
||||
ReminderCount int `json:"reminder_count"`
|
||||
LastReminderAt *time.Time `json:"last_reminder_at"`
|
||||
ConsentGivenAt *time.Time `json:"consent_given_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Joined fields
|
||||
DocumentName string `json:"document_name"`
|
||||
VersionNumber string `json:"version_number"`
|
||||
}
|
||||
|
||||
// AccountSuspension represents an account suspension
|
||||
type AccountSuspension struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Reason string `json:"reason"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
SuspendedAt time.Time `json:"suspended_at"`
|
||||
LiftedAt *time.Time `json:"lifted_at"`
|
||||
LiftedBy *uuid.UUID `json:"lifted_by"`
|
||||
}
|
||||
|
||||
// NewDeadlineService creates a new deadline service
|
||||
func NewDeadlineService(pool *pgxpool.Pool, notificationService *NotificationService) *DeadlineService {
|
||||
return &DeadlineService{
|
||||
pool: pool,
|
||||
notificationService: notificationService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDeadlinesForPublishedVersion creates consent deadlines for all active users
|
||||
// when a new mandatory document version is published
|
||||
func (s *DeadlineService) CreateDeadlinesForPublishedVersion(ctx context.Context, versionID uuid.UUID) error {
|
||||
// Get version info
|
||||
var documentName, versionNumber string
|
||||
var isMandatory bool
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT ld.name, dv.version, ld.is_mandatory
|
||||
FROM document_versions dv
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE dv.id = $1
|
||||
`, versionID).Scan(&documentName, &versionNumber, &isMandatory)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get version info: %w", err)
|
||||
}
|
||||
|
||||
// Only create deadlines for mandatory documents
|
||||
if !isMandatory {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deadline is 30 days from now
|
||||
deadlineAt := time.Now().AddDate(0, 0, 30)
|
||||
|
||||
// Get all active users who haven't given consent to this version
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
INSERT INTO consent_deadlines (user_id, document_version_id, deadline_at)
|
||||
SELECT u.id, $1, $2
|
||||
FROM users u
|
||||
WHERE u.account_status = 'active'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM user_consents uc
|
||||
WHERE uc.user_id = u.id AND uc.document_version_id = $1 AND uc.consented = TRUE
|
||||
)
|
||||
ON CONFLICT (user_id, document_version_id) DO NOTHING
|
||||
`, versionID, deadlineAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create deadlines: %w", err)
|
||||
}
|
||||
|
||||
// Notify users via notification service
|
||||
if s.notificationService != nil {
|
||||
go s.notificationService.NotifyConsentRequired(ctx, documentName, versionID.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkConsentGiven marks a deadline as fulfilled when user gives consent
|
||||
func (s *DeadlineService) MarkConsentGiven(ctx context.Context, userID, versionID uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE consent_deadlines
|
||||
SET consent_given_at = NOW()
|
||||
WHERE user_id = $1 AND document_version_id = $2 AND consent_given_at IS NULL
|
||||
`, userID, versionID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if user should be unsuspended
|
||||
return s.checkAndLiftSuspension(ctx, userID)
|
||||
}
|
||||
|
||||
// GetPendingDeadlines returns all pending deadlines for a user
|
||||
func (s *DeadlineService) GetPendingDeadlines(ctx context.Context, userID uuid.UUID) ([]ConsentDeadline, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at,
|
||||
cd.reminder_count, cd.last_reminder_at, cd.consent_given_at, cd.created_at,
|
||||
ld.name as document_name, dv.version as version_number
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE cd.user_id = $1 AND cd.consent_given_at IS NULL
|
||||
ORDER BY cd.deadline_at ASC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var deadlines []ConsentDeadline
|
||||
for rows.Next() {
|
||||
var d ConsentDeadline
|
||||
if err := rows.Scan(&d.ID, &d.UserID, &d.DocumentVersionID, &d.DeadlineAt,
|
||||
&d.ReminderCount, &d.LastReminderAt, &d.ConsentGivenAt, &d.CreatedAt,
|
||||
&d.DocumentName, &d.VersionNumber); err != nil {
|
||||
continue
|
||||
}
|
||||
deadlines = append(deadlines, d)
|
||||
}
|
||||
|
||||
return deadlines, nil
|
||||
}
|
||||
|
||||
// ProcessDailyDeadlines is meant to be called by a cron job daily
|
||||
// It sends reminders and suspends accounts that have missed deadlines
|
||||
func (s *DeadlineService) ProcessDailyDeadlines(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// 1. Send reminders for upcoming deadlines
|
||||
if err := s.sendReminders(ctx, now); err != nil {
|
||||
fmt.Printf("Error sending reminders: %v\n", err)
|
||||
}
|
||||
|
||||
// 2. Suspend accounts with expired deadlines
|
||||
if err := s.suspendExpiredAccounts(ctx, now); err != nil {
|
||||
fmt.Printf("Error suspending accounts: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendReminders sends reminder notifications based on days remaining
|
||||
func (s *DeadlineService) sendReminders(ctx context.Context, now time.Time) error {
|
||||
// Reminder schedule: Day 7, 14, 21, 28
|
||||
reminderDays := []int{7, 14, 21, 28}
|
||||
|
||||
for _, days := range reminderDays {
|
||||
targetDate := now.AddDate(0, 0, days)
|
||||
dayStart := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, targetDate.Location())
|
||||
dayEnd := dayStart.AddDate(0, 0, 1)
|
||||
|
||||
// Find deadlines that fall on this reminder day
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at, cd.reminder_count,
|
||||
ld.name as document_name
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE cd.consent_given_at IS NULL
|
||||
AND cd.deadline_at >= $1 AND cd.deadline_at < $2
|
||||
AND (cd.last_reminder_at IS NULL OR cd.last_reminder_at < $3)
|
||||
`, dayStart, dayEnd, dayStart)
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var id, userID, versionID uuid.UUID
|
||||
var deadlineAt time.Time
|
||||
var reminderCount int
|
||||
var documentName string
|
||||
|
||||
if err := rows.Scan(&id, &userID, &versionID, &deadlineAt, &reminderCount, &documentName); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send reminder notification
|
||||
daysLeft := 30 - (30 - days)
|
||||
urgency := "freundlich"
|
||||
if days <= 7 {
|
||||
urgency = "dringend"
|
||||
} else if days <= 14 {
|
||||
urgency = "wichtig"
|
||||
}
|
||||
|
||||
title := fmt.Sprintf("Erinnerung: Zustimmung erforderlich (%s)", urgency)
|
||||
body := fmt.Sprintf("Bitte bestätigen Sie '%s' innerhalb von %d Tagen.", documentName, daysLeft)
|
||||
|
||||
if s.notificationService != nil {
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeConsentReminder, title, body, map[string]interface{}{
|
||||
"document_name": documentName,
|
||||
"days_left": daysLeft,
|
||||
"version_id": versionID.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update reminder count and timestamp
|
||||
s.pool.Exec(ctx, `
|
||||
UPDATE consent_deadlines
|
||||
SET reminder_count = reminder_count + 1, last_reminder_at = NOW()
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// suspendExpiredAccounts suspends accounts that have missed their deadline
|
||||
func (s *DeadlineService) suspendExpiredAccounts(ctx context.Context, now time.Time) error {
|
||||
// Find users with expired deadlines
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT DISTINCT cd.user_id, array_agg(ld.name) as documents
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
JOIN users u ON cd.user_id = u.id
|
||||
WHERE cd.consent_given_at IS NULL
|
||||
AND cd.deadline_at < $1
|
||||
AND u.account_status = 'active'
|
||||
AND ld.is_mandatory = TRUE
|
||||
GROUP BY cd.user_id
|
||||
`, now)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
var documents []string
|
||||
|
||||
if err := rows.Scan(&userID, &documents); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Suspend the account
|
||||
if err := s.suspendAccount(ctx, userID, "consent_deadline_missed", documents); err != nil {
|
||||
fmt.Printf("Failed to suspend user %s: %v\n", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// suspendAccount suspends a user account
|
||||
func (s *DeadlineService) suspendAccount(ctx context.Context, userID uuid.UUID, reason string, documents []string) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Update user status
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE users SET account_status = 'suspended', updated_at = NOW()
|
||||
WHERE id = $1 AND account_status = 'active'
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create suspension record
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO account_suspensions (user_id, reason, details)
|
||||
VALUES ($1, $2, $3)
|
||||
`, userID, reason, map[string]interface{}{"documents": documents})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Log to audit
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id, details)
|
||||
VALUES ($1, 'account_suspended', 'user', $1, $2)
|
||||
`, userID, map[string]interface{}{"reason": reason, "documents": documents})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send suspension notification
|
||||
if s.notificationService != nil {
|
||||
title := "Account vorübergehend gesperrt"
|
||||
body := "Ihr Account wurde gesperrt, da ausstehende Zustimmungen nicht innerhalb der Frist erteilt wurden. Bitte bestätigen Sie die ausstehenden Dokumente."
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountSuspended, title, body, map[string]interface{}{
|
||||
"documents": documents,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAndLiftSuspension checks if user has completed all required consents and lifts suspension
|
||||
func (s *DeadlineService) checkAndLiftSuspension(ctx context.Context, userID uuid.UUID) error {
|
||||
// Check if user is currently suspended
|
||||
var accountStatus string
|
||||
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
|
||||
if err != nil || accountStatus != "suspended" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if there are any pending mandatory consents
|
||||
var pendingCount int
|
||||
err = s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE cd.user_id = $1
|
||||
AND cd.consent_given_at IS NULL
|
||||
AND ld.is_mandatory = TRUE
|
||||
`, userID).Scan(&pendingCount)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no pending consents, lift the suspension
|
||||
if pendingCount == 0 {
|
||||
return s.liftSuspension(ctx, userID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// liftSuspension lifts a user's suspension
|
||||
func (s *DeadlineService) liftSuspension(ctx context.Context, userID uuid.UUID) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Update user status
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE users SET account_status = 'active', updated_at = NOW()
|
||||
WHERE id = $1 AND account_status = 'suspended'
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update suspension record
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE account_suspensions
|
||||
SET lifted_at = NOW()
|
||||
WHERE user_id = $1 AND lifted_at IS NULL
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Log to audit
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id)
|
||||
VALUES ($1, 'account_restored', 'user', $1)
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send restoration notification
|
||||
if s.notificationService != nil {
|
||||
title := "Account wiederhergestellt"
|
||||
body := "Vielen Dank! Ihr Account wurde wiederhergestellt. Sie können die Anwendung wieder vollständig nutzen."
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountRestored, title, body, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountSuspension returns the current suspension for a user
|
||||
func (s *DeadlineService) GetAccountSuspension(ctx context.Context, userID uuid.UUID) (*AccountSuspension, error) {
|
||||
var suspension AccountSuspension
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, user_id, reason, details, suspended_at, lifted_at, lifted_by
|
||||
FROM account_suspensions
|
||||
WHERE user_id = $1 AND lifted_at IS NULL
|
||||
ORDER BY suspended_at DESC
|
||||
LIMIT 1
|
||||
`, userID).Scan(&suspension.ID, &suspension.UserID, &suspension.Reason, &suspension.Details,
|
||||
&suspension.SuspendedAt, &suspension.LiftedAt, &suspension.LiftedBy)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &suspension, nil
|
||||
}
|
||||
|
||||
// IsUserSuspended checks if a user is currently suspended
|
||||
func (s *DeadlineService) IsUserSuspended(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
var status string
|
||||
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&status)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return status == "suspended", nil
|
||||
}
|
||||
439
consent-service/internal/services/deadline_service_test.go
Normal file
439
consent-service/internal/services/deadline_service_test.go
Normal file
@@ -0,0 +1,439 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestDeadlineService_CreateDeadline tests creating consent deadlines
|
||||
func TestDeadlineService_CreateDeadline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
versionID uuid.UUID
|
||||
deadlineAt time.Time
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid deadline - 30 days",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, 30),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid deadline - 14 days",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, 14),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, 30),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid version ID",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.Nil,
|
||||
deadlineAt: time.Now().AddDate(0, 0, 30),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "deadline in past",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, -1),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if tt.versionID == uuid.Nil {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
} else if tt.deadlineAt.Before(time.Now()) {
|
||||
err = &ValidationError{Field: "deadline", Message: "must be in the future"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_CheckDeadlineStatus tests deadline status checking
|
||||
func TestDeadlineService_CheckDeadlineStatus(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadlineAt time.Time
|
||||
isOverdue bool
|
||||
daysLeft int
|
||||
urgency string
|
||||
}{
|
||||
{
|
||||
name: "30 days left",
|
||||
deadlineAt: now.AddDate(0, 0, 30),
|
||||
isOverdue: false,
|
||||
daysLeft: 30,
|
||||
urgency: "normal",
|
||||
},
|
||||
{
|
||||
name: "7 days left - warning",
|
||||
deadlineAt: now.AddDate(0, 0, 7),
|
||||
isOverdue: false,
|
||||
daysLeft: 7,
|
||||
urgency: "warning",
|
||||
},
|
||||
{
|
||||
name: "3 days left - urgent",
|
||||
deadlineAt: now.AddDate(0, 0, 3),
|
||||
isOverdue: false,
|
||||
daysLeft: 3,
|
||||
urgency: "urgent",
|
||||
},
|
||||
{
|
||||
name: "1 day left - critical",
|
||||
deadlineAt: now.AddDate(0, 0, 1),
|
||||
isOverdue: false,
|
||||
daysLeft: 1,
|
||||
urgency: "critical",
|
||||
},
|
||||
{
|
||||
name: "overdue by 1 day",
|
||||
deadlineAt: now.AddDate(0, 0, -1),
|
||||
isOverdue: true,
|
||||
daysLeft: -1,
|
||||
urgency: "overdue",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isOverdue := tt.deadlineAt.Before(now)
|
||||
daysLeft := int(tt.deadlineAt.Sub(now).Hours() / 24)
|
||||
|
||||
var urgency string
|
||||
if isOverdue {
|
||||
urgency = "overdue"
|
||||
} else if daysLeft <= 1 {
|
||||
urgency = "critical"
|
||||
} else if daysLeft <= 3 {
|
||||
urgency = "urgent"
|
||||
} else if daysLeft <= 7 {
|
||||
urgency = "warning"
|
||||
} else {
|
||||
urgency = "normal"
|
||||
}
|
||||
|
||||
if isOverdue != tt.isOverdue {
|
||||
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
|
||||
}
|
||||
|
||||
if abs(daysLeft-tt.daysLeft) > 1 { // Allow 1 day difference
|
||||
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
|
||||
}
|
||||
|
||||
if urgency != tt.urgency {
|
||||
t.Errorf("Expected urgency=%s, got %s", tt.urgency, urgency)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_SendReminders tests reminder scheduling
|
||||
func TestDeadlineService_SendReminders(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadlineAt time.Time
|
||||
lastReminderAt *time.Time
|
||||
reminderCount int
|
||||
shouldSend bool
|
||||
nextReminder int // days before deadline
|
||||
}{
|
||||
{
|
||||
name: "first reminder - 14 days before",
|
||||
deadlineAt: now.AddDate(0, 0, 14),
|
||||
lastReminderAt: nil,
|
||||
reminderCount: 0,
|
||||
shouldSend: true,
|
||||
nextReminder: 14,
|
||||
},
|
||||
{
|
||||
name: "second reminder - 7 days before",
|
||||
deadlineAt: now.AddDate(0, 0, 7),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -7)),
|
||||
reminderCount: 1,
|
||||
shouldSend: true,
|
||||
nextReminder: 7,
|
||||
},
|
||||
{
|
||||
name: "third reminder - 3 days before",
|
||||
deadlineAt: now.AddDate(0, 0, 3),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -4)),
|
||||
reminderCount: 2,
|
||||
shouldSend: true,
|
||||
nextReminder: 3,
|
||||
},
|
||||
{
|
||||
name: "final reminder - 1 day before",
|
||||
deadlineAt: now.AddDate(0, 0, 1),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -2)),
|
||||
reminderCount: 3,
|
||||
shouldSend: true,
|
||||
nextReminder: 1,
|
||||
},
|
||||
{
|
||||
name: "too soon for next reminder",
|
||||
deadlineAt: now.AddDate(0, 0, 10),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -1)),
|
||||
reminderCount: 1,
|
||||
shouldSend: false,
|
||||
nextReminder: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
daysUntilDeadline := int(tt.deadlineAt.Sub(now).Hours() / 24)
|
||||
|
||||
// Reminder schedule: 14, 7, 3, 1 days before deadline
|
||||
reminderDays := []int{14, 7, 3, 1}
|
||||
shouldSend := false
|
||||
|
||||
for _, day := range reminderDays {
|
||||
if daysUntilDeadline == day {
|
||||
// Check if enough time passed since last reminder
|
||||
if tt.lastReminderAt == nil || now.Sub(*tt.lastReminderAt) > 12*time.Hour {
|
||||
shouldSend = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldSend != tt.shouldSend {
|
||||
t.Errorf("Expected shouldSend=%v, got %v", tt.shouldSend, shouldSend)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_SuspendAccount tests account suspension logic
|
||||
func TestDeadlineService_SuspendAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
reason string
|
||||
shouldSuspend bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "suspend for missed deadline",
|
||||
userID: uuid.New(),
|
||||
reason: "consent_deadline_exceeded",
|
||||
shouldSuspend: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
reason: "consent_deadline_exceeded",
|
||||
shouldSuspend: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid reason",
|
||||
userID: uuid.New(),
|
||||
reason: "",
|
||||
shouldSuspend: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
validReasons := map[string]bool{
|
||||
"consent_deadline_exceeded": true,
|
||||
"mandatory_consent_missing": true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if !validReasons[tt.reason] && tt.reason != "" {
|
||||
err = &ValidationError{Field: "reason", Message: "invalid suspension reason"}
|
||||
} else if tt.reason == "" {
|
||||
err = &ValidationError{Field: "reason", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_LiftSuspension tests lifting account suspension
|
||||
func TestDeadlineService_LiftSuspension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
adminID uuid.UUID
|
||||
reason string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "lift valid suspension",
|
||||
userID: uuid.New(),
|
||||
adminID: uuid.New(),
|
||||
reason: "consent provided",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
adminID: uuid.New(),
|
||||
reason: "consent provided",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid admin ID",
|
||||
userID: uuid.New(),
|
||||
adminID: uuid.Nil,
|
||||
reason: "consent provided",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if tt.adminID == uuid.Nil {
|
||||
err = &ValidationError{Field: "admin ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_GetOverdueDeadlines tests finding overdue deadlines
|
||||
func TestDeadlineService_GetOverdueDeadlines(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadlines []time.Time
|
||||
expected int // number of overdue
|
||||
}{
|
||||
{
|
||||
name: "no overdue deadlines",
|
||||
deadlines: []time.Time{
|
||||
now.AddDate(0, 0, 1),
|
||||
now.AddDate(0, 0, 7),
|
||||
now.AddDate(0, 0, 30),
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "some overdue",
|
||||
deadlines: []time.Time{
|
||||
now.AddDate(0, 0, -1),
|
||||
now.AddDate(0, 0, -5),
|
||||
now.AddDate(0, 0, 7),
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "all overdue",
|
||||
deadlines: []time.Time{
|
||||
now.AddDate(0, 0, -1),
|
||||
now.AddDate(0, 0, -7),
|
||||
now.AddDate(0, 0, -30),
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
overdueCount := 0
|
||||
for _, deadline := range tt.deadlines {
|
||||
if deadline.Before(now) {
|
||||
overdueCount++
|
||||
}
|
||||
}
|
||||
|
||||
if overdueCount != tt.expected {
|
||||
t.Errorf("Expected %d overdue, got %d", tt.expected, overdueCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_ProcessScheduledTasks tests scheduled task processing
|
||||
func TestDeadlineService_ProcessScheduledTasks(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
task string
|
||||
scheduledAt time.Time
|
||||
shouldProcess bool
|
||||
}{
|
||||
{
|
||||
name: "process due task",
|
||||
task: "send_reminder",
|
||||
scheduledAt: now.Add(-1 * time.Hour),
|
||||
shouldProcess: true,
|
||||
},
|
||||
{
|
||||
name: "skip future task",
|
||||
task: "send_reminder",
|
||||
scheduledAt: now.Add(1 * time.Hour),
|
||||
shouldProcess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldProcess := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
|
||||
|
||||
if shouldProcess != tt.shouldProcess {
|
||||
t.Errorf("Expected shouldProcess=%v, got %v", tt.shouldProcess, shouldProcess)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
728
consent-service/internal/services/document_service_test.go
Normal file
728
consent-service/internal/services/document_service_test.go
Normal file
@@ -0,0 +1,728 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestDocumentService_CreateDocument tests creating a new legal document
|
||||
func TestDocumentService_CreateDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docType string
|
||||
docName string
|
||||
description string
|
||||
isMandatory bool
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid mandatory document",
|
||||
docType: "terms",
|
||||
docName: "Terms of Service",
|
||||
description: "Our terms and conditions",
|
||||
isMandatory: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid optional document",
|
||||
docType: "cookies",
|
||||
docName: "Cookie Policy",
|
||||
description: "How we use cookies",
|
||||
isMandatory: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty document type",
|
||||
docType: "",
|
||||
docName: "Test Document",
|
||||
description: "Test",
|
||||
isMandatory: true,
|
||||
expectError: true,
|
||||
errorContains: "type",
|
||||
},
|
||||
{
|
||||
name: "empty document name",
|
||||
docType: "privacy",
|
||||
docName: "",
|
||||
description: "Test",
|
||||
isMandatory: true,
|
||||
expectError: true,
|
||||
errorContains: "name",
|
||||
},
|
||||
{
|
||||
name: "invalid document type",
|
||||
docType: "invalid_type",
|
||||
docName: "Test",
|
||||
description: "Test",
|
||||
isMandatory: false,
|
||||
expectError: true,
|
||||
errorContains: "type",
|
||||
},
|
||||
}
|
||||
|
||||
validTypes := map[string]bool{
|
||||
"terms": true,
|
||||
"privacy": true,
|
||||
"cookies": true,
|
||||
"community_guidelines": true,
|
||||
"imprint": true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate inputs
|
||||
var err error
|
||||
if tt.docType == "" {
|
||||
err = &ValidationError{Field: "type", Message: "required"}
|
||||
} else if !validTypes[tt.docType] {
|
||||
err = &ValidationError{Field: "type", Message: "invalid document type"}
|
||||
} else if tt.docName == "" {
|
||||
err = &ValidationError{Field: "name", Message: "required"}
|
||||
}
|
||||
|
||||
// Assert
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_UpdateDocument tests updating a document
|
||||
func TestDocumentService_UpdateDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentID uuid.UUID
|
||||
newName string
|
||||
newActive bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid update",
|
||||
documentID: uuid.New(),
|
||||
newName: "Updated Name",
|
||||
newActive: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "deactivate document",
|
||||
documentID: uuid.New(),
|
||||
newName: "Test",
|
||||
newActive: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid document ID",
|
||||
documentID: uuid.Nil,
|
||||
newName: "Test",
|
||||
newActive: true,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.documentID == uuid.Nil {
|
||||
err = &ValidationError{Field: "document ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_CreateVersion tests creating a document version
|
||||
func TestDocumentService_CreateVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentID uuid.UUID
|
||||
version string
|
||||
language string
|
||||
title string
|
||||
content string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid version - German",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "de",
|
||||
title: "Nutzungsbedingungen",
|
||||
content: "<h1>Terms</h1><p>Content...</p>",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid version - English",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "en",
|
||||
title: "Terms of Service",
|
||||
content: "<h1>Terms</h1><p>Content...</p>",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid version format",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0",
|
||||
language: "de",
|
||||
title: "Test",
|
||||
content: "Content",
|
||||
expectError: true,
|
||||
errorContains: "version",
|
||||
},
|
||||
{
|
||||
name: "invalid language",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "fr",
|
||||
title: "Test",
|
||||
content: "Content",
|
||||
expectError: true,
|
||||
errorContains: "language",
|
||||
},
|
||||
{
|
||||
name: "empty title",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "de",
|
||||
title: "",
|
||||
content: "Content",
|
||||
expectError: true,
|
||||
errorContains: "title",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "de",
|
||||
title: "Test",
|
||||
content: "",
|
||||
expectError: true,
|
||||
errorContains: "content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate semver format (X.Y.Z pattern)
|
||||
validVersion := regexp.MustCompile(`^\d+\.\d+\.\d+$`).MatchString(tt.version)
|
||||
validLanguage := tt.language == "de" || tt.language == "en"
|
||||
|
||||
var err error
|
||||
if !validVersion {
|
||||
err = &ValidationError{Field: "version", Message: "invalid format"}
|
||||
} else if !validLanguage {
|
||||
err = &ValidationError{Field: "language", Message: "must be 'de' or 'en'"}
|
||||
} else if tt.title == "" {
|
||||
err = &ValidationError{Field: "title", Message: "required"}
|
||||
} else if tt.content == "" {
|
||||
err = &ValidationError{Field: "content", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_VersionStatusTransitions tests version status workflow
|
||||
func TestDocumentService_VersionStatusTransitions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fromStatus string
|
||||
toStatus string
|
||||
isAllowed bool
|
||||
}{
|
||||
// Valid transitions
|
||||
{"draft to review", "draft", "review", true},
|
||||
{"review to approved", "review", "approved", true},
|
||||
{"review to rejected", "review", "rejected", true},
|
||||
{"approved to published", "approved", "published", true},
|
||||
{"approved to scheduled", "approved", "scheduled", true},
|
||||
{"scheduled to published", "scheduled", "published", true},
|
||||
{"published to archived", "published", "archived", true},
|
||||
{"rejected to draft", "rejected", "draft", true},
|
||||
|
||||
// Invalid transitions
|
||||
{"draft to published", "draft", "published", false},
|
||||
{"draft to approved", "draft", "approved", false},
|
||||
{"review to published", "review", "published", false},
|
||||
{"published to draft", "published", "draft", false},
|
||||
{"published to review", "published", "review", false},
|
||||
{"archived to draft", "archived", "draft", false},
|
||||
{"archived to published", "archived", "published", false},
|
||||
}
|
||||
|
||||
// Define valid transitions
|
||||
validTransitions := map[string][]string{
|
||||
"draft": {"review"},
|
||||
"review": {"approved", "rejected"},
|
||||
"approved": {"published", "scheduled"},
|
||||
"scheduled": {"published"},
|
||||
"published": {"archived"},
|
||||
"rejected": {"draft"},
|
||||
"archived": {}, // terminal state
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Check if transition is allowed
|
||||
allowed := false
|
||||
if transitions, ok := validTransitions[tt.fromStatus]; ok {
|
||||
for _, validTo := range transitions {
|
||||
if validTo == tt.toStatus {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.isAllowed {
|
||||
t.Errorf("Transition %s->%s: expected allowed=%v, got %v",
|
||||
tt.fromStatus, tt.toStatus, tt.isAllowed, allowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_PublishVersion tests publishing a version
|
||||
func TestDocumentService_PublishVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionID uuid.UUID
|
||||
currentStatus string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "publish approved version",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "approved",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "publish scheduled version",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "scheduled",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "cannot publish draft",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "draft",
|
||||
expectError: true,
|
||||
errorContains: "draft",
|
||||
},
|
||||
{
|
||||
name: "cannot publish review",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "review",
|
||||
expectError: true,
|
||||
errorContains: "review",
|
||||
},
|
||||
{
|
||||
name: "invalid version ID",
|
||||
versionID: uuid.Nil,
|
||||
currentStatus: "approved",
|
||||
expectError: true,
|
||||
errorContains: "ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.versionID == uuid.Nil {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
} else if tt.currentStatus != "approved" && tt.currentStatus != "scheduled" {
|
||||
err = &ValidationError{Field: "status", Message: "only approved or scheduled versions can be published"}
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_ArchiveVersion tests archiving a version
|
||||
func TestDocumentService_ArchiveVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "archive valid version",
|
||||
versionID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid version ID",
|
||||
versionID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.versionID == uuid.Nil {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_DeleteVersion tests deleting a version
|
||||
func TestDocumentService_DeleteVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionID uuid.UUID
|
||||
status string
|
||||
canDelete bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete draft version",
|
||||
versionID: uuid.New(),
|
||||
status: "draft",
|
||||
canDelete: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete rejected version",
|
||||
versionID: uuid.New(),
|
||||
status: "rejected",
|
||||
canDelete: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "cannot delete published version",
|
||||
versionID: uuid.New(),
|
||||
status: "published",
|
||||
canDelete: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "cannot delete approved version",
|
||||
versionID: uuid.New(),
|
||||
status: "approved",
|
||||
canDelete: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "cannot delete archived version",
|
||||
versionID: uuid.New(),
|
||||
status: "archived",
|
||||
canDelete: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Only draft and rejected can be deleted
|
||||
canDelete := tt.status == "draft" || tt.status == "rejected"
|
||||
|
||||
var err error
|
||||
if !canDelete {
|
||||
err = &ValidationError{Field: "status", Message: "only draft or rejected versions can be deleted"}
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if canDelete != tt.canDelete {
|
||||
t.Errorf("Expected canDelete=%v, got %v", tt.canDelete, canDelete)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_GetLatestVersion tests retrieving the latest version
|
||||
func TestDocumentService_GetLatestVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentID uuid.UUID
|
||||
language string
|
||||
status string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get latest German version",
|
||||
documentID: uuid.New(),
|
||||
language: "de",
|
||||
status: "published",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "get latest English version",
|
||||
documentID: uuid.New(),
|
||||
language: "en",
|
||||
status: "published",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid document ID",
|
||||
documentID: uuid.Nil,
|
||||
language: "de",
|
||||
status: "published",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.documentID == uuid.Nil {
|
||||
err = &ValidationError{Field: "document ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_CompareVersions tests version comparison
|
||||
func TestDocumentService_CompareVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
version1 string
|
||||
version2 string
|
||||
isDifferent bool
|
||||
}{
|
||||
{
|
||||
name: "same version",
|
||||
version1: "1.0.0",
|
||||
version2: "1.0.0",
|
||||
isDifferent: false,
|
||||
},
|
||||
{
|
||||
name: "different major version",
|
||||
version1: "2.0.0",
|
||||
version2: "1.0.0",
|
||||
isDifferent: true,
|
||||
},
|
||||
{
|
||||
name: "different minor version",
|
||||
version1: "1.1.0",
|
||||
version2: "1.0.0",
|
||||
isDifferent: true,
|
||||
},
|
||||
{
|
||||
name: "different patch version",
|
||||
version1: "1.0.1",
|
||||
version2: "1.0.0",
|
||||
isDifferent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isDifferent := tt.version1 != tt.version2
|
||||
|
||||
if isDifferent != tt.isDifferent {
|
||||
t.Errorf("Expected isDifferent=%v, got %v", tt.isDifferent, isDifferent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_ScheduledPublishing tests scheduled publishing
|
||||
func TestDocumentService_ScheduledPublishing(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scheduledAt time.Time
|
||||
shouldPublish bool
|
||||
}{
|
||||
{
|
||||
name: "scheduled for past - should publish",
|
||||
scheduledAt: now.Add(-1 * time.Hour),
|
||||
shouldPublish: true,
|
||||
},
|
||||
{
|
||||
name: "scheduled for now - should publish",
|
||||
scheduledAt: now,
|
||||
shouldPublish: true,
|
||||
},
|
||||
{
|
||||
name: "scheduled for future - should not publish",
|
||||
scheduledAt: now.Add(1 * time.Hour),
|
||||
shouldPublish: false,
|
||||
},
|
||||
{
|
||||
name: "scheduled for tomorrow - should not publish",
|
||||
scheduledAt: now.AddDate(0, 0, 1),
|
||||
shouldPublish: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldPublish := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
|
||||
|
||||
if shouldPublish != tt.shouldPublish {
|
||||
t.Errorf("Expected shouldPublish=%v, got %v", tt.shouldPublish, shouldPublish)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_ApprovalWorkflow tests the approval workflow
|
||||
func TestDocumentService_ApprovalWorkflow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action string
|
||||
userRole string
|
||||
isAllowed bool
|
||||
}{
|
||||
// Admin permissions
|
||||
{"admin submit for review", "submit_review", "admin", true},
|
||||
{"admin cannot approve", "approve", "admin", false},
|
||||
{"admin can publish", "publish", "admin", true},
|
||||
|
||||
// DSB permissions
|
||||
{"dsb can approve", "approve", "data_protection_officer", true},
|
||||
{"dsb can reject", "reject", "data_protection_officer", true},
|
||||
{"dsb can publish", "publish", "data_protection_officer", true},
|
||||
|
||||
// User permissions
|
||||
{"user cannot submit", "submit_review", "user", false},
|
||||
{"user cannot approve", "approve", "user", false},
|
||||
{"user cannot publish", "publish", "user", false},
|
||||
}
|
||||
|
||||
permissions := map[string]map[string]bool{
|
||||
"admin": {
|
||||
"submit_review": true,
|
||||
"approve": false,
|
||||
"reject": false,
|
||||
"publish": true,
|
||||
},
|
||||
"data_protection_officer": {
|
||||
"submit_review": true,
|
||||
"approve": true,
|
||||
"reject": true,
|
||||
"publish": true,
|
||||
},
|
||||
"user": {
|
||||
"submit_review": false,
|
||||
"approve": false,
|
||||
"reject": false,
|
||||
"publish": false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rolePerms, ok := permissions[tt.userRole]
|
||||
if !ok {
|
||||
t.Fatalf("Unknown role: %s", tt.userRole)
|
||||
}
|
||||
|
||||
isAllowed := rolePerms[tt.action]
|
||||
|
||||
if isAllowed != tt.isAllowed {
|
||||
t.Errorf("Role %s action %s: expected allowed=%v, got %v",
|
||||
tt.userRole, tt.action, tt.isAllowed, isAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_FourEyesPrinciple tests the four-eyes principle
|
||||
func TestDocumentService_FourEyesPrinciple(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
createdBy uuid.UUID
|
||||
approver uuid.UUID
|
||||
approverRole string
|
||||
canApprove bool
|
||||
}{
|
||||
{
|
||||
name: "different users - DSB can approve",
|
||||
createdBy: uuid.New(),
|
||||
approver: uuid.New(),
|
||||
approverRole: "data_protection_officer",
|
||||
canApprove: true,
|
||||
},
|
||||
{
|
||||
name: "same user - DSB cannot approve own",
|
||||
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approverRole: "data_protection_officer",
|
||||
canApprove: false,
|
||||
},
|
||||
{
|
||||
name: "same user - admin CAN approve own (exception)",
|
||||
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approverRole: "admin",
|
||||
canApprove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Four-eyes principle: DSB cannot approve their own work
|
||||
// Exception: Admins can (for development/testing)
|
||||
canApprove := tt.createdBy != tt.approver || tt.approverRole == "admin"
|
||||
|
||||
if canApprove != tt.canApprove {
|
||||
t.Errorf("Expected canApprove=%v, got %v", tt.canApprove, canApprove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
947
consent-service/internal/services/dsr_service.go
Normal file
947
consent-service/internal/services/dsr_service.go
Normal file
@@ -0,0 +1,947 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// DSRService handles Data Subject Request business logic
|
||||
type DSRService struct {
|
||||
pool *pgxpool.Pool
|
||||
notificationService *NotificationService
|
||||
emailService *EmailService
|
||||
}
|
||||
|
||||
// NewDSRService creates a new DSRService
|
||||
func NewDSRService(pool *pgxpool.Pool, notificationService *NotificationService, emailService *EmailService) *DSRService {
|
||||
return &DSRService{
|
||||
pool: pool,
|
||||
notificationService: notificationService,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPool returns the database pool for direct queries
|
||||
func (s *DSRService) GetPool() *pgxpool.Pool {
|
||||
return s.pool
|
||||
}
|
||||
|
||||
// generateRequestNumber generates a unique request number like DSR-2025-000001
|
||||
func (s *DSRService) generateRequestNumber(ctx context.Context) (string, error) {
|
||||
var seqNum int64
|
||||
err := s.pool.QueryRow(ctx, "SELECT nextval('dsr_request_number_seq')").Scan(&seqNum)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get next sequence number: %w", err)
|
||||
}
|
||||
year := time.Now().Year()
|
||||
return fmt.Sprintf("DSR-%d-%06d", year, seqNum), nil
|
||||
}
|
||||
|
||||
// CreateRequest creates a new data subject request
|
||||
func (s *DSRService) CreateRequest(ctx context.Context, req models.CreateDSRRequest, createdBy *uuid.UUID) (*models.DataSubjectRequest, error) {
|
||||
// Validate request type
|
||||
requestType := models.DSRRequestType(req.RequestType)
|
||||
if !isValidRequestType(requestType) {
|
||||
return nil, fmt.Errorf("invalid request type: %s", req.RequestType)
|
||||
}
|
||||
|
||||
// Generate request number
|
||||
requestNumber, err := s.generateRequestNumber(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate deadline
|
||||
deadlineDays := requestType.DeadlineDays()
|
||||
deadline := time.Now().AddDate(0, 0, deadlineDays)
|
||||
|
||||
// Determine priority
|
||||
priority := models.DSRPriorityNormal
|
||||
if req.Priority != "" {
|
||||
priority = models.DSRPriority(req.Priority)
|
||||
} else if requestType.IsExpedited() {
|
||||
priority = models.DSRPriorityExpedited
|
||||
}
|
||||
|
||||
// Determine source
|
||||
source := models.DSRSourceAPI
|
||||
if req.Source != "" {
|
||||
source = models.DSRSource(req.Source)
|
||||
}
|
||||
|
||||
// Serialize request details
|
||||
detailsJSON, err := json.Marshal(req.RequestDetails)
|
||||
if err != nil {
|
||||
detailsJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Try to find existing user by email
|
||||
var userID *uuid.UUID
|
||||
var foundUserID uuid.UUID
|
||||
err = s.pool.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", req.RequesterEmail).Scan(&foundUserID)
|
||||
if err == nil {
|
||||
userID = &foundUserID
|
||||
}
|
||||
|
||||
// Insert request
|
||||
var dsr models.DataSubjectRequest
|
||||
err = s.pool.QueryRow(ctx, `
|
||||
INSERT INTO data_subject_requests (
|
||||
user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone,
|
||||
request_details, deadline_at, legal_deadline_days, created_by
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone, identity_verified,
|
||||
request_details, deadline_at, legal_deadline_days, created_at, updated_at, created_by
|
||||
`, userID, requestNumber, requestType, models.DSRStatusIntake, priority, source,
|
||||
req.RequesterEmail, req.RequesterName, req.RequesterPhone,
|
||||
detailsJSON, deadline, deadlineDays, createdBy,
|
||||
).Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.RequesterPhone, &dsr.IdentityVerified, &detailsJSON,
|
||||
&dsr.DeadlineAt, &dsr.LegalDeadlineDays, &dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DSR: %w", err)
|
||||
}
|
||||
|
||||
// Parse details back
|
||||
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
|
||||
|
||||
// Record initial status
|
||||
s.recordStatusChange(ctx, dsr.ID, nil, models.DSRStatusIntake, createdBy, "Anfrage eingegangen")
|
||||
|
||||
// Notify DPOs about new request
|
||||
go s.notifyNewRequest(context.Background(), &dsr)
|
||||
|
||||
return &dsr, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a DSR by ID
|
||||
func (s *DSRService) GetByID(ctx context.Context, id uuid.UUID) (*models.DataSubjectRequest, error) {
|
||||
var dsr models.DataSubjectRequest
|
||||
var detailsJSON, resultDataJSON []byte
|
||||
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone,
|
||||
identity_verified, identity_verified_at, identity_verified_by, identity_verification_method,
|
||||
request_details, deadline_at, legal_deadline_days, extended_deadline_at, extension_reason,
|
||||
assigned_to, processing_notes, completed_at, completed_by, result_summary, result_data,
|
||||
rejected_at, rejected_by, rejection_reason, rejection_legal_basis,
|
||||
created_at, updated_at, created_by
|
||||
FROM data_subject_requests WHERE id = $1
|
||||
`, id).Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.IdentityVerifiedAt,
|
||||
&dsr.IdentityVerifiedBy, &dsr.IdentityVerificationMethod,
|
||||
&detailsJSON, &dsr.DeadlineAt, &dsr.LegalDeadlineDays,
|
||||
&dsr.ExtendedDeadlineAt, &dsr.ExtensionReason, &dsr.AssignedTo,
|
||||
&dsr.ProcessingNotes, &dsr.CompletedAt, &dsr.CompletedBy,
|
||||
&dsr.ResultSummary, &resultDataJSON, &dsr.RejectedAt, &dsr.RejectedBy,
|
||||
&dsr.RejectionReason, &dsr.RejectionLegalBasis,
|
||||
&dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
|
||||
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
|
||||
json.Unmarshal(resultDataJSON, &dsr.ResultData)
|
||||
|
||||
return &dsr, nil
|
||||
}
|
||||
|
||||
// GetByNumber retrieves a DSR by request number
|
||||
func (s *DSRService) GetByNumber(ctx context.Context, requestNumber string) (*models.DataSubjectRequest, error) {
|
||||
var id uuid.UUID
|
||||
err := s.pool.QueryRow(ctx, "SELECT id FROM data_subject_requests WHERE request_number = $1", requestNumber).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
return s.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// List retrieves DSRs with filters and pagination
|
||||
func (s *DSRService) List(ctx context.Context, filters models.DSRListFilters, limit, offset int) ([]models.DataSubjectRequest, int, error) {
|
||||
// Build query
|
||||
baseQuery := "FROM data_subject_requests WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if filters.Status != nil && *filters.Status != "" {
|
||||
baseQuery += fmt.Sprintf(" AND status = $%d", argIndex)
|
||||
args = append(args, *filters.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.RequestType != nil && *filters.RequestType != "" {
|
||||
baseQuery += fmt.Sprintf(" AND request_type = $%d", argIndex)
|
||||
args = append(args, *filters.RequestType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.AssignedTo != nil && *filters.AssignedTo != "" {
|
||||
baseQuery += fmt.Sprintf(" AND assigned_to = $%d", argIndex)
|
||||
args = append(args, *filters.AssignedTo)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.Priority != nil && *filters.Priority != "" {
|
||||
baseQuery += fmt.Sprintf(" AND priority = $%d", argIndex)
|
||||
args = append(args, *filters.Priority)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.OverdueOnly {
|
||||
baseQuery += " AND deadline_at < NOW() AND status NOT IN ('completed', 'rejected', 'cancelled')"
|
||||
}
|
||||
|
||||
if filters.FromDate != nil {
|
||||
baseQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex)
|
||||
args = append(args, *filters.FromDate)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.ToDate != nil {
|
||||
baseQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex)
|
||||
args = append(args, *filters.ToDate)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.Search != nil && *filters.Search != "" {
|
||||
searchPattern := "%" + *filters.Search + "%"
|
||||
baseQuery += fmt.Sprintf(" AND (request_number ILIKE $%d OR requester_email ILIKE $%d OR requester_name ILIKE $%d)", argIndex, argIndex, argIndex)
|
||||
args = append(args, searchPattern)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Get total count
|
||||
var total int
|
||||
err := s.pool.QueryRow(ctx, "SELECT COUNT(*) "+baseQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count DSRs: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone, identity_verified,
|
||||
deadline_at, legal_deadline_days, assigned_to, created_at, updated_at
|
||||
%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d
|
||||
`, baseQuery, argIndex, argIndex+1)
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query DSRs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dsrs []models.DataSubjectRequest
|
||||
for rows.Next() {
|
||||
var dsr models.DataSubjectRequest
|
||||
err := rows.Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.DeadlineAt,
|
||||
&dsr.LegalDeadlineDays, &dsr.AssignedTo, &dsr.CreatedAt, &dsr.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to scan DSR: %w", err)
|
||||
}
|
||||
dsrs = append(dsrs, dsr)
|
||||
}
|
||||
|
||||
return dsrs, total, nil
|
||||
}
|
||||
|
||||
// ListByUser retrieves DSRs for a specific user
|
||||
func (s *DSRService) ListByUser(ctx context.Context, userID uuid.UUID) ([]models.DataSubjectRequest, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, deadline_at, created_at, updated_at
|
||||
FROM data_subject_requests
|
||||
WHERE user_id = $1 OR requester_email = (SELECT email FROM users WHERE id = $1)
|
||||
ORDER BY created_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user DSRs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dsrs []models.DataSubjectRequest
|
||||
for rows.Next() {
|
||||
var dsr models.DataSubjectRequest
|
||||
err := rows.Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.DeadlineAt, &dsr.CreatedAt, &dsr.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan DSR: %w", err)
|
||||
}
|
||||
dsrs = append(dsrs, dsr)
|
||||
}
|
||||
|
||||
return dsrs, nil
|
||||
}
|
||||
|
||||
// UpdateStatus changes the status of a DSR
|
||||
func (s *DSRService) UpdateStatus(ctx context.Context, id uuid.UUID, newStatus models.DSRStatus, comment string, changedBy *uuid.UUID) error {
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
err := s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
|
||||
// Validate transition
|
||||
if !isValidStatusTransition(currentStatus, newStatus) {
|
||||
return fmt.Errorf("invalid status transition from %s to %s", currentStatus, newStatus)
|
||||
}
|
||||
|
||||
// Update status
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests SET status = $1, updated_at = NOW() WHERE id = $2
|
||||
`, newStatus, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update status: %w", err)
|
||||
}
|
||||
|
||||
// Record status change
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, newStatus, changedBy, comment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyIdentity marks identity as verified
|
||||
func (s *DSRService) VerifyIdentity(ctx context.Context, id uuid.UUID, method string, verifiedBy uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET identity_verified = TRUE,
|
||||
identity_verified_at = NOW(),
|
||||
identity_verified_by = $1,
|
||||
identity_verification_method = $2,
|
||||
status = CASE WHEN status = 'intake' THEN 'identity_verification' ELSE status END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, verifiedBy, method, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify identity: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, nil, models.DSRStatusIdentityVerification, &verifiedBy, "Identität verifiziert via "+method)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignRequest assigns a DSR to a handler
|
||||
func (s *DSRService) AssignRequest(ctx context.Context, id uuid.UUID, assigneeID uuid.UUID, assignedBy uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests SET assigned_to = $1, updated_at = NOW() WHERE id = $2
|
||||
`, assigneeID, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to assign DSR: %w", err)
|
||||
}
|
||||
|
||||
// Get assignee name for comment
|
||||
var assigneeName string
|
||||
s.pool.QueryRow(ctx, "SELECT COALESCE(name, email) FROM users WHERE id = $1", assigneeID).Scan(&assigneeName)
|
||||
|
||||
s.recordStatusChange(ctx, id, nil, "", &assignedBy, "Zugewiesen an "+assigneeName)
|
||||
|
||||
// Notify assignee
|
||||
go s.notifyAssignment(context.Background(), id, assigneeID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtendDeadline extends the deadline for a DSR
|
||||
func (s *DSRService) ExtendDeadline(ctx context.Context, id uuid.UUID, reason string, days int, extendedBy uuid.UUID) error {
|
||||
// Default extension is 2 months (60 days) per Art. 12(3)
|
||||
if days <= 0 {
|
||||
days = 60
|
||||
}
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET extended_deadline_at = deadline_at + ($1 || ' days')::INTERVAL,
|
||||
extension_reason = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, days, reason, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extend deadline: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, nil, "", &extendedBy, fmt.Sprintf("Frist um %d Tage verlängert: %s", days, reason))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteRequest marks a DSR as completed
|
||||
func (s *DSRService) CompleteRequest(ctx context.Context, id uuid.UUID, summary string, resultData map[string]interface{}, completedBy uuid.UUID) error {
|
||||
resultJSON, _ := json.Marshal(resultData)
|
||||
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET status = 'completed',
|
||||
completed_at = NOW(),
|
||||
completed_by = $1,
|
||||
result_summary = $2,
|
||||
result_data = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
`, completedBy, summary, resultJSON, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to complete DSR: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, models.DSRStatusCompleted, &completedBy, summary)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectRequest rejects a DSR with legal basis
|
||||
func (s *DSRService) RejectRequest(ctx context.Context, id uuid.UUID, reason, legalBasis string, rejectedBy uuid.UUID) error {
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET status = 'rejected',
|
||||
rejected_at = NOW(),
|
||||
rejected_by = $1,
|
||||
rejection_reason = $2,
|
||||
rejection_legal_basis = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
`, rejectedBy, reason, legalBasis, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reject DSR: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, models.DSRStatusRejected, &rejectedBy, fmt.Sprintf("Abgelehnt (%s): %s", legalBasis, reason))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelRequest cancels a DSR (by user)
|
||||
func (s *DSRService) CancelRequest(ctx context.Context, id uuid.UUID, cancelledBy uuid.UUID) error {
|
||||
// Verify ownership
|
||||
var userID *uuid.UUID
|
||||
err := s.pool.QueryRow(ctx, "SELECT user_id FROM data_subject_requests WHERE id = $1", id).Scan(&userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
if userID == nil || *userID != cancelledBy {
|
||||
return fmt.Errorf("unauthorized: can only cancel own requests")
|
||||
}
|
||||
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests SET status = 'cancelled', updated_at = NOW() WHERE id = $1
|
||||
`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cancel DSR: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, models.DSRStatusCancelled, &cancelledBy, "Vom Antragsteller storniert")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDashboardStats returns statistics for the admin dashboard
|
||||
func (s *DSRService) GetDashboardStats(ctx context.Context) (*models.DSRDashboardStats, error) {
|
||||
stats := &models.DSRDashboardStats{
|
||||
ByType: make(map[string]int),
|
||||
ByStatus: make(map[string]int),
|
||||
}
|
||||
|
||||
// Total requests
|
||||
s.pool.QueryRow(ctx, "SELECT COUNT(*) FROM data_subject_requests").Scan(&stats.TotalRequests)
|
||||
|
||||
// Pending requests (not completed, rejected, or cancelled)
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM data_subject_requests
|
||||
WHERE status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`).Scan(&stats.PendingRequests)
|
||||
|
||||
// Overdue requests
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) < NOW()
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`).Scan(&stats.OverdueRequests)
|
||||
|
||||
// Completed this month
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM data_subject_requests
|
||||
WHERE status = 'completed'
|
||||
AND completed_at >= DATE_TRUNC('month', NOW())
|
||||
`).Scan(&stats.CompletedThisMonth)
|
||||
|
||||
// Average processing days
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COALESCE(AVG(EXTRACT(EPOCH FROM (completed_at - created_at)) / 86400), 0)
|
||||
FROM data_subject_requests WHERE status = 'completed'
|
||||
`).Scan(&stats.AverageProcessingDays)
|
||||
|
||||
// Count by type
|
||||
rows, _ := s.pool.Query(ctx, `
|
||||
SELECT request_type, COUNT(*) FROM data_subject_requests GROUP BY request_type
|
||||
`)
|
||||
for rows.Next() {
|
||||
var t string
|
||||
var count int
|
||||
rows.Scan(&t, &count)
|
||||
stats.ByType[t] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Count by status
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT status, COUNT(*) FROM data_subject_requests GROUP BY status
|
||||
`)
|
||||
for rows.Next() {
|
||||
var s string
|
||||
var count int
|
||||
rows.Scan(&s, &count)
|
||||
stats.ByStatus[s] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Upcoming deadlines (next 7 days)
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, status, requester_email, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN NOW() AND NOW() + INTERVAL '7 days'
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
ORDER BY deadline_at ASC LIMIT 10
|
||||
`)
|
||||
for rows.Next() {
|
||||
var dsr models.DataSubjectRequest
|
||||
rows.Scan(&dsr.ID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status, &dsr.RequesterEmail, &dsr.DeadlineAt)
|
||||
stats.UpcomingDeadlines = append(stats.UpcomingDeadlines, dsr)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetStatusHistory retrieves the status history for a DSR
|
||||
func (s *DSRService) GetStatusHistory(ctx context.Context, requestID uuid.UUID) ([]models.DSRStatusHistory, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, request_id, from_status, to_status, changed_by, comment, metadata, created_at
|
||||
FROM dsr_status_history WHERE request_id = $1 ORDER BY created_at DESC
|
||||
`, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query status history: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var history []models.DSRStatusHistory
|
||||
for rows.Next() {
|
||||
var h models.DSRStatusHistory
|
||||
var metadataJSON []byte
|
||||
err := rows.Scan(&h.ID, &h.RequestID, &h.FromStatus, &h.ToStatus, &h.ChangedBy, &h.Comment, &metadataJSON, &h.CreatedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
json.Unmarshal(metadataJSON, &h.Metadata)
|
||||
history = append(history, h)
|
||||
}
|
||||
|
||||
return history, nil
|
||||
}
|
||||
|
||||
// GetCommunications retrieves communications for a DSR
|
||||
func (s *DSRService) GetCommunications(ctx context.Context, requestID uuid.UUID) ([]models.DSRCommunication, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, request_id, direction, channel, communication_type, template_version_id,
|
||||
subject, body_html, body_text, recipient_email, sent_at, error_message,
|
||||
attachments, created_at, created_by
|
||||
FROM dsr_communications WHERE request_id = $1 ORDER BY created_at DESC
|
||||
`, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query communications: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var comms []models.DSRCommunication
|
||||
for rows.Next() {
|
||||
var c models.DSRCommunication
|
||||
var attachmentsJSON []byte
|
||||
err := rows.Scan(&c.ID, &c.RequestID, &c.Direction, &c.Channel, &c.CommunicationType,
|
||||
&c.TemplateVersionID, &c.Subject, &c.BodyHTML, &c.BodyText, &c.RecipientEmail,
|
||||
&c.SentAt, &c.ErrorMessage, &attachmentsJSON, &c.CreatedAt, &c.CreatedBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
json.Unmarshal(attachmentsJSON, &c.Attachments)
|
||||
comms = append(comms, c)
|
||||
}
|
||||
|
||||
return comms, nil
|
||||
}
|
||||
|
||||
// SendCommunication sends a communication for a DSR
|
||||
func (s *DSRService) SendCommunication(ctx context.Context, requestID uuid.UUID, req models.SendDSRCommunicationRequest, sentBy uuid.UUID) error {
|
||||
// Get DSR details
|
||||
dsr, err := s.GetByID(ctx, requestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get template if specified
|
||||
var subject, bodyHTML, bodyText string
|
||||
if req.TemplateVersionID != nil {
|
||||
templateVersionID, _ := uuid.Parse(*req.TemplateVersionID)
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT subject, body_html, body_text FROM dsr_template_versions WHERE id = $1 AND status = 'published'
|
||||
`, templateVersionID).Scan(&subject, &bodyHTML, &bodyText)
|
||||
if err != nil {
|
||||
return fmt.Errorf("template version not found or not published: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Use custom content if provided
|
||||
if req.CustomSubject != nil {
|
||||
subject = *req.CustomSubject
|
||||
}
|
||||
if req.CustomBody != nil {
|
||||
bodyHTML = *req.CustomBody
|
||||
bodyText = stripHTML(*req.CustomBody)
|
||||
}
|
||||
|
||||
// Replace variables
|
||||
variables := map[string]string{
|
||||
"requester_name": stringOrDefault(dsr.RequesterName, "Antragsteller/in"),
|
||||
"request_number": dsr.RequestNumber,
|
||||
"request_type_de": dsr.RequestType.Label(),
|
||||
"request_date": dsr.CreatedAt.Format("02.01.2006"),
|
||||
"deadline_date": dsr.DeadlineAt.Format("02.01.2006"),
|
||||
}
|
||||
for k, v := range req.Variables {
|
||||
variables[k] = v
|
||||
}
|
||||
subject = replaceVariables(subject, variables)
|
||||
bodyHTML = replaceVariables(bodyHTML, variables)
|
||||
bodyText = replaceVariables(bodyText, variables)
|
||||
|
||||
// Send email
|
||||
if s.emailService != nil {
|
||||
err = s.emailService.SendEmail(dsr.RequesterEmail, subject, bodyHTML, bodyText)
|
||||
if err != nil {
|
||||
// Log error but continue
|
||||
_, _ = s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
|
||||
template_version_id, subject, body_html, body_text, recipient_email, error_message, created_by)
|
||||
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
|
||||
dsr.RequesterEmail, err.Error(), sentBy)
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Log communication
|
||||
now := time.Now()
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
|
||||
template_version_id, subject, body_html, body_text, recipient_email, sent_at, created_by)
|
||||
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
|
||||
dsr.RequesterEmail, now, sentBy)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// InitErasureExceptionChecks initializes exception checks for an erasure request
|
||||
func (s *DSRService) InitErasureExceptionChecks(ctx context.Context, requestID uuid.UUID) error {
|
||||
exceptions := []struct {
|
||||
Type string
|
||||
Description string
|
||||
}{
|
||||
{models.DSRExceptionFreedomExpression, "Ausübung des Rechts auf freie Meinungsäußerung und Information (Art. 17 Abs. 3 lit. a)"},
|
||||
{models.DSRExceptionLegalObligation, "Erfüllung einer rechtlichen Verpflichtung oder öffentlichen Aufgabe (Art. 17 Abs. 3 lit. b)"},
|
||||
{models.DSRExceptionPublicHealth, "Gründe des öffentlichen Interesses im Bereich der öffentlichen Gesundheit (Art. 17 Abs. 3 lit. c)"},
|
||||
{models.DSRExceptionArchiving, "Im öffentlichen Interesse liegende Archivzwecke, Forschung oder Statistik (Art. 17 Abs. 3 lit. d)"},
|
||||
{models.DSRExceptionLegalClaims, "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen (Art. 17 Abs. 3 lit. e)"},
|
||||
}
|
||||
|
||||
for _, exc := range exceptions {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_exception_checks (request_id, exception_type, description)
|
||||
VALUES ($1, $2, $3) ON CONFLICT DO NOTHING
|
||||
`, requestID, exc.Type, exc.Description)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create exception check: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExceptionChecks retrieves exception checks for a DSR
|
||||
func (s *DSRService) GetExceptionChecks(ctx context.Context, requestID uuid.UUID) ([]models.DSRExceptionCheck, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, request_id, exception_type, description, applies, checked_by, checked_at, notes, created_at
|
||||
FROM dsr_exception_checks WHERE request_id = $1 ORDER BY created_at
|
||||
`, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query exception checks: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var checks []models.DSRExceptionCheck
|
||||
for rows.Next() {
|
||||
var c models.DSRExceptionCheck
|
||||
err := rows.Scan(&c.ID, &c.RequestID, &c.ExceptionType, &c.Description, &c.Applies,
|
||||
&c.CheckedBy, &c.CheckedAt, &c.Notes, &c.CreatedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
checks = append(checks, c)
|
||||
}
|
||||
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
// UpdateExceptionCheck updates an exception check
|
||||
func (s *DSRService) UpdateExceptionCheck(ctx context.Context, checkID uuid.UUID, applies bool, notes *string, checkedBy uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE dsr_exception_checks
|
||||
SET applies = $1, notes = $2, checked_by = $3, checked_at = NOW()
|
||||
WHERE id = $4
|
||||
`, applies, notes, checkedBy, checkID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ProcessDeadlines checks for approaching and overdue deadlines
|
||||
func (s *DSRService) ProcessDeadlines(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// Find requests with deadlines in 3 days
|
||||
threeDaysAhead := now.AddDate(0, 0, 3)
|
||||
rows, _ := s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, assigned_to, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`, now, threeDaysAhead)
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var requestNumber, requestType string
|
||||
var assignedTo *uuid.UUID
|
||||
var deadline time.Time
|
||||
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
|
||||
|
||||
// Notify assigned user or all DPOs
|
||||
if assignedTo != nil {
|
||||
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 3)
|
||||
} else {
|
||||
s.notifyAllDPOs(ctx, id, requestNumber, "Frist in 3 Tagen", deadline)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Find requests with deadlines in 1 day
|
||||
oneDayAhead := now.AddDate(0, 0, 1)
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, assigned_to, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`, now, oneDayAhead)
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var requestNumber, requestType string
|
||||
var assignedTo *uuid.UUID
|
||||
var deadline time.Time
|
||||
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
|
||||
|
||||
if assignedTo != nil {
|
||||
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 1)
|
||||
} else {
|
||||
s.notifyAllDPOs(ctx, id, requestNumber, "Frist morgen!", deadline)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Find overdue requests
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, assigned_to, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) < $1
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`, now)
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var requestNumber, requestType string
|
||||
var assignedTo *uuid.UUID
|
||||
var deadline time.Time
|
||||
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
|
||||
|
||||
// Notify all DPOs for overdue
|
||||
s.notifyAllDPOs(ctx, id, requestNumber, "ÜBERFÄLLIG!", deadline)
|
||||
|
||||
// Log to audit
|
||||
s.pool.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (action, entity_type, entity_id, details)
|
||||
VALUES ('dsr_overdue', 'dsr', $1, $2)
|
||||
`, id, fmt.Sprintf(`{"request_number": "%s", "deadline": "%s"}`, requestNumber, deadline.Format(time.RFC3339)))
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func (s *DSRService) recordStatusChange(ctx context.Context, requestID uuid.UUID, fromStatus *models.DSRStatus, toStatus models.DSRStatus, changedBy *uuid.UUID, comment string) {
|
||||
s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_status_history (request_id, from_status, to_status, changed_by, comment)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, requestID, fromStatus, toStatus, changedBy, comment)
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyNewRequest(ctx context.Context, dsr *models.DataSubjectRequest) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
// Notify all DPOs
|
||||
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
rows.Scan(&userID)
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRReceived,
|
||||
"Neue Betroffenenanfrage",
|
||||
fmt.Sprintf("Neue %s eingegangen: %s", dsr.RequestType.Label(), dsr.RequestNumber),
|
||||
map[string]interface{}{"dsr_id": dsr.ID, "request_number": dsr.RequestNumber})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyAssignment(ctx context.Context, dsrID, assigneeID uuid.UUID) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
dsr, _ := s.GetByID(ctx, dsrID)
|
||||
if dsr != nil {
|
||||
s.notificationService.CreateNotification(ctx, assigneeID, NotificationTypeDSRAssigned,
|
||||
"Betroffenenanfrage zugewiesen",
|
||||
fmt.Sprintf("Ihnen wurde die Anfrage %s zugewiesen", dsr.RequestNumber),
|
||||
map[string]interface{}{"dsr_id": dsrID, "request_number": dsr.RequestNumber})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyDeadlineWarning(ctx context.Context, dsrID, userID uuid.UUID, requestNumber string, deadline time.Time, daysLeft int) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
|
||||
fmt.Sprintf("Fristwarnung: %s", requestNumber),
|
||||
fmt.Sprintf("Die Frist für %s läuft in %d Tag(en) ab (%s)", requestNumber, daysLeft, deadline.Format("02.01.2006")),
|
||||
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline, "days_left": daysLeft})
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyAllDPOs(ctx context.Context, dsrID uuid.UUID, requestNumber, message string, deadline time.Time) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
rows.Scan(&userID)
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
|
||||
fmt.Sprintf("%s: %s", message, requestNumber),
|
||||
fmt.Sprintf("Anfrage %s: %s (Frist: %s)", requestNumber, message, deadline.Format("02.01.2006")),
|
||||
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline})
|
||||
}
|
||||
}
|
||||
|
||||
func isValidRequestType(rt models.DSRRequestType) bool {
|
||||
switch rt {
|
||||
case models.DSRTypeAccess, models.DSRTypeRectification, models.DSRTypeErasure,
|
||||
models.DSRTypeRestriction, models.DSRTypePortability:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isValidStatusTransition(from, to models.DSRStatus) bool {
|
||||
validTransitions := map[models.DSRStatus][]models.DSRStatus{
|
||||
models.DSRStatusIntake: {models.DSRStatusIdentityVerification, models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
|
||||
models.DSRStatusIdentityVerification: {models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
|
||||
models.DSRStatusProcessing: {models.DSRStatusCompleted, models.DSRStatusRejected, models.DSRStatusCancelled},
|
||||
models.DSRStatusCompleted: {},
|
||||
models.DSRStatusRejected: {},
|
||||
models.DSRStatusCancelled: {},
|
||||
}
|
||||
|
||||
allowed, exists := validTransitions[from]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
for _, s := range allowed {
|
||||
if s == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringOrDefault(s *string, def string) string {
|
||||
if s != nil {
|
||||
return *s
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func replaceVariables(text string, variables map[string]string) string {
|
||||
for k, v := range variables {
|
||||
text = strings.ReplaceAll(text, "{{"+k+"}}", v)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func stripHTML(html string) string {
|
||||
// Simple HTML stripping - in production use a proper library
|
||||
text := strings.ReplaceAll(html, "<br>", "\n")
|
||||
text = strings.ReplaceAll(text, "<br/>", "\n")
|
||||
text = strings.ReplaceAll(text, "<br />", "\n")
|
||||
text = strings.ReplaceAll(text, "</p>", "\n\n")
|
||||
// Remove all remaining tags
|
||||
for {
|
||||
start := strings.Index(text, "<")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(text[start:], ">")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
text = text[:start] + text[start+end+1:]
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
420
consent-service/internal/services/dsr_service_test.go
Normal file
420
consent-service/internal/services/dsr_service_test.go
Normal file
@@ -0,0 +1,420 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
// TestDSRRequestTypeLabel tests label generation for request types
|
||||
func TestDSRRequestTypeLabel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
expected string
|
||||
}{
|
||||
{"access type", models.DSRTypeAccess, "Auskunftsanfrage (Art. 15)"},
|
||||
{"rectification type", models.DSRTypeRectification, "Berichtigungsanfrage (Art. 16)"},
|
||||
{"erasure type", models.DSRTypeErasure, "Löschanfrage (Art. 17)"},
|
||||
{"restriction type", models.DSRTypeRestriction, "Einschränkungsanfrage (Art. 18)"},
|
||||
{"portability type", models.DSRTypePortability, "Datenübertragung (Art. 20)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.reqType.Label()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRRequestTypeDeadlineDays tests deadline calculation for different request types
|
||||
func TestDSRRequestTypeDeadlineDays(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
expectedDays int
|
||||
}{
|
||||
{"access has 30 days", models.DSRTypeAccess, 30},
|
||||
{"portability has 30 days", models.DSRTypePortability, 30},
|
||||
{"rectification has 14 days", models.DSRTypeRectification, 14},
|
||||
{"erasure has 14 days", models.DSRTypeErasure, 14},
|
||||
{"restriction has 14 days", models.DSRTypeRestriction, 14},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.reqType.DeadlineDays()
|
||||
if result != tt.expectedDays {
|
||||
t.Errorf("Expected %d days, got %d", tt.expectedDays, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRRequestTypeIsExpedited tests expedited flag for request types
|
||||
func TestDSRRequestTypeIsExpedited(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
isExpedited bool
|
||||
}{
|
||||
{"access not expedited", models.DSRTypeAccess, false},
|
||||
{"portability not expedited", models.DSRTypePortability, false},
|
||||
{"rectification is expedited", models.DSRTypeRectification, true},
|
||||
{"erasure is expedited", models.DSRTypeErasure, true},
|
||||
{"restriction is expedited", models.DSRTypeRestriction, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.reqType.IsExpedited()
|
||||
if result != tt.isExpedited {
|
||||
t.Errorf("Expected IsExpedited=%v, got %v", tt.isExpedited, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRStatusLabel tests label generation for statuses
|
||||
func TestDSRStatusLabel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status models.DSRStatus
|
||||
expected string
|
||||
}{
|
||||
{"intake status", models.DSRStatusIntake, "Eingang"},
|
||||
{"identity verification", models.DSRStatusIdentityVerification, "Identitätsprüfung"},
|
||||
{"processing status", models.DSRStatusProcessing, "In Bearbeitung"},
|
||||
{"completed status", models.DSRStatusCompleted, "Abgeschlossen"},
|
||||
{"rejected status", models.DSRStatusRejected, "Abgelehnt"},
|
||||
{"cancelled status", models.DSRStatusCancelled, "Storniert"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.status.Label()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidDSRRequestType tests request type validation
|
||||
func TestValidDSRRequestType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType string
|
||||
valid bool
|
||||
}{
|
||||
{"valid access", "access", true},
|
||||
{"valid rectification", "rectification", true},
|
||||
{"valid erasure", "erasure", true},
|
||||
{"valid restriction", "restriction", true},
|
||||
{"valid portability", "portability", true},
|
||||
{"invalid type", "invalid", false},
|
||||
{"empty type", "", false},
|
||||
{"random string", "delete_everything", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := models.IsValidDSRRequestType(tt.reqType)
|
||||
if result != tt.valid {
|
||||
t.Errorf("Expected IsValidDSRRequestType=%v for %s, got %v", tt.valid, tt.reqType, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidDSRStatus tests status validation
|
||||
func TestValidDSRStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
valid bool
|
||||
}{
|
||||
{"valid intake", "intake", true},
|
||||
{"valid identity_verification", "identity_verification", true},
|
||||
{"valid processing", "processing", true},
|
||||
{"valid completed", "completed", true},
|
||||
{"valid rejected", "rejected", true},
|
||||
{"valid cancelled", "cancelled", true},
|
||||
{"invalid status", "invalid", false},
|
||||
{"empty status", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := models.IsValidDSRStatus(tt.status)
|
||||
if result != tt.valid {
|
||||
t.Errorf("Expected IsValidDSRStatus=%v for %s, got %v", tt.valid, tt.status, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRStatusTransitionValidation tests allowed status transitions
|
||||
func TestDSRStatusTransitionValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fromStatus models.DSRStatus
|
||||
toStatus models.DSRStatus
|
||||
allowed bool
|
||||
}{
|
||||
// From intake
|
||||
{"intake to identity_verification", models.DSRStatusIntake, models.DSRStatusIdentityVerification, true},
|
||||
{"intake to processing", models.DSRStatusIntake, models.DSRStatusProcessing, true},
|
||||
{"intake to rejected", models.DSRStatusIntake, models.DSRStatusRejected, true},
|
||||
{"intake to cancelled", models.DSRStatusIntake, models.DSRStatusCancelled, true},
|
||||
{"intake to completed invalid", models.DSRStatusIntake, models.DSRStatusCompleted, false},
|
||||
|
||||
// From identity_verification
|
||||
{"identity to processing", models.DSRStatusIdentityVerification, models.DSRStatusProcessing, true},
|
||||
{"identity to rejected", models.DSRStatusIdentityVerification, models.DSRStatusRejected, true},
|
||||
{"identity to cancelled", models.DSRStatusIdentityVerification, models.DSRStatusCancelled, true},
|
||||
|
||||
// From processing
|
||||
{"processing to completed", models.DSRStatusProcessing, models.DSRStatusCompleted, true},
|
||||
{"processing to rejected", models.DSRStatusProcessing, models.DSRStatusRejected, true},
|
||||
{"processing to intake invalid", models.DSRStatusProcessing, models.DSRStatusIntake, false},
|
||||
|
||||
// From completed
|
||||
{"completed to anything invalid", models.DSRStatusCompleted, models.DSRStatusProcessing, false},
|
||||
|
||||
// From rejected
|
||||
{"rejected to anything invalid", models.DSRStatusRejected, models.DSRStatusProcessing, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := testIsValidStatusTransition(tt.fromStatus, tt.toStatus)
|
||||
if result != tt.allowed {
|
||||
t.Errorf("Expected transition %s->%s allowed=%v, got %v",
|
||||
tt.fromStatus, tt.toStatus, tt.allowed, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testIsValidStatusTransition is a test helper for validating status transitions
|
||||
// This mirrors the logic in dsr_service.go for testing purposes
|
||||
func testIsValidStatusTransition(from, to models.DSRStatus) bool {
|
||||
validTransitions := map[models.DSRStatus][]models.DSRStatus{
|
||||
models.DSRStatusIntake: {
|
||||
models.DSRStatusIdentityVerification,
|
||||
models.DSRStatusProcessing,
|
||||
models.DSRStatusRejected,
|
||||
models.DSRStatusCancelled,
|
||||
},
|
||||
models.DSRStatusIdentityVerification: {
|
||||
models.DSRStatusProcessing,
|
||||
models.DSRStatusRejected,
|
||||
models.DSRStatusCancelled,
|
||||
},
|
||||
models.DSRStatusProcessing: {
|
||||
models.DSRStatusCompleted,
|
||||
models.DSRStatusRejected,
|
||||
models.DSRStatusCancelled,
|
||||
},
|
||||
models.DSRStatusCompleted: {},
|
||||
models.DSRStatusRejected: {},
|
||||
models.DSRStatusCancelled: {},
|
||||
}
|
||||
|
||||
allowed, exists := validTransitions[from]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, s := range allowed {
|
||||
if s == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestCalculateDeadline tests deadline calculation
|
||||
func TestCalculateDeadline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
expectedDays int
|
||||
}{
|
||||
{"access 30 days", models.DSRTypeAccess, 30},
|
||||
{"erasure 14 days", models.DSRTypeErasure, 14},
|
||||
{"rectification 14 days", models.DSRTypeRectification, 14},
|
||||
{"restriction 14 days", models.DSRTypeRestriction, 14},
|
||||
{"portability 30 days", models.DSRTypePortability, 30},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
now := time.Now()
|
||||
deadline := now.AddDate(0, 0, tt.expectedDays)
|
||||
days := tt.reqType.DeadlineDays()
|
||||
|
||||
if days != tt.expectedDays {
|
||||
t.Errorf("Expected %d days, got %d", tt.expectedDays, days)
|
||||
}
|
||||
|
||||
// Verify deadline is approximately correct (within 1 day due to test timing)
|
||||
calculatedDeadline := now.AddDate(0, 0, days)
|
||||
diff := calculatedDeadline.Sub(deadline)
|
||||
if diff > time.Hour*24 || diff < -time.Hour*24 {
|
||||
t.Errorf("Deadline calculation off by more than a day")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateDSRRequest_Validation tests validation of create request
|
||||
func TestCreateDSRRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.CreateDSRRequest
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid access request",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "access",
|
||||
RequesterEmail: "test@example.com",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid erasure request with name",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "erasure",
|
||||
RequesterEmail: "test@example.com",
|
||||
RequesterName: stringPtr("Max Mustermann"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing email",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "access",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid request type",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "invalid_type",
|
||||
RequesterEmail: "test@example.com",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty request type",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "",
|
||||
RequesterEmail: "test@example.com",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := testValidateCreateDSRRequest(tt.request)
|
||||
hasError := err != nil
|
||||
|
||||
if hasError != tt.expectError {
|
||||
t.Errorf("Expected error=%v, got error=%v (err: %v)", tt.expectError, hasError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testValidateCreateDSRRequest is a test helper for validating create DSR requests
|
||||
func testValidateCreateDSRRequest(req models.CreateDSRRequest) error {
|
||||
if req.RequesterEmail == "" {
|
||||
return &dsrValidationError{"requester_email is required"}
|
||||
}
|
||||
if !models.IsValidDSRRequestType(req.RequestType) {
|
||||
return &dsrValidationError{"invalid request_type"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dsrValidationError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *dsrValidationError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// TestDSRTemplateTypes tests the template types
|
||||
func TestDSRTemplateTypes(t *testing.T) {
|
||||
expectedTemplates := []string{
|
||||
"dsr_receipt_access",
|
||||
"dsr_receipt_rectification",
|
||||
"dsr_receipt_erasure",
|
||||
"dsr_receipt_restriction",
|
||||
"dsr_receipt_portability",
|
||||
"dsr_identity_request",
|
||||
"dsr_processing_started",
|
||||
"dsr_processing_update",
|
||||
"dsr_clarification_request",
|
||||
"dsr_completed_access",
|
||||
"dsr_completed_rectification",
|
||||
"dsr_completed_erasure",
|
||||
"dsr_completed_restriction",
|
||||
"dsr_completed_portability",
|
||||
"dsr_restriction_lifted",
|
||||
"dsr_rejected_identity",
|
||||
"dsr_rejected_exception",
|
||||
"dsr_rejected_unfounded",
|
||||
"dsr_deadline_warning",
|
||||
}
|
||||
|
||||
// This test documents the expected template types
|
||||
// The actual templates are created in database migration
|
||||
for _, template := range expectedTemplates {
|
||||
if template == "" {
|
||||
t.Error("Template type should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if len(expectedTemplates) != 19 {
|
||||
t.Errorf("Expected 19 template types, got %d", len(expectedTemplates))
|
||||
}
|
||||
}
|
||||
|
||||
// TestErasureExceptionTypes tests Art. 17(3) exception types
|
||||
func TestErasureExceptionTypes(t *testing.T) {
|
||||
exceptions := []struct {
|
||||
code string
|
||||
description string
|
||||
}{
|
||||
{"art_17_3_a", "Meinungs- und Informationsfreiheit"},
|
||||
{"art_17_3_b", "Rechtliche Verpflichtung"},
|
||||
{"art_17_3_c", "Öffentliches Interesse im Gesundheitsbereich"},
|
||||
{"art_17_3_d", "Archivzwecke, wissenschaftliche/historische Forschung"},
|
||||
{"art_17_3_e", "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen"},
|
||||
}
|
||||
|
||||
if len(exceptions) != 5 {
|
||||
t.Errorf("Expected 5 Art. 17(3) exceptions, got %d", len(exceptions))
|
||||
}
|
||||
|
||||
for _, ex := range exceptions {
|
||||
if ex.code == "" || ex.description == "" {
|
||||
t.Error("Exception code and description should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stringPtr returns a pointer to the given string
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
554
consent-service/internal/services/email_service.go
Normal file
554
consent-service/internal/services/email_service.go
Normal file
@@ -0,0 +1,554 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EmailConfig holds SMTP configuration
|
||||
type EmailConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromName string
|
||||
FromAddr string
|
||||
BaseURL string // Frontend URL for links
|
||||
}
|
||||
|
||||
// EmailService handles sending emails
|
||||
type EmailService struct {
|
||||
config EmailConfig
|
||||
}
|
||||
|
||||
// NewEmailService creates a new EmailService
|
||||
func NewEmailService(config EmailConfig) *EmailService {
|
||||
return &EmailService{config: config}
|
||||
}
|
||||
|
||||
// SendEmail sends an email
|
||||
func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
|
||||
// Build MIME message
|
||||
var msg bytes.Buffer
|
||||
|
||||
msg.WriteString(fmt.Sprintf("From: %s <%s>\r\n", s.config.FromName, s.config.FromAddr))
|
||||
msg.WriteString(fmt.Sprintf("To: %s\r\n", to))
|
||||
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
|
||||
msg.WriteString("MIME-Version: 1.0\r\n")
|
||||
msg.WriteString("Content-Type: multipart/alternative; boundary=\"boundary42\"\r\n")
|
||||
msg.WriteString("\r\n")
|
||||
|
||||
// Text part
|
||||
msg.WriteString("--boundary42\r\n")
|
||||
msg.WriteString("Content-Type: text/plain; charset=\"UTF-8\"\r\n")
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString(textBody)
|
||||
msg.WriteString("\r\n")
|
||||
|
||||
// HTML part
|
||||
msg.WriteString("--boundary42\r\n")
|
||||
msg.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n")
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString(htmlBody)
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString("--boundary42--\r\n")
|
||||
|
||||
// Send email
|
||||
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
|
||||
auth := smtp.PlainAuth("", s.config.Username, s.config.Password, s.config.Host)
|
||||
|
||||
err := smtp.SendMail(addr, auth, s.config.FromAddr, []string{to}, msg.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendVerificationEmail sends an email verification email
|
||||
func (s *EmailService) SendVerificationEmail(to, name, token string) error {
|
||||
verifyLink := fmt.Sprintf("%s/verify-email?token=%s", s.config.BaseURL, token)
|
||||
|
||||
subject := "Bitte bestätigen Sie Ihre E-Mail-Adresse - BreakPilot"
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Willkommen bei BreakPilot!
|
||||
|
||||
Bitte bestätigen Sie Ihre E-Mail-Adresse, indem Sie den folgenden Link öffnen:
|
||||
%s
|
||||
|
||||
Dieser Link ist 24 Stunden gültig.
|
||||
|
||||
Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), verifyLink)
|
||||
|
||||
htmlBody := s.renderTemplate("verification", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"VerifyLink": verifyLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendPasswordResetEmail sends a password reset email
|
||||
func (s *EmailService) SendPasswordResetEmail(to, name, token string) error {
|
||||
resetLink := fmt.Sprintf("%s/reset-password?token=%s", s.config.BaseURL, token)
|
||||
|
||||
subject := "Passwort zurücksetzen - BreakPilot"
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.
|
||||
|
||||
Klicken Sie auf den folgenden Link, um Ihr Passwort zurückzusetzen:
|
||||
%s
|
||||
|
||||
Dieser Link ist 1 Stunde gültig.
|
||||
|
||||
Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), resetLink)
|
||||
|
||||
htmlBody := s.renderTemplate("password_reset", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"ResetLink": resetLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendNewVersionNotification sends a notification about new document version
|
||||
func (s *EmailService) SendNewVersionNotification(to, name, documentName, documentType string, deadlineDays int) error {
|
||||
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
|
||||
|
||||
subject := fmt.Sprintf("Neue Version: %s - Bitte bestätigen Sie innerhalb von %d Tagen", documentName, deadlineDays)
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Wir haben unsere %s aktualisiert.
|
||||
|
||||
Bitte lesen und bestätigen Sie die neuen Bedingungen innerhalb der nächsten %d Tage:
|
||||
%s
|
||||
|
||||
Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), documentName, deadlineDays, consentLink)
|
||||
|
||||
htmlBody := s.renderTemplate("new_version", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"DocumentName": documentName,
|
||||
"DeadlineDays": deadlineDays,
|
||||
"ConsentLink": consentLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendConsentReminder sends a consent reminder email
|
||||
func (s *EmailService) SendConsentReminder(to, name string, documents []string, daysLeft int) error {
|
||||
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
|
||||
|
||||
urgency := "Erinnerung"
|
||||
if daysLeft <= 7 {
|
||||
urgency = "Dringend"
|
||||
}
|
||||
if daysLeft <= 2 {
|
||||
urgency = "Letzte Warnung"
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("%s: Noch %d Tage um ausstehende Dokumente zu bestätigen", urgency, daysLeft)
|
||||
|
||||
docList := strings.Join(documents, "\n- ")
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.
|
||||
|
||||
Ausstehende Dokumente:
|
||||
- %s
|
||||
|
||||
Sie haben noch %d Tage Zeit. Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
|
||||
|
||||
Bitte bestätigen Sie hier:
|
||||
%s
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), docList, daysLeft, consentLink)
|
||||
|
||||
htmlBody := s.renderTemplate("reminder", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"Documents": documents,
|
||||
"DaysLeft": daysLeft,
|
||||
"Urgency": urgency,
|
||||
"ConsentLink": consentLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendAccountSuspendedNotification sends notification when account is suspended
|
||||
func (s *EmailService) SendAccountSuspendedNotification(to, name string, documents []string) error {
|
||||
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
|
||||
|
||||
subject := "Ihr Account wurde vorübergehend gesperrt - BreakPilot"
|
||||
|
||||
docList := strings.Join(documents, "\n- ")
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Ihr Account wurde vorübergehend gesperrt, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben:
|
||||
|
||||
- %s
|
||||
|
||||
Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:
|
||||
%s
|
||||
|
||||
Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), docList, consentLink)
|
||||
|
||||
htmlBody := s.renderTemplate("suspended", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"Documents": documents,
|
||||
"ConsentLink": consentLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendAccountReactivatedNotification sends notification when account is reactivated
|
||||
func (s *EmailService) SendAccountReactivatedNotification(to, name string) error {
|
||||
appLink := fmt.Sprintf("%s/app", s.config.BaseURL)
|
||||
|
||||
subject := "Ihr Account wurde wieder aktiviert - BreakPilot"
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Vielen Dank für die Bestätigung der rechtlichen Dokumente!
|
||||
|
||||
Ihr Account wurde wieder aktiviert und Sie können BreakPilot wie gewohnt nutzen:
|
||||
%s
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), appLink)
|
||||
|
||||
htmlBody := s.renderTemplate("reactivated", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"AppLink": appLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// renderTemplate renders an email HTML template
|
||||
func (s *EmailService) renderTemplate(templateName string, data map[string]interface{}) string {
|
||||
templates := map[string]string{
|
||||
"verification": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Willkommen bei BreakPilot!</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Vielen Dank für Ihre Registrierung! Bitte bestätigen Sie Ihre E-Mail-Adresse, um Ihr Konto zu aktivieren.</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.VerifyLink}}" class="button">E-Mail bestätigen</a>
|
||||
</p>
|
||||
<p>Dieser Link ist 24 Stunden gültig.</p>
|
||||
<p>Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"password_reset": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Passwort zurücksetzen</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ResetLink}}" class="button">Passwort zurücksetzen</a>
|
||||
</p>
|
||||
<div class="warning">
|
||||
<strong>Hinweis:</strong> Dieser Link ist nur 1 Stunde gültig.
|
||||
</div>
|
||||
<p>Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"new_version": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.info-box { background: #e0e7ff; border-left: 4px solid #6366f1; padding: 12px; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Neue Version: {{.DocumentName}}</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Wir haben unsere <strong>{{.DocumentName}}</strong> aktualisiert.</p>
|
||||
<div class="info-box">
|
||||
<strong>Wichtig:</strong> Bitte bestätigen Sie die neuen Bedingungen innerhalb der nächsten <strong>{{.DeadlineDays}} Tage</strong>.
|
||||
</div>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ConsentLink}}" class="button">Dokument ansehen & bestätigen</a>
|
||||
</p>
|
||||
<p>Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"reminder": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #f59e0b; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
|
||||
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>{{.Urgency}}: Ausstehende Bestätigungen</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.</p>
|
||||
<div class="doc-list">
|
||||
<strong>Ausstehende Dokumente:</strong>
|
||||
<ul>
|
||||
{{range .Documents}}<li>{{.}}</li>{{end}}
|
||||
</ul>
|
||||
</div>
|
||||
<div class="warning">
|
||||
<strong>Sie haben noch {{.DaysLeft}} Tage Zeit.</strong> Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
|
||||
</div>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ConsentLink}}" class="button">Jetzt bestätigen</a>
|
||||
</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"suspended": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.alert { background: #fee2e2; border-left: 4px solid #ef4444; padding: 12px; margin: 20px 0; }
|
||||
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Account vorübergehend gesperrt</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<div class="alert">
|
||||
<strong>Ihr Account wurde vorübergehend gesperrt</strong>, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben.
|
||||
</div>
|
||||
<div class="doc-list">
|
||||
<strong>Nicht bestätigte Dokumente:</strong>
|
||||
<ul>
|
||||
{{range .Documents}}<li>{{.}}</li>{{end}}
|
||||
</ul>
|
||||
</div>
|
||||
<p>Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ConsentLink}}" class="button">Dokumente bestätigen & Account entsperren</a>
|
||||
</p>
|
||||
<p>Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"reactivated": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #22c55e, #16a34a); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #22c55e; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.success { background: #dcfce7; border-left: 4px solid #22c55e; padding: 12px; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Account wieder aktiviert!</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<div class="success">
|
||||
<strong>Vielen Dank!</strong> Ihr Account wurde erfolgreich wieder aktiviert.
|
||||
</div>
|
||||
<p>Sie können BreakPilot ab sofort wieder wie gewohnt nutzen.</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.AppLink}}" class="button">Zu BreakPilot</a>
|
||||
</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"generic_notification": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>{{.Title}}</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>{{.Body}}</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.BaseURL}}/app" class="button">Zu BreakPilot</a>
|
||||
</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
}
|
||||
|
||||
tmplStr, ok := templates[templateName]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tmpl, err := template.New(templateName).Parse(tmplStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// SendConsentReminderEmail sends a simplified consent reminder email
|
||||
func (s *EmailService) SendConsentReminderEmail(to, title, body string) error {
|
||||
subject := title
|
||||
|
||||
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
|
||||
"Title": title,
|
||||
"Body": body,
|
||||
"BaseURL": s.config.BaseURL,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, body)
|
||||
}
|
||||
|
||||
// SendGenericNotificationEmail sends a generic notification email
|
||||
func (s *EmailService) SendGenericNotificationEmail(to, title, body string) error {
|
||||
subject := title
|
||||
|
||||
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
|
||||
"Title": title,
|
||||
"Body": body,
|
||||
"BaseURL": s.config.BaseURL,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, body)
|
||||
}
|
||||
|
||||
// getDisplayName returns display name or fallback
|
||||
func getDisplayName(name string) string {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return "Nutzer"
|
||||
}
|
||||
624
consent-service/internal/services/email_service_test.go
Normal file
624
consent-service/internal/services/email_service_test.go
Normal file
@@ -0,0 +1,624 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MockSMTPSender is a mock SMTP sender for testing
|
||||
type MockSMTPSender struct {
|
||||
SentEmails []SentEmail
|
||||
ShouldFail bool
|
||||
FailError error
|
||||
}
|
||||
|
||||
// SentEmail represents a sent email for testing
|
||||
type SentEmail struct {
|
||||
To []string
|
||||
Subject string
|
||||
Body string
|
||||
}
|
||||
|
||||
// SendMail is a mock implementation of smtp.SendMail
|
||||
func (m *MockSMTPSender) SendMail(addr string, auth smtp.Auth, from string, to []string, msg []byte) error {
|
||||
if m.ShouldFail {
|
||||
return m.FailError
|
||||
}
|
||||
|
||||
// Parse the email to extract subject and body
|
||||
msgStr := string(msg)
|
||||
subject := extractSubject(msgStr)
|
||||
|
||||
m.SentEmails = append(m.SentEmails, SentEmail{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
Body: msgStr,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractSubject extracts the subject from an email message
|
||||
func extractSubject(msg string) string {
|
||||
lines := strings.Split(msg, "\r\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "Subject: ") {
|
||||
return strings.TrimPrefix(line, "Subject: ")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestEmailService_SendEmail tests basic email sending
|
||||
func TestEmailService_SendEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
to string
|
||||
subject string
|
||||
htmlBody string
|
||||
textBody string
|
||||
shouldFail bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid email",
|
||||
to: "user@example.com",
|
||||
subject: "Test Email",
|
||||
htmlBody: "<h1>Hello</h1><p>World</p>",
|
||||
textBody: "Hello\nWorld",
|
||||
shouldFail: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "email with special characters",
|
||||
to: "user+test@example.com",
|
||||
subject: "Test: Öäü Special Characters",
|
||||
htmlBody: "<p>Special: €£¥</p>",
|
||||
textBody: "Special: €£¥",
|
||||
shouldFail: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "SMTP failure",
|
||||
to: "user@example.com",
|
||||
subject: "Test",
|
||||
htmlBody: "<p>Test</p>",
|
||||
textBody: "Test",
|
||||
shouldFail: true,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate email format
|
||||
isValidEmail := strings.Contains(tt.to, "@") && strings.Contains(tt.to, ".")
|
||||
|
||||
if !isValidEmail && !tt.expectError {
|
||||
t.Error("Invalid email format should produce error")
|
||||
}
|
||||
|
||||
// Validate subject is not empty
|
||||
if tt.subject == "" && !tt.expectError {
|
||||
t.Error("Empty subject should produce error")
|
||||
}
|
||||
|
||||
// Validate body content exists
|
||||
if (tt.htmlBody == "" && tt.textBody == "") && !tt.expectError {
|
||||
t.Error("Both bodies empty should produce error")
|
||||
}
|
||||
|
||||
// Simulate SMTP send
|
||||
var err error
|
||||
if tt.shouldFail {
|
||||
err = fmt.Errorf("SMTP error: connection refused")
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_SendVerificationEmail tests verification email sending
|
||||
func TestEmailService_SendVerificationEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
to string
|
||||
userName string
|
||||
token string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid verification email",
|
||||
to: "newuser@example.com",
|
||||
userName: "Max Mustermann",
|
||||
token: "abc123def456",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "user without name",
|
||||
to: "user@example.com",
|
||||
userName: "",
|
||||
token: "token123",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
to: "user@example.com",
|
||||
userName: "Test User",
|
||||
token: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
to: "invalid-email",
|
||||
userName: "Test",
|
||||
token: "token123",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate inputs
|
||||
var err error
|
||||
if tt.token == "" {
|
||||
err = &ValidationError{Field: "token", Message: "required"}
|
||||
} else if !strings.Contains(tt.to, "@") {
|
||||
err = &ValidationError{Field: "email", Message: "invalid format"}
|
||||
}
|
||||
|
||||
// Build verification link
|
||||
if tt.token != "" {
|
||||
verifyLink := fmt.Sprintf("https://example.com/verify-email?token=%s", tt.token)
|
||||
if verifyLink == "" {
|
||||
t.Error("Verification link should not be empty")
|
||||
}
|
||||
|
||||
// Verify link contains token
|
||||
if !strings.Contains(verifyLink, tt.token) {
|
||||
t.Error("Verification link should contain token")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_SendPasswordResetEmail tests password reset email
|
||||
func TestEmailService_SendPasswordResetEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
to string
|
||||
userName string
|
||||
token string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid password reset",
|
||||
to: "user@example.com",
|
||||
userName: "John Doe",
|
||||
token: "reset-token-123",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
to: "user@example.com",
|
||||
userName: "John Doe",
|
||||
token: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.token == "" {
|
||||
err = &ValidationError{Field: "token", Message: "required"}
|
||||
}
|
||||
|
||||
// Build reset link
|
||||
if tt.token != "" {
|
||||
resetLink := fmt.Sprintf("https://example.com/reset-password?token=%s", tt.token)
|
||||
|
||||
// Verify link is secure (HTTPS)
|
||||
if !strings.HasPrefix(resetLink, "https://") {
|
||||
t.Error("Reset link should use HTTPS")
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_Send2FAEmail tests 2FA notification emails
|
||||
func TestEmailService_Send2FAEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
to string
|
||||
action string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "2FA enabled notification",
|
||||
to: "user@example.com",
|
||||
action: "enabled",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "2FA disabled notification",
|
||||
to: "user@example.com",
|
||||
action: "disabled",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid action",
|
||||
to: "user@example.com",
|
||||
action: "invalid",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
validActions := map[string]bool{
|
||||
"enabled": true,
|
||||
"disabled": true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if !validActions[tt.action] {
|
||||
err = &ValidationError{Field: "action", Message: "must be 'enabled' or 'disabled'"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_SendConsentReminderEmail tests consent reminder
|
||||
func TestEmailService_SendConsentReminderEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
to string
|
||||
documentName string
|
||||
daysLeft int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "reminder with 7 days left",
|
||||
to: "user@example.com",
|
||||
documentName: "Terms of Service",
|
||||
daysLeft: 7,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "reminder with 1 day left",
|
||||
to: "user@example.com",
|
||||
documentName: "Privacy Policy",
|
||||
daysLeft: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "urgent reminder - overdue",
|
||||
to: "user@example.com",
|
||||
documentName: "Terms",
|
||||
daysLeft: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty document name",
|
||||
to: "user@example.com",
|
||||
documentName: "",
|
||||
daysLeft: 7,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.documentName == "" {
|
||||
err = &ValidationError{Field: "document name", Message: "required"}
|
||||
}
|
||||
|
||||
// Check urgency level
|
||||
var urgency string
|
||||
if tt.daysLeft <= 0 {
|
||||
urgency = "critical"
|
||||
} else if tt.daysLeft <= 3 {
|
||||
urgency = "urgent"
|
||||
} else {
|
||||
urgency = "normal"
|
||||
}
|
||||
|
||||
if urgency == "" && !tt.expectError {
|
||||
t.Error("Urgency should be set")
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_MIMEFormatting tests MIME message formatting
|
||||
func TestEmailService_MIMEFormatting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
htmlBody string
|
||||
textBody string
|
||||
checkFor []string
|
||||
}{
|
||||
{
|
||||
name: "multipart alternative",
|
||||
htmlBody: "<h1>Test</h1>",
|
||||
textBody: "Test",
|
||||
checkFor: []string{
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: multipart/alternative",
|
||||
"Content-Type: text/plain",
|
||||
"Content-Type: text/html",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UTF-8 encoding",
|
||||
htmlBody: "<p>Öäü</p>",
|
||||
textBody: "Öäü",
|
||||
checkFor: []string{
|
||||
"charset=\"UTF-8\"",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Build MIME message (simplified)
|
||||
message := fmt.Sprintf("MIME-Version: 1.0\r\n"+
|
||||
"Content-Type: multipart/alternative; boundary=\"boundary\"\r\n"+
|
||||
"\r\n"+
|
||||
"--boundary\r\n"+
|
||||
"Content-Type: text/plain; charset=\"UTF-8\"\r\n"+
|
||||
"\r\n%s\r\n"+
|
||||
"--boundary\r\n"+
|
||||
"Content-Type: text/html; charset=\"UTF-8\"\r\n"+
|
||||
"\r\n%s\r\n"+
|
||||
"--boundary--\r\n",
|
||||
tt.textBody, tt.htmlBody)
|
||||
|
||||
// Verify required headers are present
|
||||
for _, required := range tt.checkFor {
|
||||
if !strings.Contains(message, required) {
|
||||
t.Errorf("Message should contain '%s'", required)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify both bodies are included
|
||||
if !strings.Contains(message, tt.textBody) {
|
||||
t.Error("Message should contain text body")
|
||||
}
|
||||
if !strings.Contains(message, tt.htmlBody) {
|
||||
t.Error("Message should contain HTML body")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_TemplateRendering tests email template rendering
|
||||
func TestEmailService_TemplateRendering(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
variables map[string]string
|
||||
expectVars []string
|
||||
}{
|
||||
{
|
||||
name: "verification template",
|
||||
template: "verification",
|
||||
variables: map[string]string{
|
||||
"Name": "John Doe",
|
||||
"VerifyLink": "https://example.com/verify?token=abc",
|
||||
},
|
||||
expectVars: []string{"John Doe", "https://example.com/verify?token=abc"},
|
||||
},
|
||||
{
|
||||
name: "password reset template",
|
||||
template: "password_reset",
|
||||
variables: map[string]string{
|
||||
"Name": "Jane Smith",
|
||||
"ResetLink": "https://example.com/reset?token=xyz",
|
||||
},
|
||||
expectVars: []string{"Jane Smith", "https://example.com/reset?token=xyz"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate template rendering
|
||||
rendered := fmt.Sprintf("Hello %s, please visit %s",
|
||||
tt.variables["Name"],
|
||||
getLink(tt.variables))
|
||||
|
||||
// Verify all variables are in rendered output
|
||||
for _, expectedVar := range tt.expectVars {
|
||||
if !strings.Contains(rendered, expectedVar) {
|
||||
t.Errorf("Rendered template should contain '%s'", expectedVar)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_EmailValidation tests email address validation
|
||||
func TestEmailService_EmailValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
email string
|
||||
isValid bool
|
||||
}{
|
||||
{"user@example.com", true},
|
||||
{"user+tag@example.com", true},
|
||||
{"user.name@example.co.uk", true},
|
||||
{"user@subdomain.example.com", true},
|
||||
{"invalid", false},
|
||||
{"@example.com", false},
|
||||
{"user@", false},
|
||||
{"user@.com", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.email, func(t *testing.T) {
|
||||
// RFC 5322 compliant email validation pattern
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
isValid := emailRegex.MatchString(tt.email)
|
||||
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("Email %s: expected valid=%v, got %v", tt.email, tt.isValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_SMTPConfig tests SMTP configuration
|
||||
func TestEmailService_SMTPConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config EmailConfig
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: EmailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 587,
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
FromName: "BreakPilot",
|
||||
FromAddr: "noreply@example.com",
|
||||
BaseURL: "https://example.com",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
config: EmailConfig{
|
||||
Port: 587,
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
config: EmailConfig{
|
||||
Host: "smtp.example.com",
|
||||
Port: 0,
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.config.Host == "" {
|
||||
err = &ValidationError{Field: "host", Message: "required"}
|
||||
} else if tt.config.Port <= 0 || tt.config.Port > 65535 {
|
||||
err = &ValidationError{Field: "port", Message: "must be between 1 and 65535"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailService_RateLimiting tests email rate limiting logic
|
||||
func TestEmailService_RateLimiting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
emailsSent int
|
||||
timeWindow int // minutes
|
||||
limit int
|
||||
expectThrottle bool
|
||||
}{
|
||||
{
|
||||
name: "under limit",
|
||||
emailsSent: 5,
|
||||
timeWindow: 60,
|
||||
limit: 10,
|
||||
expectThrottle: false,
|
||||
},
|
||||
{
|
||||
name: "at limit",
|
||||
emailsSent: 10,
|
||||
timeWindow: 60,
|
||||
limit: 10,
|
||||
expectThrottle: false,
|
||||
},
|
||||
{
|
||||
name: "over limit",
|
||||
emailsSent: 15,
|
||||
timeWindow: 60,
|
||||
limit: 10,
|
||||
expectThrottle: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldThrottle := tt.emailsSent > tt.limit
|
||||
|
||||
if shouldThrottle != tt.expectThrottle {
|
||||
t.Errorf("Expected throttle=%v, got %v", tt.expectThrottle, shouldThrottle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getLink(vars map[string]string) string {
|
||||
if link, ok := vars["VerifyLink"]; ok {
|
||||
return link
|
||||
}
|
||||
if link, ok := vars["ResetLink"]; ok {
|
||||
return link
|
||||
}
|
||||
return ""
|
||||
}
|
||||
1673
consent-service/internal/services/email_template_service.go
Normal file
1673
consent-service/internal/services/email_template_service.go
Normal file
File diff suppressed because it is too large
Load Diff
698
consent-service/internal/services/email_template_service_test.go
Normal file
698
consent-service/internal/services/email_template_service_test.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Test All 19 Email Categories
|
||||
// ========================================
|
||||
|
||||
// TestEmailTemplateService_GetDefaultTemplateContent tests default content generation for each email type
|
||||
func TestEmailTemplateService_GetDefaultTemplateContent(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
|
||||
// All 19 email categories
|
||||
tests := []struct {
|
||||
name string
|
||||
emailType string
|
||||
language string
|
||||
wantSubject bool
|
||||
wantBodyHTML bool
|
||||
wantBodyText bool
|
||||
}{
|
||||
// Auth Lifecycle (10 types)
|
||||
{"welcome_de", models.EmailTypeWelcome, "de", true, true, true},
|
||||
{"email_verification_de", models.EmailTypeEmailVerification, "de", true, true, true},
|
||||
{"password_reset_de", models.EmailTypePasswordReset, "de", true, true, true},
|
||||
{"password_changed_de", models.EmailTypePasswordChanged, "de", true, true, true},
|
||||
{"2fa_enabled_de", models.EmailType2FAEnabled, "de", true, true, true},
|
||||
{"2fa_disabled_de", models.EmailType2FADisabled, "de", true, true, true},
|
||||
{"new_device_login_de", models.EmailTypeNewDeviceLogin, "de", true, true, true},
|
||||
{"suspicious_activity_de", models.EmailTypeSuspiciousActivity, "de", true, true, true},
|
||||
{"account_locked_de", models.EmailTypeAccountLocked, "de", true, true, true},
|
||||
{"account_unlocked_de", models.EmailTypeAccountUnlocked, "de", true, true, true},
|
||||
|
||||
// GDPR/Privacy (5 types)
|
||||
{"deletion_requested_de", models.EmailTypeDeletionRequested, "de", true, true, true},
|
||||
{"deletion_confirmed_de", models.EmailTypeDeletionConfirmed, "de", true, true, true},
|
||||
{"data_export_ready_de", models.EmailTypeDataExportReady, "de", true, true, true},
|
||||
{"email_changed_de", models.EmailTypeEmailChanged, "de", true, true, true},
|
||||
{"email_change_verify_de", models.EmailTypeEmailChangeVerify, "de", true, true, true},
|
||||
|
||||
// Consent Management (4 types)
|
||||
{"new_version_published_de", models.EmailTypeNewVersionPublished, "de", true, true, true},
|
||||
{"consent_reminder_de", models.EmailTypeConsentReminder, "de", true, true, true},
|
||||
{"consent_deadline_warning_de", models.EmailTypeConsentDeadlineWarning, "de", true, true, true},
|
||||
{"account_suspended_de", models.EmailTypeAccountSuspended, "de", true, true, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(tt.emailType, tt.language)
|
||||
|
||||
if tt.wantSubject && subject == "" {
|
||||
t.Errorf("GetDefaultTemplateContent(%s, %s): expected subject, got empty string", tt.emailType, tt.language)
|
||||
}
|
||||
if tt.wantBodyHTML && bodyHTML == "" {
|
||||
t.Errorf("GetDefaultTemplateContent(%s, %s): expected bodyHTML, got empty string", tt.emailType, tt.language)
|
||||
}
|
||||
if tt.wantBodyText && bodyText == "" {
|
||||
t.Errorf("GetDefaultTemplateContent(%s, %s): expected bodyText, got empty string", tt.emailType, tt.language)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_GetDefaultTemplateContent_UnknownType tests default content for unknown type
|
||||
func TestEmailTemplateService_GetDefaultTemplateContent_UnknownType(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
|
||||
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent("unknown_type", "de")
|
||||
|
||||
// The service returns a fallback for unknown types
|
||||
if subject == "" {
|
||||
t.Errorf("GetDefaultTemplateContent(unknown_type, de): expected fallback subject, got empty")
|
||||
}
|
||||
if bodyHTML == "" {
|
||||
t.Errorf("GetDefaultTemplateContent(unknown_type, de): expected fallback bodyHTML, got empty")
|
||||
}
|
||||
if bodyText == "" {
|
||||
t.Errorf("GetDefaultTemplateContent(unknown_type, de): expected fallback bodyText, got empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_GetDefaultTemplateContent_UnsupportedLanguage tests fallback for unsupported language
|
||||
func TestEmailTemplateService_GetDefaultTemplateContent_UnsupportedLanguage(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
|
||||
// Test with unsupported language - should return fallback
|
||||
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(models.EmailTypeWelcome, "fr")
|
||||
|
||||
// Should return fallback (not empty, but generic)
|
||||
if subject == "" || bodyHTML == "" || bodyText == "" {
|
||||
t.Error("GetDefaultTemplateContent should return fallback for unsupported language")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReplaceVariables tests variable replacement in templates
|
||||
func TestReplaceVariables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
variables map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single variable",
|
||||
template: "Hallo {{user_name}}!",
|
||||
variables: map[string]string{"user_name": "Max"},
|
||||
expected: "Hallo Max!",
|
||||
},
|
||||
{
|
||||
name: "multiple variables",
|
||||
template: "Hallo {{user_name}}, klicken Sie hier: {{reset_link}}",
|
||||
variables: map[string]string{"user_name": "Max", "reset_link": "https://example.com"},
|
||||
expected: "Hallo Max, klicken Sie hier: https://example.com",
|
||||
},
|
||||
{
|
||||
name: "no variables",
|
||||
template: "Hallo Welt!",
|
||||
variables: map[string]string{},
|
||||
expected: "Hallo Welt!",
|
||||
},
|
||||
{
|
||||
name: "missing variable - not replaced",
|
||||
template: "Hallo {{user_name}} und {{missing}}!",
|
||||
variables: map[string]string{"user_name": "Max"},
|
||||
expected: "Hallo Max und {{missing}}!",
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: "",
|
||||
variables: map[string]string{"user_name": "Max"},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "variable with special characters",
|
||||
template: "IP: {{ip_address}}",
|
||||
variables: map[string]string{"ip_address": "192.168.1.1"},
|
||||
expected: "IP: 192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "variable with URL",
|
||||
template: "Link: {{verification_url}}",
|
||||
variables: map[string]string{"verification_url": "https://example.com/verify?token=abc123&user=test"},
|
||||
expected: "Link: https://example.com/verify?token=abc123&user=test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := testReplaceVariables(tt.template, tt.variables)
|
||||
if result != tt.expected {
|
||||
t.Errorf("replaceVariables() = %s, want %s", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testReplaceVariables is a test helper function for variable replacement
|
||||
func testReplaceVariables(template string, variables map[string]string) string {
|
||||
result := template
|
||||
for key, value := range variables {
|
||||
placeholder := "{{" + key + "}}"
|
||||
for i := 0; i < len(result); i++ {
|
||||
idx := testFindSubstring(result, placeholder)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
result = result[:idx] + value + result[idx+len(placeholder):]
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func testFindSubstring(s, substr string) int {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// TestEmailTypeConstantsExist verifies that all expected email types are defined
|
||||
func TestEmailTypeConstantsExist(t *testing.T) {
|
||||
// Test that all 19 email type constants are defined and produce non-empty templates
|
||||
types := []string{
|
||||
// Auth Lifecycle
|
||||
models.EmailTypeWelcome,
|
||||
models.EmailTypeEmailVerification,
|
||||
models.EmailTypePasswordReset,
|
||||
models.EmailTypePasswordChanged,
|
||||
models.EmailType2FAEnabled,
|
||||
models.EmailType2FADisabled,
|
||||
models.EmailTypeNewDeviceLogin,
|
||||
models.EmailTypeSuspiciousActivity,
|
||||
models.EmailTypeAccountLocked,
|
||||
models.EmailTypeAccountUnlocked,
|
||||
// GDPR/Privacy
|
||||
models.EmailTypeDeletionRequested,
|
||||
models.EmailTypeDeletionConfirmed,
|
||||
models.EmailTypeDataExportReady,
|
||||
models.EmailTypeEmailChanged,
|
||||
models.EmailTypeEmailChangeVerify,
|
||||
// Consent Management
|
||||
models.EmailTypeNewVersionPublished,
|
||||
models.EmailTypeConsentReminder,
|
||||
models.EmailTypeConsentDeadlineWarning,
|
||||
models.EmailTypeAccountSuspended,
|
||||
}
|
||||
|
||||
service := &EmailTemplateService{}
|
||||
|
||||
for _, emailType := range types {
|
||||
t.Run(emailType, func(t *testing.T) {
|
||||
subject, bodyHTML, _ := service.GetDefaultTemplateContent(emailType, "de")
|
||||
if subject == "" {
|
||||
t.Errorf("Email type %s has no default subject", emailType)
|
||||
}
|
||||
if bodyHTML == "" {
|
||||
t.Errorf("Email type %s has no default body HTML", emailType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify we have exactly 19 types
|
||||
if len(types) != 19 {
|
||||
t.Errorf("Expected 19 email types, got %d", len(types))
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_ValidateTemplateContent tests template content validation
|
||||
func TestEmailTemplateService_ValidateTemplateContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subject string
|
||||
bodyHTML string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid content",
|
||||
subject: "Test Subject",
|
||||
bodyHTML: "<p>Test Body</p>",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "empty subject",
|
||||
subject: "",
|
||||
bodyHTML: "<p>Test Body</p>",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
subject: "Test Subject",
|
||||
bodyHTML: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
subject: "",
|
||||
bodyHTML: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace only subject",
|
||||
subject: " ",
|
||||
bodyHTML: "<p>Test Body</p>",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := testValidateTemplateContent(tt.subject, tt.bodyHTML)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateTemplateContent() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testValidateTemplateContent is a test helper function to validate template content
|
||||
func testValidateTemplateContent(subject, bodyHTML string) error {
|
||||
if strings.TrimSpace(subject) == "" {
|
||||
return &templateValidationError{Field: "subject", Message: "subject is required"}
|
||||
}
|
||||
if strings.TrimSpace(bodyHTML) == "" {
|
||||
return &templateValidationError{Field: "body_html", Message: "body_html is required"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// templateValidationError represents a validation error in email templates
|
||||
type templateValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *templateValidationError) Error() string {
|
||||
return e.Field + ": " + e.Message
|
||||
}
|
||||
|
||||
// TestGetTestVariablesForType tests that test variables are properly generated for each email type
|
||||
func TestGetTestVariablesForType(t *testing.T) {
|
||||
tests := []struct {
|
||||
emailType string
|
||||
expectedVars []string
|
||||
}{
|
||||
// Auth Lifecycle
|
||||
{models.EmailTypeWelcome, []string{"user_name", "app_name"}},
|
||||
{models.EmailTypeEmailVerification, []string{"user_name", "verification_url"}},
|
||||
{models.EmailTypePasswordReset, []string{"reset_url"}},
|
||||
{models.EmailTypePasswordChanged, []string{"user_name", "changed_at"}},
|
||||
{models.EmailType2FAEnabled, []string{"user_name", "enabled_at"}},
|
||||
{models.EmailType2FADisabled, []string{"user_name", "disabled_at"}},
|
||||
{models.EmailTypeNewDeviceLogin, []string{"device", "location", "ip_address", "login_time"}},
|
||||
{models.EmailTypeSuspiciousActivity, []string{"activity_type", "activity_time"}},
|
||||
{models.EmailTypeAccountLocked, []string{"locked_at", "reason"}},
|
||||
{models.EmailTypeAccountUnlocked, []string{"unlocked_at"}},
|
||||
// GDPR/Privacy
|
||||
{models.EmailTypeDeletionRequested, []string{"deletion_date", "cancel_url"}},
|
||||
{models.EmailTypeDeletionConfirmed, []string{"deleted_at"}},
|
||||
{models.EmailTypeDataExportReady, []string{"download_url", "expires_in"}},
|
||||
{models.EmailTypeEmailChanged, []string{"old_email", "new_email"}},
|
||||
// Consent Management
|
||||
{models.EmailTypeNewVersionPublished, []string{"document_name", "version"}},
|
||||
{models.EmailTypeConsentReminder, []string{"document_name", "days_left"}},
|
||||
{models.EmailTypeConsentDeadlineWarning, []string{"document_name", "hours_left"}},
|
||||
{models.EmailTypeAccountSuspended, []string{"suspended_at", "reason"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.emailType, func(t *testing.T) {
|
||||
vars := getTestVariablesForType(tt.emailType)
|
||||
for _, expected := range tt.expectedVars {
|
||||
if _, ok := vars[expected]; !ok {
|
||||
t.Errorf("getTestVariablesForType(%s) missing variable %s", tt.emailType, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getTestVariablesForType returns test variables for a given email type
|
||||
func getTestVariablesForType(emailType string) map[string]string {
|
||||
// Common variables
|
||||
vars := map[string]string{
|
||||
"user_name": "Max Mustermann",
|
||||
"user_email": "max@example.com",
|
||||
"app_name": "BreakPilot",
|
||||
"app_url": "https://breakpilot.app",
|
||||
"support_url": "https://breakpilot.app/support",
|
||||
"support_email": "support@breakpilot.app",
|
||||
"security_url": "https://breakpilot.app/security",
|
||||
"login_url": "https://breakpilot.app/login",
|
||||
}
|
||||
|
||||
switch emailType {
|
||||
case models.EmailTypeEmailVerification:
|
||||
vars["verification_url"] = "https://breakpilot.app/verify?token=xyz789"
|
||||
vars["verification_code"] = "ABC123"
|
||||
vars["expires_in"] = "24 Stunden"
|
||||
|
||||
case models.EmailTypePasswordReset:
|
||||
vars["reset_url"] = "https://breakpilot.app/reset?token=abc123"
|
||||
vars["reset_code"] = "RST456"
|
||||
vars["expires_in"] = "1 Stunde"
|
||||
vars["ip_address"] = "192.168.1.1"
|
||||
|
||||
case models.EmailTypePasswordChanged:
|
||||
vars["changed_at"] = "14.12.2025 15:30 Uhr"
|
||||
vars["ip_address"] = "192.168.1.1"
|
||||
vars["device_info"] = "Chrome auf MacOS"
|
||||
|
||||
case models.EmailType2FAEnabled:
|
||||
vars["enabled_at"] = "14.12.2025 15:30 Uhr"
|
||||
vars["device_info"] = "Chrome auf MacOS"
|
||||
|
||||
case models.EmailType2FADisabled:
|
||||
vars["disabled_at"] = "14.12.2025 15:30 Uhr"
|
||||
vars["ip_address"] = "192.168.1.1"
|
||||
|
||||
case models.EmailTypeNewDeviceLogin:
|
||||
vars["device"] = "Chrome auf MacOS"
|
||||
vars["device_info"] = "Chrome auf MacOS"
|
||||
vars["location"] = "Berlin, Deutschland"
|
||||
vars["ip_address"] = "192.168.1.1"
|
||||
vars["login_time"] = "14.12.2025 15:30 Uhr"
|
||||
|
||||
case models.EmailTypeSuspiciousActivity:
|
||||
vars["activity_type"] = "Mehrere fehlgeschlagene Logins"
|
||||
vars["activity_time"] = "14.12.2025 15:30 Uhr"
|
||||
vars["ip_address"] = "192.168.1.1"
|
||||
|
||||
case models.EmailTypeAccountLocked:
|
||||
vars["locked_at"] = "14.12.2025 15:30 Uhr"
|
||||
vars["reason"] = "Zu viele fehlgeschlagene Login-Versuche"
|
||||
vars["unlock_time"] = "14.12.2025 16:30 Uhr"
|
||||
|
||||
case models.EmailTypeAccountUnlocked:
|
||||
vars["unlocked_at"] = "14.12.2025 16:30 Uhr"
|
||||
|
||||
case models.EmailTypeDeletionRequested:
|
||||
vars["requested_at"] = "14.12.2025 15:30 Uhr"
|
||||
vars["deletion_date"] = "14.01.2026"
|
||||
vars["cancel_url"] = "https://breakpilot.app/cancel-deletion?token=del123"
|
||||
vars["data_info"] = "Profildaten, Consent-Historie, Audit-Logs"
|
||||
|
||||
case models.EmailTypeDeletionConfirmed:
|
||||
vars["deleted_at"] = "14.01.2026 00:00 Uhr"
|
||||
vars["feedback_url"] = "https://breakpilot.app/feedback"
|
||||
|
||||
case models.EmailTypeDataExportReady:
|
||||
vars["download_url"] = "https://breakpilot.app/download/export123"
|
||||
vars["expires_in"] = "7 Tage"
|
||||
vars["file_size"] = "2.5 MB"
|
||||
|
||||
case models.EmailTypeEmailChanged:
|
||||
vars["old_email"] = "old@example.com"
|
||||
vars["new_email"] = "new@example.com"
|
||||
vars["changed_at"] = "14.12.2025 15:30 Uhr"
|
||||
|
||||
case models.EmailTypeEmailChangeVerify:
|
||||
vars["new_email"] = "new@example.com"
|
||||
vars["verification_url"] = "https://breakpilot.app/verify-email?token=ver123"
|
||||
vars["expires_in"] = "24 Stunden"
|
||||
|
||||
case models.EmailTypeNewVersionPublished:
|
||||
vars["document_name"] = "Datenschutzerklärung"
|
||||
vars["document_type"] = "privacy"
|
||||
vars["version"] = "2.0.0"
|
||||
vars["consent_url"] = "https://breakpilot.app/consent"
|
||||
vars["deadline"] = "31.12.2025"
|
||||
|
||||
case models.EmailTypeConsentReminder:
|
||||
vars["document_name"] = "Nutzungsbedingungen"
|
||||
vars["days_left"] = "7"
|
||||
vars["consent_url"] = "https://breakpilot.app/consent"
|
||||
vars["deadline"] = "21.12.2025"
|
||||
|
||||
case models.EmailTypeConsentDeadlineWarning:
|
||||
vars["document_name"] = "Nutzungsbedingungen"
|
||||
vars["hours_left"] = "24 Stunden"
|
||||
vars["consent_url"] = "https://breakpilot.app/consent"
|
||||
vars["consequences"] = "Ihr Konto wird temporär suspendiert."
|
||||
|
||||
case models.EmailTypeAccountSuspended:
|
||||
vars["suspended_at"] = "14.12.2025 15:30 Uhr"
|
||||
vars["reason"] = "Fehlende Zustimmung zu Pflichtdokumenten"
|
||||
vars["documents"] = "- Nutzungsbedingungen v2.0\n- Datenschutzerklärung v3.0"
|
||||
vars["consent_url"] = "https://breakpilot.app/consent"
|
||||
}
|
||||
|
||||
return vars
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_HTMLEscape tests that HTML is properly escaped in text version
|
||||
func TestEmailTemplateService_HTMLEscape(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
html string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple paragraph",
|
||||
html: "<p>Hello World</p>",
|
||||
expected: "Hello World",
|
||||
},
|
||||
{
|
||||
name: "link",
|
||||
html: `<a href="https://example.com">Click here</a>`,
|
||||
expected: "Click here",
|
||||
},
|
||||
{
|
||||
name: "bold text",
|
||||
html: "<strong>Important</strong>",
|
||||
expected: "Important",
|
||||
},
|
||||
{
|
||||
name: "nested tags",
|
||||
html: "<div><p><strong>Nested</strong> text</p></div>",
|
||||
expected: "Nested text",
|
||||
},
|
||||
{
|
||||
name: "multiple tags",
|
||||
html: "<h1>Title</h1><p>Paragraph</p>",
|
||||
expected: "TitleParagraph",
|
||||
},
|
||||
{
|
||||
name: "self-closing tag",
|
||||
html: "Line1<br/>Line2",
|
||||
expected: "Line1Line2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := stripHTMLTags(tt.html)
|
||||
if result != tt.expected {
|
||||
t.Errorf("stripHTMLTags() = %s, want %s", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// stripHTMLTags removes HTML tags from a string
|
||||
func stripHTMLTags(html string) string {
|
||||
result := ""
|
||||
inTag := false
|
||||
for _, r := range html {
|
||||
if r == '<' {
|
||||
inTag = true
|
||||
continue
|
||||
}
|
||||
if r == '>' {
|
||||
inTag = false
|
||||
continue
|
||||
}
|
||||
if !inTag {
|
||||
result += string(r)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_AllTemplatesHaveVariables tests that all templates define their required variables
|
||||
func TestEmailTemplateService_AllTemplatesHaveVariables(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
templateTypes := service.GetAllTemplateTypes()
|
||||
|
||||
for _, tt := range templateTypes {
|
||||
t.Run(tt.TemplateType, func(t *testing.T) {
|
||||
// Get default template content
|
||||
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(tt.TemplateType, "de")
|
||||
|
||||
// Check that variables defined in template type are present in the content
|
||||
for _, varName := range tt.Variables {
|
||||
placeholder := "{{" + varName + "}}"
|
||||
foundInSubject := strings.Contains(subject, placeholder)
|
||||
foundInHTML := strings.Contains(bodyHTML, placeholder)
|
||||
foundInText := strings.Contains(bodyText, placeholder)
|
||||
|
||||
// Variable should be present in at least one of subject, HTML or text
|
||||
if !foundInSubject && !foundInHTML && !foundInText {
|
||||
// Note: This is a warning, not an error, as some variables might be optional
|
||||
t.Logf("Warning: Variable %s defined for %s but not found in template content", varName, tt.TemplateType)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all variables in content are defined
|
||||
re := regexp.MustCompile(`\{\{(\w+)\}\}`)
|
||||
allMatches := re.FindAllStringSubmatch(subject+bodyHTML+bodyText, -1)
|
||||
definedVars := make(map[string]bool)
|
||||
for _, v := range tt.Variables {
|
||||
definedVars[v] = true
|
||||
}
|
||||
|
||||
for _, match := range allMatches {
|
||||
if len(match) > 1 {
|
||||
varName := match[1]
|
||||
if !definedVars[varName] {
|
||||
t.Logf("Warning: Variable {{%s}} found in template but not defined in variables list for %s", varName, tt.TemplateType)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_TemplateVariableDescriptions tests that all variables have descriptions
|
||||
func TestEmailTemplateService_TemplateVariableDescriptions(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
templateTypes := service.GetAllTemplateTypes()
|
||||
|
||||
for _, tt := range templateTypes {
|
||||
t.Run(tt.TemplateType, func(t *testing.T) {
|
||||
for _, varName := range tt.Variables {
|
||||
if desc, ok := tt.Descriptions[varName]; !ok || desc == "" {
|
||||
t.Errorf("Variable %s in %s has no description", varName, tt.TemplateType)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_GermanTemplatesAreComplete tests that all German templates are fully translated
|
||||
func TestEmailTemplateService_GermanTemplatesAreComplete(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
|
||||
emailTypes := []string{
|
||||
models.EmailTypeWelcome,
|
||||
models.EmailTypeEmailVerification,
|
||||
models.EmailTypePasswordReset,
|
||||
models.EmailTypePasswordChanged,
|
||||
models.EmailType2FAEnabled,
|
||||
models.EmailType2FADisabled,
|
||||
models.EmailTypeNewDeviceLogin,
|
||||
models.EmailTypeSuspiciousActivity,
|
||||
models.EmailTypeAccountLocked,
|
||||
models.EmailTypeAccountUnlocked,
|
||||
models.EmailTypeDeletionRequested,
|
||||
models.EmailTypeDeletionConfirmed,
|
||||
models.EmailTypeDataExportReady,
|
||||
models.EmailTypeEmailChanged,
|
||||
models.EmailTypeNewVersionPublished,
|
||||
models.EmailTypeConsentReminder,
|
||||
models.EmailTypeConsentDeadlineWarning,
|
||||
models.EmailTypeAccountSuspended,
|
||||
}
|
||||
|
||||
germanKeywords := []string{"Hallo", "freundlichen", "Grüßen", "BreakPilot", "Ihr"}
|
||||
|
||||
for _, emailType := range emailTypes {
|
||||
t.Run(emailType, func(t *testing.T) {
|
||||
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(emailType, "de")
|
||||
|
||||
// Check that German text is present
|
||||
foundGerman := false
|
||||
for _, keyword := range germanKeywords {
|
||||
if strings.Contains(bodyHTML, keyword) || strings.Contains(bodyText, keyword) {
|
||||
foundGerman = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundGerman {
|
||||
t.Errorf("Template %s does not appear to be in German", emailType)
|
||||
}
|
||||
|
||||
// Check that subject is not just the fallback
|
||||
if subject == "No template" {
|
||||
t.Errorf("Template %s has fallback subject instead of German subject", emailType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateService_HTMLStructure tests that HTML templates have valid structure
|
||||
func TestEmailTemplateService_HTMLStructure(t *testing.T) {
|
||||
service := &EmailTemplateService{}
|
||||
|
||||
emailTypes := []string{
|
||||
models.EmailTypeWelcome,
|
||||
models.EmailTypeEmailVerification,
|
||||
models.EmailTypePasswordReset,
|
||||
}
|
||||
|
||||
for _, emailType := range emailTypes {
|
||||
t.Run(emailType, func(t *testing.T) {
|
||||
_, bodyHTML, _ := service.GetDefaultTemplateContent(emailType, "de")
|
||||
|
||||
// Check for basic HTML structure
|
||||
if !strings.Contains(bodyHTML, "<!DOCTYPE html>") {
|
||||
t.Errorf("Template %s missing DOCTYPE", emailType)
|
||||
}
|
||||
if !strings.Contains(bodyHTML, "<html>") {
|
||||
t.Errorf("Template %s missing <html> tag", emailType)
|
||||
}
|
||||
if !strings.Contains(bodyHTML, "</html>") {
|
||||
t.Errorf("Template %s missing closing </html> tag", emailType)
|
||||
}
|
||||
if !strings.Contains(bodyHTML, "<body") {
|
||||
t.Errorf("Template %s missing <body> tag", emailType)
|
||||
}
|
||||
if !strings.Contains(bodyHTML, "</body>") {
|
||||
t.Errorf("Template %s missing closing </body> tag", emailType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkReplaceVariables benchmarks variable replacement
|
||||
func BenchmarkReplaceVariables(b *testing.B) {
|
||||
template := "Hallo {{user_name}}, Ihr Link: {{reset_url}}, gültig bis {{expires_in}}"
|
||||
variables := map[string]string{
|
||||
"user_name": "Max Mustermann",
|
||||
"reset_url": "https://example.com/reset?token=abc123",
|
||||
"expires_in": "24 Stunden",
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
replaceVariables(template, variables)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStripHTMLTags benchmarks HTML tag stripping
|
||||
func BenchmarkStripHTMLTags(b *testing.B) {
|
||||
html := "<html><body><h1>Title</h1><p>This is a <strong>test</strong> paragraph with <a href='#'>links</a>.</p></body></html>"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stripHTMLTags(html)
|
||||
}
|
||||
}
|
||||
543
consent-service/internal/services/grade_service.go
Normal file
543
consent-service/internal/services/grade_service.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/database"
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services/matrix"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GradeService handles grade management and notifications
|
||||
type GradeService struct {
|
||||
db *database.DB
|
||||
matrix *matrix.MatrixService
|
||||
}
|
||||
|
||||
// NewGradeService creates a new grade service
|
||||
func NewGradeService(db *database.DB, matrixService *matrix.MatrixService) *GradeService {
|
||||
return &GradeService{
|
||||
db: db,
|
||||
matrix: matrixService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Grade CRUD
|
||||
// ========================================
|
||||
|
||||
// CreateGrade creates a new grade for a student
|
||||
func (s *GradeService) CreateGrade(ctx context.Context, req models.CreateGradeRequest, teacherID uuid.UUID) (*models.Grade, error) {
|
||||
studentID, err := uuid.Parse(req.StudentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid student ID: %w", err)
|
||||
}
|
||||
|
||||
subjectID, err := uuid.Parse(req.SubjectID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid subject ID: %w", err)
|
||||
}
|
||||
|
||||
schoolYearID, err := uuid.Parse(req.SchoolYearID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid school year ID: %w", err)
|
||||
}
|
||||
|
||||
date, err := time.Parse("2006-01-02", req.Date)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
// Get default grade scale for the school
|
||||
var gradeScaleID uuid.UUID
|
||||
var schoolID uuid.UUID
|
||||
err = s.db.Pool.QueryRow(ctx, `
|
||||
SELECT gs.id, gs.school_id
|
||||
FROM grade_scales gs
|
||||
JOIN students st ON st.school_id = gs.school_id
|
||||
WHERE st.id = $1 AND gs.is_default = true`, studentID).Scan(&gradeScaleID, &schoolID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get grade scale: %w", err)
|
||||
}
|
||||
|
||||
weight := req.Weight
|
||||
if weight == 0 {
|
||||
weight = 1.0
|
||||
}
|
||||
|
||||
grade := &models.Grade{
|
||||
ID: uuid.New(),
|
||||
StudentID: studentID,
|
||||
SubjectID: subjectID,
|
||||
TeacherID: teacherID,
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Weight: weight,
|
||||
Date: date,
|
||||
Title: req.Title,
|
||||
Description: req.Description,
|
||||
IsVisible: true,
|
||||
Semester: req.Semester,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO grades (id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
grade.ID, grade.StudentID, grade.SubjectID, grade.TeacherID,
|
||||
grade.SchoolYearID, grade.GradeScaleID, grade.Type, grade.Value,
|
||||
grade.Weight, grade.Date, grade.Title, grade.Description,
|
||||
grade.IsVisible, grade.Semester, grade.CreatedAt, grade.UpdatedAt,
|
||||
).Scan(&grade.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create grade: %w", err)
|
||||
}
|
||||
|
||||
// Send notification to parents if grade is visible
|
||||
if grade.IsVisible {
|
||||
go s.notifyParentsOfNewGrade(context.Background(), grade)
|
||||
}
|
||||
|
||||
return grade, nil
|
||||
}
|
||||
|
||||
// GetGrade retrieves a grade by ID
|
||||
func (s *GradeService) GetGrade(ctx context.Context, gradeID uuid.UUID) (*models.Grade, error) {
|
||||
query := `
|
||||
SELECT id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at
|
||||
FROM grades
|
||||
WHERE id = $1`
|
||||
|
||||
grade := &models.Grade{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, gradeID).Scan(
|
||||
&grade.ID, &grade.StudentID, &grade.SubjectID, &grade.TeacherID,
|
||||
&grade.SchoolYearID, &grade.GradeScaleID, &grade.Type, &grade.Value,
|
||||
&grade.Weight, &grade.Date, &grade.Title, &grade.Description,
|
||||
&grade.IsVisible, &grade.Semester, &grade.CreatedAt, &grade.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get grade: %w", err)
|
||||
}
|
||||
|
||||
return grade, nil
|
||||
}
|
||||
|
||||
// UpdateGrade updates an existing grade
|
||||
func (s *GradeService) UpdateGrade(ctx context.Context, gradeID uuid.UUID, value float64, title, description *string) error {
|
||||
query := `
|
||||
UPDATE grades
|
||||
SET value = $1, title = COALESCE($2, title), description = COALESCE($3, description), updated_at = NOW()
|
||||
WHERE id = $4`
|
||||
|
||||
result, err := s.db.Pool.Exec(ctx, query, value, title, description, gradeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update grade: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("grade not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGrade deletes a grade
|
||||
func (s *GradeService) DeleteGrade(ctx context.Context, gradeID uuid.UUID) error {
|
||||
result, err := s.db.Pool.Exec(ctx, `DELETE FROM grades WHERE id = $1`, gradeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete grade: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("grade not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Grade Queries
|
||||
// ========================================
|
||||
|
||||
// GetStudentGrades gets all grades for a student in a school year
|
||||
func (s *GradeService) GetStudentGrades(ctx context.Context, studentID, schoolYearID uuid.UUID) ([]models.Grade, error) {
|
||||
query := `
|
||||
SELECT id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at
|
||||
FROM grades
|
||||
WHERE student_id = $1 AND school_year_id = $2 AND is_visible = true
|
||||
ORDER BY date DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, studentID, schoolYearID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get student grades: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var grades []models.Grade
|
||||
for rows.Next() {
|
||||
var grade models.Grade
|
||||
err := rows.Scan(
|
||||
&grade.ID, &grade.StudentID, &grade.SubjectID, &grade.TeacherID,
|
||||
&grade.SchoolYearID, &grade.GradeScaleID, &grade.Type, &grade.Value,
|
||||
&grade.Weight, &grade.Date, &grade.Title, &grade.Description,
|
||||
&grade.IsVisible, &grade.Semester, &grade.CreatedAt, &grade.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan grade: %w", err)
|
||||
}
|
||||
grades = append(grades, grade)
|
||||
}
|
||||
|
||||
return grades, nil
|
||||
}
|
||||
|
||||
// GetStudentGradesBySubject gets grades for a student in a specific subject
|
||||
func (s *GradeService) GetStudentGradesBySubject(ctx context.Context, studentID, subjectID, schoolYearID uuid.UUID, semester int) ([]models.Grade, error) {
|
||||
query := `
|
||||
SELECT id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at
|
||||
FROM grades
|
||||
WHERE student_id = $1 AND subject_id = $2 AND school_year_id = $3 AND semester = $4 AND is_visible = true
|
||||
ORDER BY date DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, studentID, subjectID, schoolYearID, semester)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get grades by subject: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var grades []models.Grade
|
||||
for rows.Next() {
|
||||
var grade models.Grade
|
||||
err := rows.Scan(
|
||||
&grade.ID, &grade.StudentID, &grade.SubjectID, &grade.TeacherID,
|
||||
&grade.SchoolYearID, &grade.GradeScaleID, &grade.Type, &grade.Value,
|
||||
&grade.Weight, &grade.Date, &grade.Title, &grade.Description,
|
||||
&grade.IsVisible, &grade.Semester, &grade.CreatedAt, &grade.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan grade: %w", err)
|
||||
}
|
||||
grades = append(grades, grade)
|
||||
}
|
||||
|
||||
return grades, nil
|
||||
}
|
||||
|
||||
// GetClassGradesBySubject gets all grades for a class in a subject (Notenspiegel)
|
||||
func (s *GradeService) GetClassGradesBySubject(ctx context.Context, classID, subjectID, schoolYearID uuid.UUID, semester int) ([]models.StudentGradeOverview, error) {
|
||||
// Get all students in the class
|
||||
studentsQuery := `
|
||||
SELECT id, first_name, last_name
|
||||
FROM students
|
||||
WHERE class_id = $1 AND is_active = true
|
||||
ORDER BY last_name, first_name`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, studentsQuery, classID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get students: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var students []struct {
|
||||
ID uuid.UUID
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var student struct {
|
||||
ID uuid.UUID
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
||||
if err := rows.Scan(&student.ID, &student.FirstName, &student.LastName); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan student: %w", err)
|
||||
}
|
||||
students = append(students, student)
|
||||
}
|
||||
|
||||
// Get subject info
|
||||
var subject models.Subject
|
||||
err = s.db.Pool.QueryRow(ctx, `SELECT id, school_id, name, short_name, color, is_active, created_at FROM subjects WHERE id = $1`, subjectID).Scan(
|
||||
&subject.ID, &subject.SchoolID, &subject.Name, &subject.ShortName, &subject.Color, &subject.IsActive, &subject.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get subject: %w", err)
|
||||
}
|
||||
|
||||
var overviews []models.StudentGradeOverview
|
||||
|
||||
for _, student := range students {
|
||||
grades, err := s.GetStudentGradesBySubject(ctx, student.ID, subjectID, schoolYearID, semester)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate averages
|
||||
var totalWeight, weightedSum float64
|
||||
var oralWeight, oralSum float64
|
||||
var examWeight, examSum float64
|
||||
|
||||
for _, grade := range grades {
|
||||
totalWeight += grade.Weight
|
||||
weightedSum += grade.Value * grade.Weight
|
||||
|
||||
if grade.Type == models.GradeTypeOral || grade.Type == models.GradeTypeParticipation {
|
||||
oralWeight += grade.Weight
|
||||
oralSum += grade.Value * grade.Weight
|
||||
} else if grade.Type == models.GradeTypeExam || grade.Type == models.GradeTypeTest {
|
||||
examWeight += grade.Weight
|
||||
examSum += grade.Value * grade.Weight
|
||||
}
|
||||
}
|
||||
|
||||
var average, oralAverage, examAverage float64
|
||||
if totalWeight > 0 {
|
||||
average = weightedSum / totalWeight
|
||||
}
|
||||
if oralWeight > 0 {
|
||||
oralAverage = oralSum / oralWeight
|
||||
}
|
||||
if examWeight > 0 {
|
||||
examAverage = examSum / examWeight
|
||||
}
|
||||
|
||||
overview := models.StudentGradeOverview{
|
||||
Student: models.Student{
|
||||
ID: student.ID,
|
||||
FirstName: student.FirstName,
|
||||
LastName: student.LastName,
|
||||
},
|
||||
Subject: subject,
|
||||
Grades: grades,
|
||||
Average: average,
|
||||
OralAverage: oralAverage,
|
||||
ExamAverage: examAverage,
|
||||
Semester: semester,
|
||||
}
|
||||
|
||||
overviews = append(overviews, overview)
|
||||
}
|
||||
|
||||
return overviews, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Grade Statistics
|
||||
// ========================================
|
||||
|
||||
// GetStudentGradeAverage calculates the overall grade average for a student
|
||||
func (s *GradeService) GetStudentGradeAverage(ctx context.Context, studentID, schoolYearID uuid.UUID, semester int) (float64, error) {
|
||||
query := `
|
||||
SELECT COALESCE(SUM(value * weight) / NULLIF(SUM(weight), 0), 0)
|
||||
FROM grades
|
||||
WHERE student_id = $1 AND school_year_id = $2 AND semester = $3 AND is_visible = true`
|
||||
|
||||
var average float64
|
||||
err := s.db.Pool.QueryRow(ctx, query, studentID, schoolYearID, semester).Scan(&average)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to calculate average: %w", err)
|
||||
}
|
||||
|
||||
return average, nil
|
||||
}
|
||||
|
||||
// GetSubjectGradeStatistics gets grade statistics for a subject in a class
|
||||
func (s *GradeService) GetSubjectGradeStatistics(ctx context.Context, classID, subjectID, schoolYearID uuid.UUID, semester int) (map[string]interface{}, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(DISTINCT g.student_id) as student_count,
|
||||
AVG(g.value) as class_average,
|
||||
MIN(g.value) as best_grade,
|
||||
MAX(g.value) as worst_grade,
|
||||
COUNT(*) as total_grades
|
||||
FROM grades g
|
||||
JOIN students s ON g.student_id = s.id
|
||||
WHERE s.class_id = $1 AND g.subject_id = $2 AND g.school_year_id = $3 AND g.semester = $4 AND g.is_visible = true`
|
||||
|
||||
var studentCount, totalGrades int
|
||||
var classAverage, bestGrade, worstGrade float64
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, classID, subjectID, schoolYearID, semester).Scan(
|
||||
&studentCount, &classAverage, &bestGrade, &worstGrade, &totalGrades,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get statistics: %w", err)
|
||||
}
|
||||
|
||||
// Grade distribution (for German grades 1-6)
|
||||
distributionQuery := `
|
||||
SELECT
|
||||
COUNT(CASE WHEN value >= 1 AND value < 1.5 THEN 1 END) as grade_1,
|
||||
COUNT(CASE WHEN value >= 1.5 AND value < 2.5 THEN 1 END) as grade_2,
|
||||
COUNT(CASE WHEN value >= 2.5 AND value < 3.5 THEN 1 END) as grade_3,
|
||||
COUNT(CASE WHEN value >= 3.5 AND value < 4.5 THEN 1 END) as grade_4,
|
||||
COUNT(CASE WHEN value >= 4.5 AND value < 5.5 THEN 1 END) as grade_5,
|
||||
COUNT(CASE WHEN value >= 5.5 THEN 1 END) as grade_6
|
||||
FROM grades g
|
||||
JOIN students s ON g.student_id = s.id
|
||||
WHERE s.class_id = $1 AND g.subject_id = $2 AND g.school_year_id = $3 AND g.semester = $4 AND g.is_visible = true AND g.type IN ('exam', 'test')`
|
||||
|
||||
var g1, g2, g3, g4, g5, g6 int
|
||||
err = s.db.Pool.QueryRow(ctx, distributionQuery, classID, subjectID, schoolYearID, semester).Scan(
|
||||
&g1, &g2, &g3, &g4, &g5, &g6,
|
||||
)
|
||||
if err != nil {
|
||||
// Non-fatal, continue without distribution
|
||||
g1, g2, g3, g4, g5, g6 = 0, 0, 0, 0, 0, 0
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"student_count": studentCount,
|
||||
"class_average": classAverage,
|
||||
"best_grade": bestGrade,
|
||||
"worst_grade": worstGrade,
|
||||
"total_grades": totalGrades,
|
||||
"distribution": map[string]int{
|
||||
"1": g1,
|
||||
"2": g2,
|
||||
"3": g3,
|
||||
"4": g4,
|
||||
"5": g5,
|
||||
"6": g6,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Grade Comments
|
||||
// ========================================
|
||||
|
||||
// AddGradeComment adds a comment to a grade
|
||||
func (s *GradeService) AddGradeComment(ctx context.Context, gradeID, teacherID uuid.UUID, comment string, isPrivate bool) (*models.GradeComment, error) {
|
||||
gradeComment := &models.GradeComment{
|
||||
ID: uuid.New(),
|
||||
GradeID: gradeID,
|
||||
TeacherID: teacherID,
|
||||
Comment: comment,
|
||||
IsPrivate: isPrivate,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO grade_comments (id, grade_id, teacher_id, comment, is_private, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id`
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
gradeComment.ID, gradeComment.GradeID, gradeComment.TeacherID,
|
||||
gradeComment.Comment, gradeComment.IsPrivate, gradeComment.CreatedAt,
|
||||
).Scan(&gradeComment.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add grade comment: %w", err)
|
||||
}
|
||||
|
||||
return gradeComment, nil
|
||||
}
|
||||
|
||||
// GetGradeComments gets comments for a grade
|
||||
func (s *GradeService) GetGradeComments(ctx context.Context, gradeID uuid.UUID, includePrivate bool) ([]models.GradeComment, error) {
|
||||
query := `
|
||||
SELECT id, grade_id, teacher_id, comment, is_private, created_at
|
||||
FROM grade_comments
|
||||
WHERE grade_id = $1`
|
||||
|
||||
if !includePrivate {
|
||||
query += ` AND is_private = false`
|
||||
}
|
||||
|
||||
query += ` ORDER BY created_at DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, gradeID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get grade comments: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var comments []models.GradeComment
|
||||
for rows.Next() {
|
||||
var comment models.GradeComment
|
||||
err := rows.Scan(
|
||||
&comment.ID, &comment.GradeID, &comment.TeacherID,
|
||||
&comment.Comment, &comment.IsPrivate, &comment.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan grade comment: %w", err)
|
||||
}
|
||||
comments = append(comments, comment)
|
||||
}
|
||||
|
||||
return comments, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Parent Notifications
|
||||
// ========================================
|
||||
|
||||
func (s *GradeService) notifyParentsOfNewGrade(ctx context.Context, grade *models.Grade) {
|
||||
if s.matrix == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get student info and Matrix room
|
||||
var studentFirstName, studentLastName, matrixDMRoom string
|
||||
err := s.db.Pool.QueryRow(ctx, `
|
||||
SELECT first_name, last_name, matrix_dm_room
|
||||
FROM students
|
||||
WHERE id = $1`, grade.StudentID).Scan(&studentFirstName, &studentLastName, &matrixDMRoom)
|
||||
if err != nil || matrixDMRoom == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get subject name
|
||||
var subjectName string
|
||||
err = s.db.Pool.QueryRow(ctx, `SELECT name FROM subjects WHERE id = $1`, grade.SubjectID).Scan(&subjectName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
studentName := studentFirstName + " " + studentLastName
|
||||
gradeType := s.getGradeTypeDisplayName(grade.Type)
|
||||
|
||||
// Send Matrix notification
|
||||
err = s.matrix.SendGradeNotification(ctx, matrixDMRoom, studentName, subjectName, gradeType, grade.Value)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to send grade notification: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GradeService) getGradeTypeDisplayName(gradeType string) string {
|
||||
switch gradeType {
|
||||
case models.GradeTypeExam:
|
||||
return "Klassenarbeit"
|
||||
case models.GradeTypeTest:
|
||||
return "Test"
|
||||
case models.GradeTypeOral:
|
||||
return "Mündliche Note"
|
||||
case models.GradeTypeHomework:
|
||||
return "Hausaufgabe"
|
||||
case models.GradeTypeProject:
|
||||
return "Projekt"
|
||||
case models.GradeTypeParticipation:
|
||||
return "Mitarbeit"
|
||||
case models.GradeTypeSemester:
|
||||
return "Halbjahreszeugnis"
|
||||
case models.GradeTypeFinal:
|
||||
return "Zeugnisnote"
|
||||
default:
|
||||
return gradeType
|
||||
}
|
||||
}
|
||||
532
consent-service/internal/services/grade_service_test.go
Normal file
532
consent-service/internal/services/grade_service_test.go
Normal file
@@ -0,0 +1,532 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestValidateGrade tests grade validation
|
||||
func TestValidateGrade(t *testing.T) {
|
||||
schoolYearID := uuid.New()
|
||||
gradeScaleID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grade models.Grade
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid grade 1",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 1.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid grade 6",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 6.0,
|
||||
Type: models.GradeTypeOral,
|
||||
Weight: 0.5,
|
||||
Date: time.Now(),
|
||||
Semester: 2,
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid grade with plus (1.3)",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 1.3,
|
||||
Type: models.GradeTypeTest,
|
||||
Weight: 0.25,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "invalid grade 0",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 0.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid grade 7",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 7.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing student ID",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.Nil,
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 2.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid weight negative",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 2.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: -0.5,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid semester 0",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 2.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 0,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid semester 3",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 2.0,
|
||||
Type: models.GradeTypeExam,
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 3,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
grade: models.Grade{
|
||||
StudentID: uuid.New(),
|
||||
SubjectID: uuid.New(),
|
||||
TeacherID: uuid.New(),
|
||||
SchoolYearID: schoolYearID,
|
||||
GradeScaleID: gradeScaleID,
|
||||
Value: 2.0,
|
||||
Type: "invalid_type",
|
||||
Weight: 1.0,
|
||||
Date: time.Now(),
|
||||
Semester: 1,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateGrade(tt.grade)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateGrade validates a grade
|
||||
func validateGrade(grade models.Grade) bool {
|
||||
if grade.StudentID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if grade.SubjectID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if grade.TeacherID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
// German grading scale: 1 (best) to 6 (worst)
|
||||
if grade.Value < 1.0 || grade.Value > 6.0 {
|
||||
return false
|
||||
}
|
||||
if grade.Weight < 0 {
|
||||
return false
|
||||
}
|
||||
if grade.Semester < 1 || grade.Semester > 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate type
|
||||
validTypes := map[string]bool{
|
||||
models.GradeTypeExam: true,
|
||||
models.GradeTypeTest: true,
|
||||
models.GradeTypeOral: true,
|
||||
models.GradeTypeHomework: true,
|
||||
models.GradeTypeProject: true,
|
||||
models.GradeTypeParticipation: true,
|
||||
models.GradeTypeSemester: true,
|
||||
models.GradeTypeFinal: true,
|
||||
}
|
||||
|
||||
if !validTypes[grade.Type] {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TestCalculateWeightedAverage tests weighted average calculation
|
||||
func TestCalculateWeightedAverage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grades []models.Grade
|
||||
expectedAverage float64
|
||||
}{
|
||||
{
|
||||
name: "simple average equal weights",
|
||||
grades: []models.Grade{
|
||||
{Value: 1.0, Weight: 1.0},
|
||||
{Value: 2.0, Weight: 1.0},
|
||||
{Value: 3.0, Weight: 1.0},
|
||||
},
|
||||
expectedAverage: 2.0,
|
||||
},
|
||||
{
|
||||
name: "weighted average different weights",
|
||||
grades: []models.Grade{
|
||||
{Value: 1.0, Weight: 2.0}, // Exam counts double
|
||||
{Value: 3.0, Weight: 1.0},
|
||||
},
|
||||
// (1*2 + 3*1) / (2+1) = 5/3 = 1.67
|
||||
expectedAverage: 1.67,
|
||||
},
|
||||
{
|
||||
name: "single grade",
|
||||
grades: []models.Grade{
|
||||
{Value: 2.5, Weight: 1.0},
|
||||
},
|
||||
expectedAverage: 2.5,
|
||||
},
|
||||
{
|
||||
name: "empty grades",
|
||||
grades: []models.Grade{},
|
||||
expectedAverage: 0.0,
|
||||
},
|
||||
{
|
||||
name: "all same grades",
|
||||
grades: []models.Grade{
|
||||
{Value: 2.0, Weight: 1.0},
|
||||
{Value: 2.0, Weight: 1.0},
|
||||
{Value: 2.0, Weight: 1.0},
|
||||
},
|
||||
expectedAverage: 2.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
avg := calculateWeightedAverage(tt.grades)
|
||||
// Allow small floating point differences
|
||||
if avg < tt.expectedAverage-0.01 || avg > tt.expectedAverage+0.01 {
|
||||
t.Errorf("expected average=%.2f, got average=%.2f", tt.expectedAverage, avg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateWeightedAverage calculates weighted average of grades
|
||||
func calculateWeightedAverage(grades []models.Grade) float64 {
|
||||
if len(grades) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
var weightedSum float64
|
||||
var totalWeight float64
|
||||
|
||||
for _, g := range grades {
|
||||
weightedSum += g.Value * g.Weight
|
||||
totalWeight += g.Weight
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
avg := weightedSum / totalWeight
|
||||
// Round to 2 decimal places
|
||||
return float64(int(avg*100)) / 100
|
||||
}
|
||||
|
||||
// TestGradeDistribution tests grade distribution calculation
|
||||
func TestGradeDistribution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grades []models.Grade
|
||||
expectedDist map[int]int
|
||||
}{
|
||||
{
|
||||
name: "varied distribution",
|
||||
grades: []models.Grade{
|
||||
{Value: 1.0}, {Value: 1.3},
|
||||
{Value: 2.0}, {Value: 2.0}, {Value: 2.7},
|
||||
{Value: 3.0}, {Value: 3.0}, {Value: 3.0},
|
||||
{Value: 4.0}, {Value: 4.3},
|
||||
{Value: 5.0},
|
||||
},
|
||||
expectedDist: map[int]int{
|
||||
1: 2, // 1.0, 1.3 (rounded: 1, 1)
|
||||
2: 2, // 2.0, 2.0 (rounded: 2, 2)
|
||||
3: 4, // 2.7, 3.0, 3.0, 3.0 (rounded: 3, 3, 3, 3)
|
||||
4: 2, // 4.0, 4.3 (rounded: 4, 4)
|
||||
5: 1, // 5.0 (rounded: 5)
|
||||
6: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty grades",
|
||||
grades: []models.Grade{},
|
||||
expectedDist: map[int]int{1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0},
|
||||
},
|
||||
{
|
||||
name: "all same grade",
|
||||
grades: []models.Grade{
|
||||
{Value: 2.0},
|
||||
{Value: 2.0},
|
||||
{Value: 2.0},
|
||||
},
|
||||
expectedDist: map[int]int{1: 0, 2: 3, 3: 0, 4: 0, 5: 0, 6: 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dist := calculateGradeDistribution(tt.grades)
|
||||
for grade, count := range tt.expectedDist {
|
||||
if dist[grade] != count {
|
||||
t.Errorf("grade %d: expected count=%d, got count=%d", grade, count, dist[grade])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateGradeDistribution calculates how many grades fall into each category
|
||||
func calculateGradeDistribution(grades []models.Grade) map[int]int {
|
||||
dist := map[int]int{1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
|
||||
|
||||
for _, g := range grades {
|
||||
// Round to nearest integer for distribution
|
||||
roundedGrade := int(g.Value + 0.5)
|
||||
if roundedGrade < 1 {
|
||||
roundedGrade = 1
|
||||
}
|
||||
if roundedGrade > 6 {
|
||||
roundedGrade = 6
|
||||
}
|
||||
dist[roundedGrade]++
|
||||
}
|
||||
|
||||
return dist
|
||||
}
|
||||
|
||||
// TestGradePointConversion tests conversion between grades and points (Oberstufe)
|
||||
func TestGradePointConversion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grade float64
|
||||
expectedPoints int
|
||||
}{
|
||||
{"grade 1.0 = 15 points", 1.0, 15},
|
||||
{"grade 1.3 = 14 points", 1.3, 14},
|
||||
{"grade 1.7 = 13 points", 1.7, 13},
|
||||
{"grade 2.0 = 12 points", 2.0, 12},
|
||||
{"grade 2.3 = 11 points", 2.3, 11},
|
||||
{"grade 2.7 = 10 points", 2.7, 10},
|
||||
{"grade 3.0 = 9 points", 3.0, 9},
|
||||
{"grade 3.3 = 8 points", 3.3, 8},
|
||||
{"grade 3.7 = 7 points", 3.7, 7},
|
||||
{"grade 4.0 = 6 points", 4.0, 6},
|
||||
{"grade 4.3 = 5 points", 4.3, 5},
|
||||
{"grade 4.7 = 4 points", 4.7, 4},
|
||||
{"grade 5.0 = 3 points", 5.0, 3},
|
||||
{"grade 5.3 = 2 points", 5.3, 2},
|
||||
{"grade 5.7 = 1 point", 5.7, 1},
|
||||
{"grade 6.0 = 0 points", 6.0, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
points := gradeToPoints(tt.grade)
|
||||
if points != tt.expectedPoints {
|
||||
t.Errorf("expected points=%d, got points=%d", tt.expectedPoints, points)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// gradeToPoints converts German grade (1-6) to Oberstufe points (0-15)
|
||||
func gradeToPoints(grade float64) int {
|
||||
// Mapping based on German school system
|
||||
if grade <= 1.0 {
|
||||
return 15
|
||||
} else if grade <= 1.3 {
|
||||
return 14
|
||||
} else if grade <= 1.7 {
|
||||
return 13
|
||||
} else if grade <= 2.0 {
|
||||
return 12
|
||||
} else if grade <= 2.3 {
|
||||
return 11
|
||||
} else if grade <= 2.7 {
|
||||
return 10
|
||||
} else if grade <= 3.0 {
|
||||
return 9
|
||||
} else if grade <= 3.3 {
|
||||
return 8
|
||||
} else if grade <= 3.7 {
|
||||
return 7
|
||||
} else if grade <= 4.0 {
|
||||
return 6
|
||||
} else if grade <= 4.3 {
|
||||
return 5
|
||||
} else if grade <= 4.7 {
|
||||
return 4
|
||||
} else if grade <= 5.0 {
|
||||
return 3
|
||||
} else if grade <= 5.3 {
|
||||
return 2
|
||||
} else if grade <= 5.7 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// TestFindBestAndWorstGrade tests finding best and worst grades
|
||||
func TestFindBestAndWorstGrade(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grades []models.Grade
|
||||
expectedBest float64
|
||||
expectedWorst float64
|
||||
}{
|
||||
{
|
||||
name: "varied grades",
|
||||
grades: []models.Grade{
|
||||
{Value: 2.0},
|
||||
{Value: 1.0},
|
||||
{Value: 3.0},
|
||||
{Value: 5.0},
|
||||
{Value: 2.0},
|
||||
},
|
||||
expectedBest: 1.0,
|
||||
expectedWorst: 5.0,
|
||||
},
|
||||
{
|
||||
name: "all same",
|
||||
grades: []models.Grade{
|
||||
{Value: 2.0},
|
||||
{Value: 2.0},
|
||||
},
|
||||
expectedBest: 2.0,
|
||||
expectedWorst: 2.0,
|
||||
},
|
||||
{
|
||||
name: "single grade",
|
||||
grades: []models.Grade{
|
||||
{Value: 3.0},
|
||||
},
|
||||
expectedBest: 3.0,
|
||||
expectedWorst: 3.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
best, worst := findBestAndWorstGrade(tt.grades)
|
||||
if best != tt.expectedBest {
|
||||
t.Errorf("expected best=%.1f, got best=%.1f", tt.expectedBest, best)
|
||||
}
|
||||
if worst != tt.expectedWorst {
|
||||
t.Errorf("expected worst=%.1f, got worst=%.1f", tt.expectedWorst, worst)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// findBestAndWorstGrade finds the best (lowest) and worst (highest) grade
|
||||
func findBestAndWorstGrade(grades []models.Grade) (best, worst float64) {
|
||||
if len(grades) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
best = grades[0].Value
|
||||
worst = grades[0].Value
|
||||
|
||||
for _, g := range grades[1:] {
|
||||
if g.Value < best {
|
||||
best = g.Value
|
||||
}
|
||||
if g.Value > worst {
|
||||
worst = g.Value
|
||||
}
|
||||
}
|
||||
|
||||
return best, worst
|
||||
}
|
||||
340
consent-service/internal/services/jitsi/game_meetings.go
Normal file
340
consent-service/internal/services/jitsi/game_meetings.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package jitsi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Breakpilot Drive Game Meeting Types
|
||||
// ========================================
|
||||
|
||||
// GameMeetingMode represents different game video call modes
|
||||
type GameMeetingMode string
|
||||
|
||||
const (
|
||||
GameMeetingCoop GameMeetingMode = "coop" // Co-Op voice/video
|
||||
GameMeetingChallenge GameMeetingMode = "challenge" // 1v1 face-off
|
||||
GameMeetingClassRace GameMeetingMode = "class_race" // Teacher supervises
|
||||
GameMeetingTeamHuddle GameMeetingMode = "team_huddle" // Quick team sync
|
||||
)
|
||||
|
||||
// GameMeetingConfig holds configuration for game video meetings
|
||||
type GameMeetingConfig struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Mode GameMeetingMode `json:"mode"`
|
||||
HostID string `json:"host_id"`
|
||||
HostName string `json:"host_name"`
|
||||
Players []GamePlayer `json:"players"`
|
||||
EnableVideo bool `json:"enable_video"`
|
||||
EnableVoice bool `json:"enable_voice"`
|
||||
TeacherID string `json:"teacher_id,omitempty"`
|
||||
TeacherName string `json:"teacher_name,omitempty"`
|
||||
ClassName string `json:"class_name,omitempty"`
|
||||
}
|
||||
|
||||
// GamePlayer represents a player in the meeting
|
||||
type GamePlayer struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
IsModerator bool `json:"is_moderator,omitempty"`
|
||||
}
|
||||
|
||||
// GameMeetingLink extends MeetingLink with game-specific info
|
||||
type GameMeetingLink struct {
|
||||
*MeetingLink
|
||||
SessionID string `json:"session_id"`
|
||||
Mode GameMeetingMode `json:"mode"`
|
||||
Players []string `json:"players"`
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Game Meeting Creation
|
||||
// ========================================
|
||||
|
||||
// CreateCoopMeeting creates a video call for Co-Op gameplay (2-4 players)
|
||||
func (s *JitsiService) CreateCoopMeeting(ctx context.Context, config GameMeetingConfig) (*GameMeetingLink, error) {
|
||||
roomName := fmt.Sprintf("bp-coop-%s", config.SessionID[:8])
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: config.HostName,
|
||||
Subject: "Breakpilot Drive - Co-Op Session",
|
||||
Moderator: true,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: !config.EnableVoice,
|
||||
StartWithVideoMuted: !config.EnableVideo,
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: false, // Direct join for co-op
|
||||
DisableDeepLinking: true,
|
||||
},
|
||||
}
|
||||
|
||||
link, err := s.CreateMeetingLink(ctx, meeting)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create co-op meeting: %w", err)
|
||||
}
|
||||
|
||||
playerIDs := make([]string, len(config.Players))
|
||||
for i, p := range config.Players {
|
||||
playerIDs[i] = p.ID
|
||||
}
|
||||
|
||||
return &GameMeetingLink{
|
||||
MeetingLink: link,
|
||||
SessionID: config.SessionID,
|
||||
Mode: GameMeetingCoop,
|
||||
Players: playerIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateChallengeMeeting creates a 1v1 video call for challenges
|
||||
func (s *JitsiService) CreateChallengeMeeting(ctx context.Context, config GameMeetingConfig, challengerName string, opponentName string) (*GameMeetingLink, error) {
|
||||
roomName := fmt.Sprintf("bp-challenge-%s", config.SessionID[:8])
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: challengerName,
|
||||
Subject: fmt.Sprintf("Challenge: %s vs %s", challengerName, opponentName),
|
||||
Moderator: false, // Both players are equal
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: false, // Voice enabled for trash talk
|
||||
StartWithVideoMuted: !config.EnableVideo,
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: false,
|
||||
DisableDeepLinking: true,
|
||||
},
|
||||
}
|
||||
|
||||
link, err := s.CreateMeetingLink(ctx, meeting)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create challenge meeting: %w", err)
|
||||
}
|
||||
|
||||
return &GameMeetingLink{
|
||||
MeetingLink: link,
|
||||
SessionID: config.SessionID,
|
||||
Mode: GameMeetingChallenge,
|
||||
Players: []string{config.HostID},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateClassRaceMeeting creates a video call for teacher-supervised class races
|
||||
func (s *JitsiService) CreateClassRaceMeeting(ctx context.Context, config GameMeetingConfig) (*GameMeetingLink, error) {
|
||||
roomName := fmt.Sprintf("bp-klasse-%s-%s",
|
||||
s.sanitizeRoomName(config.ClassName),
|
||||
time.Now().Format("150405"))
|
||||
|
||||
// Teacher is moderator
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: config.TeacherName,
|
||||
Subject: fmt.Sprintf("Klassenrennen: %s", config.ClassName),
|
||||
Moderator: true,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: true, // Students muted by default
|
||||
StartWithVideoMuted: true, // Video off for performance
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: true, // Teacher admits students
|
||||
EnableRecording: false, // No recording for minors
|
||||
DisableDeepLinking: true,
|
||||
},
|
||||
Features: &MeetingFeatures{
|
||||
Recording: false,
|
||||
Transcription: false,
|
||||
},
|
||||
}
|
||||
|
||||
link, err := s.CreateMeetingLink(ctx, meeting)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create class race meeting: %w", err)
|
||||
}
|
||||
|
||||
playerIDs := make([]string, len(config.Players))
|
||||
for i, p := range config.Players {
|
||||
playerIDs[i] = p.ID
|
||||
}
|
||||
|
||||
return &GameMeetingLink{
|
||||
MeetingLink: link,
|
||||
SessionID: config.SessionID,
|
||||
Mode: GameMeetingClassRace,
|
||||
Players: playerIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTeamHuddleMeeting creates a quick sync meeting for teams
|
||||
func (s *JitsiService) CreateTeamHuddleMeeting(ctx context.Context, config GameMeetingConfig, teamName string) (*GameMeetingLink, error) {
|
||||
roomName := fmt.Sprintf("bp-team-%s-%s",
|
||||
s.sanitizeRoomName(teamName),
|
||||
config.SessionID[:8])
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: config.HostName,
|
||||
Subject: fmt.Sprintf("Team %s - Huddle", teamName),
|
||||
Duration: 5, // Short 5-minute huddles
|
||||
Moderator: true,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: false, // Voice on for quick sync
|
||||
StartWithVideoMuted: true, // Video optional
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: false,
|
||||
DisableDeepLinking: true,
|
||||
},
|
||||
}
|
||||
|
||||
link, err := s.CreateMeetingLink(ctx, meeting)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create team huddle: %w", err)
|
||||
}
|
||||
|
||||
playerIDs := make([]string, len(config.Players))
|
||||
for i, p := range config.Players {
|
||||
playerIDs[i] = p.ID
|
||||
}
|
||||
|
||||
return &GameMeetingLink{
|
||||
MeetingLink: link,
|
||||
SessionID: config.SessionID,
|
||||
Mode: GameMeetingTeamHuddle,
|
||||
Players: playerIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Game-Specific Meeting Configurations
|
||||
// ========================================
|
||||
|
||||
// GetGameEmbedConfig returns optimized config for embedding in Unity WebGL
|
||||
func (s *JitsiService) GetGameEmbedConfig(enableVideo bool, enableVoice bool) *MeetingConfig {
|
||||
return &MeetingConfig{
|
||||
StartWithAudioMuted: !enableVoice,
|
||||
StartWithVideoMuted: !enableVideo,
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: false,
|
||||
DisableDeepLinking: true, // Important for iframe embedding
|
||||
}
|
||||
}
|
||||
|
||||
// BuildGameEmbedURL creates a URL optimized for Unity WebGL embedding
|
||||
func (s *JitsiService) BuildGameEmbedURL(roomName string, playerName string, enableVideo bool, enableVoice bool) string {
|
||||
config := s.GetGameEmbedConfig(enableVideo, enableVoice)
|
||||
return s.BuildEmbedURL(roomName, playerName, config)
|
||||
}
|
||||
|
||||
// BuildUnityIFrameParams returns parameters for Unity's WebGL iframe
|
||||
func (s *JitsiService) BuildUnityIFrameParams(link *GameMeetingLink, playerName string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"domain": s.extractDomain(),
|
||||
"roomName": link.RoomName,
|
||||
"displayName": playerName,
|
||||
"jwt": link.JWT,
|
||||
"configOverwrite": map[string]interface{}{
|
||||
"startWithAudioMuted": false,
|
||||
"startWithVideoMuted": true,
|
||||
"disableDeepLinking": true,
|
||||
"prejoinPageEnabled": false,
|
||||
"enableWelcomePage": false,
|
||||
"enableClosePage": false,
|
||||
"disableInviteFunctions": true,
|
||||
},
|
||||
"interfaceConfigOverwrite": map[string]interface{}{
|
||||
"DISABLE_JOIN_LEAVE_NOTIFICATIONS": true,
|
||||
"MOBILE_APP_PROMO": false,
|
||||
"SHOW_CHROME_EXTENSION_BANNER": false,
|
||||
"TOOLBAR_BUTTONS": []string{
|
||||
"microphone", "camera", "hangup", "chat",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Spectator Mode (for teachers/parents)
|
||||
// ========================================
|
||||
|
||||
// CreateSpectatorLink creates a view-only link for observers
|
||||
func (s *JitsiService) CreateSpectatorLink(ctx context.Context, roomName string, spectatorName string) (*MeetingLink, error) {
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: fmt.Sprintf("[Zuschauer] %s", spectatorName),
|
||||
Moderator: false,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: true,
|
||||
StartWithVideoMuted: true,
|
||||
DisableDeepLinking: true,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateMeetingLink(ctx, meeting)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Helper Functions
|
||||
// ========================================
|
||||
|
||||
// extractDomain extracts the domain from baseURL
|
||||
func (s *JitsiService) extractDomain() string {
|
||||
// Remove protocol prefix
|
||||
domain := s.baseURL
|
||||
if len(domain) > 8 && domain[:8] == "https://" {
|
||||
domain = domain[8:]
|
||||
} else if len(domain) > 7 && domain[:7] == "http://" {
|
||||
domain = domain[7:]
|
||||
}
|
||||
// Remove port if present
|
||||
for i, c := range domain {
|
||||
if c == ':' || c == '/' {
|
||||
domain = domain[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return domain
|
||||
}
|
||||
|
||||
// ValidateGameMeetingConfig validates configuration before creating meeting
|
||||
func ValidateGameMeetingConfig(config GameMeetingConfig) error {
|
||||
if config.SessionID == "" {
|
||||
return fmt.Errorf("session_id is required")
|
||||
}
|
||||
|
||||
if config.Mode == "" {
|
||||
return fmt.Errorf("mode is required")
|
||||
}
|
||||
|
||||
if config.HostID == "" {
|
||||
return fmt.Errorf("host_id is required")
|
||||
}
|
||||
|
||||
if config.HostName == "" {
|
||||
return fmt.Errorf("host_name is required")
|
||||
}
|
||||
|
||||
switch config.Mode {
|
||||
case GameMeetingCoop:
|
||||
if len(config.Players) < 2 || len(config.Players) > 4 {
|
||||
return fmt.Errorf("co-op mode requires 2-4 players")
|
||||
}
|
||||
case GameMeetingChallenge:
|
||||
if len(config.Players) != 2 {
|
||||
return fmt.Errorf("challenge mode requires exactly 2 players")
|
||||
}
|
||||
case GameMeetingClassRace:
|
||||
if config.TeacherID == "" || config.TeacherName == "" {
|
||||
return fmt.Errorf("class race mode requires teacher info")
|
||||
}
|
||||
if config.ClassName == "" {
|
||||
return fmt.Errorf("class race mode requires class name")
|
||||
}
|
||||
case GameMeetingTeamHuddle:
|
||||
if len(config.Players) < 2 {
|
||||
return fmt.Errorf("team huddle requires at least 2 players")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown game meeting mode: %s", config.Mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
566
consent-service/internal/services/jitsi/jitsi_service.go
Normal file
566
consent-service/internal/services/jitsi/jitsi_service.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package jitsi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// JitsiService handles Jitsi Meet integration for video conferences
|
||||
type JitsiService struct {
|
||||
baseURL string
|
||||
appID string
|
||||
appSecret string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Config holds Jitsi service configuration
|
||||
type Config struct {
|
||||
BaseURL string // e.g., "http://localhost:8443"
|
||||
AppID string // Application ID for JWT (optional)
|
||||
AppSecret string // Secret for JWT signing (optional)
|
||||
}
|
||||
|
||||
// NewJitsiService creates a new Jitsi service instance
|
||||
func NewJitsiService(cfg Config) *JitsiService {
|
||||
return &JitsiService{
|
||||
baseURL: strings.TrimSuffix(cfg.BaseURL, "/"),
|
||||
appID: cfg.AppID,
|
||||
appSecret: cfg.AppSecret,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Types
|
||||
// ========================================
|
||||
|
||||
// Meeting represents a Jitsi meeting configuration
|
||||
type Meeting struct {
|
||||
RoomName string `json:"room_name"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
StartTime *time.Time `json:"start_time,omitempty"`
|
||||
Duration int `json:"duration,omitempty"` // in minutes
|
||||
Config *MeetingConfig `json:"config,omitempty"`
|
||||
Moderator bool `json:"moderator,omitempty"`
|
||||
Features *MeetingFeatures `json:"features,omitempty"`
|
||||
}
|
||||
|
||||
// MeetingConfig holds Jitsi room configuration options
|
||||
type MeetingConfig struct {
|
||||
StartWithAudioMuted bool `json:"start_with_audio_muted,omitempty"`
|
||||
StartWithVideoMuted bool `json:"start_with_video_muted,omitempty"`
|
||||
DisableDeepLinking bool `json:"disable_deep_linking,omitempty"`
|
||||
RequireDisplayName bool `json:"require_display_name,omitempty"`
|
||||
EnableLobby bool `json:"enable_lobby,omitempty"`
|
||||
EnableRecording bool `json:"enable_recording,omitempty"`
|
||||
}
|
||||
|
||||
// MeetingFeatures controls which features are enabled
|
||||
type MeetingFeatures struct {
|
||||
Livestreaming bool `json:"livestreaming,omitempty"`
|
||||
Recording bool `json:"recording,omitempty"`
|
||||
Transcription bool `json:"transcription,omitempty"`
|
||||
OutboundCall bool `json:"outbound_call,omitempty"`
|
||||
}
|
||||
|
||||
// MeetingLink contains the generated meeting URL and metadata
|
||||
type MeetingLink struct {
|
||||
URL string `json:"url"`
|
||||
RoomName string `json:"room_name"`
|
||||
JoinURL string `json:"join_url"`
|
||||
ModeratorURL string `json:"moderator_url,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
JWT string `json:"jwt,omitempty"`
|
||||
}
|
||||
|
||||
// JWTClaims represents the JWT payload for Jitsi
|
||||
type JWTClaims struct {
|
||||
Audience string `json:"aud,omitempty"`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Room string `json:"room,omitempty"`
|
||||
ExpiresAt int64 `json:"exp,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
Context *JWTContext `json:"context,omitempty"`
|
||||
Moderator bool `json:"moderator,omitempty"`
|
||||
Features *JWTFeatures `json:"features,omitempty"`
|
||||
}
|
||||
|
||||
// JWTContext contains user information for JWT
|
||||
type JWTContext struct {
|
||||
User *JWTUser `json:"user,omitempty"`
|
||||
Group string `json:"group,omitempty"`
|
||||
Callee *JWTCallee `json:"callee,omitempty"`
|
||||
}
|
||||
|
||||
// JWTUser represents user info in JWT
|
||||
type JWTUser struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
Moderator bool `json:"moderator,omitempty"`
|
||||
HiddenFromRecorder bool `json:"hidden-from-recorder,omitempty"`
|
||||
}
|
||||
|
||||
// JWTCallee represents callee info (for 1:1 calls)
|
||||
type JWTCallee struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
}
|
||||
|
||||
// JWTFeatures controls JWT-based feature access
|
||||
type JWTFeatures struct {
|
||||
Livestreaming string `json:"livestreaming,omitempty"` // "true" or "false"
|
||||
Recording string `json:"recording,omitempty"`
|
||||
Transcription string `json:"transcription,omitempty"`
|
||||
OutboundCall string `json:"outbound-call,omitempty"`
|
||||
}
|
||||
|
||||
// ScheduledMeeting represents a scheduled training/meeting
|
||||
type ScheduledMeeting struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description,omitempty"`
|
||||
RoomName string `json:"room_name"`
|
||||
HostID string `json:"host_id"`
|
||||
HostName string `json:"host_name"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Duration int `json:"duration"` // in minutes
|
||||
Password string `json:"password,omitempty"`
|
||||
MaxParticipants int `json:"max_participants,omitempty"`
|
||||
Features *MeetingFeatures `json:"features,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Meeting Management
|
||||
// ========================================
|
||||
|
||||
// CreateMeetingLink generates a meeting URL with optional JWT authentication
|
||||
func (s *JitsiService) CreateMeetingLink(ctx context.Context, meeting Meeting) (*MeetingLink, error) {
|
||||
// Generate room name if not provided
|
||||
roomName := meeting.RoomName
|
||||
if roomName == "" {
|
||||
roomName = s.generateRoomName()
|
||||
}
|
||||
|
||||
// Sanitize room name (Jitsi-compatible)
|
||||
roomName = s.sanitizeRoomName(roomName)
|
||||
|
||||
link := &MeetingLink{
|
||||
RoomName: roomName,
|
||||
URL: fmt.Sprintf("%s/%s", s.baseURL, roomName),
|
||||
JoinURL: fmt.Sprintf("%s/%s", s.baseURL, roomName),
|
||||
Password: meeting.Password,
|
||||
}
|
||||
|
||||
// Generate JWT if authentication is configured
|
||||
if s.appSecret != "" {
|
||||
jwt, expiresAt, err := s.generateJWT(meeting, roomName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate JWT: %w", err)
|
||||
}
|
||||
link.JWT = jwt
|
||||
link.ExpiresAt = expiresAt
|
||||
link.JoinURL = fmt.Sprintf("%s/%s?jwt=%s", s.baseURL, roomName, jwt)
|
||||
|
||||
// Generate moderator URL if user is moderator
|
||||
if meeting.Moderator {
|
||||
link.ModeratorURL = link.JoinURL
|
||||
}
|
||||
}
|
||||
|
||||
// Add config parameters to URL
|
||||
if meeting.Config != nil {
|
||||
params := s.buildConfigParams(meeting.Config)
|
||||
if params != "" {
|
||||
separator := "?"
|
||||
if strings.Contains(link.JoinURL, "?") {
|
||||
separator = "&"
|
||||
}
|
||||
link.JoinURL += separator + params
|
||||
}
|
||||
}
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
// CreateTrainingSession creates a meeting link optimized for training sessions
|
||||
func (s *JitsiService) CreateTrainingSession(ctx context.Context, title string, hostName string, hostEmail string, duration int) (*MeetingLink, error) {
|
||||
meeting := Meeting{
|
||||
RoomName: s.generateTrainingRoomName(title),
|
||||
DisplayName: hostName,
|
||||
Email: hostEmail,
|
||||
Subject: title,
|
||||
Duration: duration,
|
||||
Moderator: true,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: true, // Participants start muted
|
||||
StartWithVideoMuted: false, // Video on for training
|
||||
RequireDisplayName: true, // Know who's attending
|
||||
EnableLobby: true, // Waiting room
|
||||
EnableRecording: true, // Allow recording
|
||||
},
|
||||
Features: &MeetingFeatures{
|
||||
Recording: true,
|
||||
Transcription: false,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateMeetingLink(ctx, meeting)
|
||||
}
|
||||
|
||||
// CreateQuickMeeting creates a simple ad-hoc meeting
|
||||
func (s *JitsiService) CreateQuickMeeting(ctx context.Context, displayName string) (*MeetingLink, error) {
|
||||
meeting := Meeting{
|
||||
DisplayName: displayName,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: false,
|
||||
StartWithVideoMuted: false,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateMeetingLink(ctx, meeting)
|
||||
}
|
||||
|
||||
// CreateParentTeacherMeeting creates a meeting for parent-teacher conferences
|
||||
func (s *JitsiService) CreateParentTeacherMeeting(ctx context.Context, teacherName string, parentName string, studentName string, scheduledTime time.Time) (*MeetingLink, error) {
|
||||
roomName := fmt.Sprintf("elterngespraech-%s-%s",
|
||||
s.sanitizeRoomName(studentName),
|
||||
scheduledTime.Format("20060102-1504"))
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: teacherName,
|
||||
Subject: fmt.Sprintf("Elterngespräch: %s", studentName),
|
||||
StartTime: &scheduledTime,
|
||||
Duration: 30, // 30 minutes default
|
||||
Moderator: true,
|
||||
Password: s.generatePassword(),
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: false,
|
||||
StartWithVideoMuted: false,
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: true, // Teacher admits parent
|
||||
DisableDeepLinking: true,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateMeetingLink(ctx, meeting)
|
||||
}
|
||||
|
||||
// CreateClassMeeting creates a meeting for an entire class
|
||||
func (s *JitsiService) CreateClassMeeting(ctx context.Context, className string, teacherName string, subject string) (*MeetingLink, error) {
|
||||
roomName := fmt.Sprintf("klasse-%s-%s",
|
||||
s.sanitizeRoomName(className),
|
||||
time.Now().Format("20060102"))
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: roomName,
|
||||
DisplayName: teacherName,
|
||||
Subject: fmt.Sprintf("%s - %s", className, subject),
|
||||
Moderator: true,
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: true, // Students muted by default
|
||||
StartWithVideoMuted: false,
|
||||
RequireDisplayName: true,
|
||||
EnableLobby: false, // Direct join for classes
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateMeetingLink(ctx, meeting)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// JWT Generation
|
||||
// ========================================
|
||||
|
||||
// generateJWT creates a signed JWT for Jitsi authentication
|
||||
func (s *JitsiService) generateJWT(meeting Meeting, roomName string) (string, *time.Time, error) {
|
||||
if s.appSecret == "" {
|
||||
return "", nil, fmt.Errorf("app secret not configured")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Default expiration: 24 hours or based on meeting duration
|
||||
expiration := now.Add(24 * time.Hour)
|
||||
if meeting.Duration > 0 {
|
||||
expiration = now.Add(time.Duration(meeting.Duration+30) * time.Minute)
|
||||
}
|
||||
if meeting.StartTime != nil {
|
||||
expiration = meeting.StartTime.Add(time.Duration(meeting.Duration+60) * time.Minute)
|
||||
}
|
||||
|
||||
claims := JWTClaims{
|
||||
Audience: "jitsi",
|
||||
Issuer: s.appID,
|
||||
Subject: "meet.jitsi",
|
||||
Room: roomName,
|
||||
ExpiresAt: expiration.Unix(),
|
||||
NotBefore: now.Add(-5 * time.Minute).Unix(), // 5 min grace period
|
||||
Moderator: meeting.Moderator,
|
||||
Context: &JWTContext{
|
||||
User: &JWTUser{
|
||||
ID: uuid.New().String(),
|
||||
Name: meeting.DisplayName,
|
||||
Email: meeting.Email,
|
||||
Avatar: meeting.Avatar,
|
||||
Moderator: meeting.Moderator,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add features if specified
|
||||
if meeting.Features != nil {
|
||||
claims.Features = &JWTFeatures{
|
||||
Recording: boolToString(meeting.Features.Recording),
|
||||
Livestreaming: boolToString(meeting.Features.Livestreaming),
|
||||
Transcription: boolToString(meeting.Features.Transcription),
|
||||
OutboundCall: boolToString(meeting.Features.OutboundCall),
|
||||
}
|
||||
}
|
||||
|
||||
// Create JWT
|
||||
token, err := s.signJWT(claims)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return token, &expiration, nil
|
||||
}
|
||||
|
||||
// signJWT creates and signs a JWT token
|
||||
func (s *JitsiService) signJWT(claims JWTClaims) (string, error) {
|
||||
// Header
|
||||
header := map[string]string{
|
||||
"alg": "HS256",
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Payload
|
||||
payloadJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Encode
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
// Sign
|
||||
message := headerB64 + "." + payloadB64
|
||||
h := hmac.New(sha256.New, []byte(s.appSecret))
|
||||
h.Write([]byte(message))
|
||||
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
return message + "." + signature, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Health Check
|
||||
// ========================================
|
||||
|
||||
// HealthCheck verifies the Jitsi server is accessible
|
||||
func (s *JitsiService) HealthCheck(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("jitsi server unreachable: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 500 {
|
||||
return fmt.Errorf("jitsi server error: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerInfo returns information about the Jitsi server
|
||||
func (s *JitsiService) GetServerInfo() map[string]string {
|
||||
return map[string]string{
|
||||
"base_url": s.baseURL,
|
||||
"app_id": s.appID,
|
||||
"auth_enabled": boolToString(s.appSecret != ""),
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// URL Building
|
||||
// ========================================
|
||||
|
||||
// BuildEmbedURL creates an embeddable iframe URL
|
||||
func (s *JitsiService) BuildEmbedURL(roomName string, displayName string, config *MeetingConfig) string {
|
||||
params := url.Values{}
|
||||
|
||||
if displayName != "" {
|
||||
params.Set("userInfo.displayName", displayName)
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
if config.StartWithAudioMuted {
|
||||
params.Set("config.startWithAudioMuted", "true")
|
||||
}
|
||||
if config.StartWithVideoMuted {
|
||||
params.Set("config.startWithVideoMuted", "true")
|
||||
}
|
||||
if config.DisableDeepLinking {
|
||||
params.Set("config.disableDeepLinking", "true")
|
||||
}
|
||||
}
|
||||
|
||||
embedURL := fmt.Sprintf("%s/%s", s.baseURL, s.sanitizeRoomName(roomName))
|
||||
if len(params) > 0 {
|
||||
embedURL += "#" + params.Encode()
|
||||
}
|
||||
|
||||
return embedURL
|
||||
}
|
||||
|
||||
// BuildIFrameCode generates HTML iframe code for embedding
|
||||
func (s *JitsiService) BuildIFrameCode(roomName string, width int, height int) string {
|
||||
if width == 0 {
|
||||
width = 800
|
||||
}
|
||||
if height == 0 {
|
||||
height = 600
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`<iframe src="%s/%s" width="%d" height="%d" allow="camera; microphone; fullscreen; display-capture; autoplay" style="border: 0;"></iframe>`,
|
||||
s.baseURL,
|
||||
s.sanitizeRoomName(roomName),
|
||||
width,
|
||||
height,
|
||||
)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Helper Functions
|
||||
// ========================================
|
||||
|
||||
// generateRoomName creates a unique room name
|
||||
func (s *JitsiService) generateRoomName() string {
|
||||
return fmt.Sprintf("breakpilot-%s", uuid.New().String()[:8])
|
||||
}
|
||||
|
||||
// generateTrainingRoomName creates a room name for training sessions
|
||||
func (s *JitsiService) generateTrainingRoomName(title string) string {
|
||||
sanitized := s.sanitizeRoomName(title)
|
||||
if sanitized == "" {
|
||||
sanitized = "schulung"
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", sanitized, time.Now().Format("20060102-1504"))
|
||||
}
|
||||
|
||||
// sanitizeRoomName removes invalid characters from room names
|
||||
func (s *JitsiService) sanitizeRoomName(name string) string {
|
||||
// Replace spaces and special characters
|
||||
result := strings.ToLower(name)
|
||||
result = strings.ReplaceAll(result, " ", "-")
|
||||
result = strings.ReplaceAll(result, "ä", "ae")
|
||||
result = strings.ReplaceAll(result, "ö", "oe")
|
||||
result = strings.ReplaceAll(result, "ü", "ue")
|
||||
result = strings.ReplaceAll(result, "ß", "ss")
|
||||
|
||||
// Remove any remaining non-alphanumeric characters except hyphen
|
||||
var cleaned strings.Builder
|
||||
for _, r := range result {
|
||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
|
||||
cleaned.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove consecutive hyphens
|
||||
result = cleaned.String()
|
||||
for strings.Contains(result, "--") {
|
||||
result = strings.ReplaceAll(result, "--", "-")
|
||||
}
|
||||
|
||||
// Trim hyphens from start and end
|
||||
result = strings.Trim(result, "-")
|
||||
|
||||
// Limit length
|
||||
if len(result) > 50 {
|
||||
result = result[:50]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// generatePassword creates a random meeting password
|
||||
func (s *JitsiService) generatePassword() string {
|
||||
return uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
// buildConfigParams creates URL parameters from config
|
||||
func (s *JitsiService) buildConfigParams(config *MeetingConfig) string {
|
||||
params := url.Values{}
|
||||
|
||||
if config.StartWithAudioMuted {
|
||||
params.Set("config.startWithAudioMuted", "true")
|
||||
}
|
||||
if config.StartWithVideoMuted {
|
||||
params.Set("config.startWithVideoMuted", "true")
|
||||
}
|
||||
if config.DisableDeepLinking {
|
||||
params.Set("config.disableDeepLinking", "true")
|
||||
}
|
||||
if config.RequireDisplayName {
|
||||
params.Set("config.requireDisplayName", "true")
|
||||
}
|
||||
if config.EnableLobby {
|
||||
params.Set("config.enableLobby", "true")
|
||||
}
|
||||
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// boolToString converts bool to "true"/"false" string
|
||||
func boolToString(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
// GetBaseURL returns the configured base URL
|
||||
func (s *JitsiService) GetBaseURL() string {
|
||||
return s.baseURL
|
||||
}
|
||||
|
||||
// IsAuthEnabled returns whether JWT authentication is configured
|
||||
func (s *JitsiService) IsAuthEnabled() bool {
|
||||
return s.appSecret != ""
|
||||
}
|
||||
687
consent-service/internal/services/jitsi/jitsi_service_test.go
Normal file
687
consent-service/internal/services/jitsi/jitsi_service_test.go
Normal file
@@ -0,0 +1,687 @@
|
||||
package jitsi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Test Helpers
|
||||
// ========================================
|
||||
|
||||
func createTestService() *JitsiService {
|
||||
return NewJitsiService(Config{
|
||||
BaseURL: "http://localhost:8443",
|
||||
AppID: "breakpilot",
|
||||
AppSecret: "test-secret-key",
|
||||
})
|
||||
}
|
||||
|
||||
func createTestServiceWithoutAuth() *JitsiService {
|
||||
return NewJitsiService(Config{
|
||||
BaseURL: "http://localhost:8443",
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Service Creation
|
||||
// ========================================
|
||||
|
||||
func TestNewJitsiService_ValidConfig_CreatesService(t *testing.T) {
|
||||
cfg := Config{
|
||||
BaseURL: "http://localhost:8443",
|
||||
AppID: "test-app",
|
||||
AppSecret: "test-secret",
|
||||
}
|
||||
|
||||
service := NewJitsiService(cfg)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("Expected service to be created, got nil")
|
||||
}
|
||||
if service.baseURL != cfg.BaseURL {
|
||||
t.Errorf("Expected baseURL %s, got %s", cfg.BaseURL, service.baseURL)
|
||||
}
|
||||
if service.appID != cfg.AppID {
|
||||
t.Errorf("Expected appID %s, got %s", cfg.AppID, service.appID)
|
||||
}
|
||||
if service.appSecret != cfg.AppSecret {
|
||||
t.Errorf("Expected appSecret %s, got %s", cfg.AppSecret, service.appSecret)
|
||||
}
|
||||
if service.httpClient == nil {
|
||||
t.Error("Expected httpClient to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJitsiService_TrailingSlash_Removed(t *testing.T) {
|
||||
service := NewJitsiService(Config{
|
||||
BaseURL: "http://localhost:8443/",
|
||||
})
|
||||
|
||||
if service.baseURL != "http://localhost:8443" {
|
||||
t.Errorf("Expected trailing slash to be removed, got %s", service.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBaseURL_ReturnsConfiguredURL(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
result := service.GetBaseURL()
|
||||
|
||||
if result != "http://localhost:8443" {
|
||||
t.Errorf("Expected 'http://localhost:8443', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAuthEnabled_WithSecret_ReturnsTrue(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
if !service.IsAuthEnabled() {
|
||||
t.Error("Expected auth to be enabled when secret is configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAuthEnabled_WithoutSecret_ReturnsFalse(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
if service.IsAuthEnabled() {
|
||||
t.Error("Expected auth to be disabled when secret is not configured")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Room Name Generation
|
||||
// ========================================
|
||||
|
||||
func TestSanitizeRoomName_ValidInput_ReturnsCleanName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple name",
|
||||
input: "meeting",
|
||||
expected: "meeting",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
input: "My Meeting Room",
|
||||
expected: "my-meeting-room",
|
||||
},
|
||||
{
|
||||
name: "german umlauts",
|
||||
input: "Schüler Müller",
|
||||
expected: "schueler-mueller",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "Test@#$%Meeting!",
|
||||
expected: "testmeeting",
|
||||
},
|
||||
{
|
||||
name: "consecutive hyphens",
|
||||
input: "test---meeting",
|
||||
expected: "test-meeting",
|
||||
},
|
||||
{
|
||||
name: "leading trailing hyphens",
|
||||
input: "-test-meeting-",
|
||||
expected: "test-meeting",
|
||||
},
|
||||
{
|
||||
name: "eszett",
|
||||
input: "Straße",
|
||||
expected: "strasse",
|
||||
},
|
||||
{
|
||||
name: "numbers",
|
||||
input: "Klasse 5a",
|
||||
expected: "klasse-5a",
|
||||
},
|
||||
}
|
||||
|
||||
service := createTestService()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := service.sanitizeRoomName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRoomName_LongName_Truncated(t *testing.T) {
|
||||
service := createTestService()
|
||||
longName := strings.Repeat("a", 100)
|
||||
|
||||
result := service.sanitizeRoomName(longName)
|
||||
|
||||
if len(result) > 50 {
|
||||
t.Errorf("Expected max 50 chars, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRoomName_ReturnsUniqueNames(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
name1 := service.generateRoomName()
|
||||
name2 := service.generateRoomName()
|
||||
|
||||
if name1 == name2 {
|
||||
t.Error("Expected unique room names")
|
||||
}
|
||||
if !strings.HasPrefix(name1, "breakpilot-") {
|
||||
t.Errorf("Expected prefix 'breakpilot-', got '%s'", name1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTrainingRoomName_IncludesTitle(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
result := service.generateTrainingRoomName("Go Workshop")
|
||||
|
||||
if !strings.HasPrefix(result, "go-workshop-") {
|
||||
t.Errorf("Expected to start with 'go-workshop-', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePassword_ReturnsValidPassword(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
password := service.generatePassword()
|
||||
|
||||
if len(password) != 8 {
|
||||
t.Errorf("Expected 8 char password, got %d", len(password))
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Meeting Link Creation
|
||||
// ========================================
|
||||
|
||||
func TestCreateMeetingLink_BasicMeeting_ReturnsValidLink(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
link, err := service.CreateMeetingLink(context.Background(), meeting)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if link.RoomName != "test-room" {
|
||||
t.Errorf("Expected room name 'test-room', got '%s'", link.RoomName)
|
||||
}
|
||||
if link.URL != "http://localhost:8443/test-room" {
|
||||
t.Errorf("Expected URL 'http://localhost:8443/test-room', got '%s'", link.URL)
|
||||
}
|
||||
if link.JoinURL != "http://localhost:8443/test-room" {
|
||||
t.Errorf("Expected JoinURL 'http://localhost:8443/test-room', got '%s'", link.JoinURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMeetingLink_NoRoomName_GeneratesName(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
meeting := Meeting{
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
link, err := service.CreateMeetingLink(context.Background(), meeting)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if link.RoomName == "" {
|
||||
t.Error("Expected room name to be generated")
|
||||
}
|
||||
if !strings.HasPrefix(link.RoomName, "breakpilot-") {
|
||||
t.Errorf("Expected generated room name to start with 'breakpilot-', got '%s'", link.RoomName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMeetingLink_WithPassword_IncludesPassword(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
Password: "secret123",
|
||||
}
|
||||
|
||||
link, err := service.CreateMeetingLink(context.Background(), meeting)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if link.Password != "secret123" {
|
||||
t.Errorf("Expected password 'secret123', got '%s'", link.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMeetingLink_WithAuth_IncludesJWT(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
DisplayName: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
|
||||
link, err := service.CreateMeetingLink(context.Background(), meeting)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if link.JWT == "" {
|
||||
t.Error("Expected JWT to be generated")
|
||||
}
|
||||
if !strings.Contains(link.JoinURL, "jwt=") {
|
||||
t.Error("Expected JoinURL to contain JWT parameter")
|
||||
}
|
||||
if link.ExpiresAt == nil {
|
||||
t.Error("Expected ExpiresAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMeetingLink_WithConfig_IncludesParams(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
Config: &MeetingConfig{
|
||||
StartWithAudioMuted: true,
|
||||
StartWithVideoMuted: true,
|
||||
RequireDisplayName: true,
|
||||
},
|
||||
}
|
||||
|
||||
link, err := service.CreateMeetingLink(context.Background(), meeting)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(link.JoinURL, "startWithAudioMuted=true") {
|
||||
t.Error("Expected JoinURL to contain audio muted config")
|
||||
}
|
||||
if !strings.Contains(link.JoinURL, "startWithVideoMuted=true") {
|
||||
t.Error("Expected JoinURL to contain video muted config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMeetingLink_Moderator_SetsModeratorURL(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
DisplayName: "Admin",
|
||||
Moderator: true,
|
||||
}
|
||||
|
||||
link, err := service.CreateMeetingLink(context.Background(), meeting)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if link.ModeratorURL == "" {
|
||||
t.Error("Expected ModeratorURL to be set for moderator")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Specialized Meeting Types
|
||||
// ========================================
|
||||
|
||||
func TestCreateTrainingSession_ReturnsOptimizedConfig(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
link, err := service.CreateTrainingSession(
|
||||
context.Background(),
|
||||
"Go Grundlagen",
|
||||
"Max Trainer",
|
||||
"trainer@example.com",
|
||||
60,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(link.RoomName, "go-grundlagen") {
|
||||
t.Errorf("Expected room name to contain 'go-grundlagen', got '%s'", link.RoomName)
|
||||
}
|
||||
// Config should have lobby enabled for training
|
||||
if !strings.Contains(link.JoinURL, "enableLobby=true") {
|
||||
t.Error("Expected training to have lobby enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateQuickMeeting_ReturnsSimpleMeeting(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
link, err := service.CreateQuickMeeting(context.Background(), "Quick User")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if link.RoomName == "" {
|
||||
t.Error("Expected room name to be generated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateParentTeacherMeeting_ReturnsSecureMeeting(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
scheduledTime := time.Now().Add(24 * time.Hour)
|
||||
|
||||
link, err := service.CreateParentTeacherMeeting(
|
||||
context.Background(),
|
||||
"Frau Müller",
|
||||
"Herr Schmidt",
|
||||
"Max Mustermann",
|
||||
scheduledTime,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(link.RoomName, "elterngespraech") {
|
||||
t.Errorf("Expected room name to contain 'elterngespraech', got '%s'", link.RoomName)
|
||||
}
|
||||
if link.Password == "" {
|
||||
t.Error("Expected password for parent-teacher meeting")
|
||||
}
|
||||
if !strings.Contains(link.JoinURL, "enableLobby=true") {
|
||||
t.Error("Expected lobby to be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateClassMeeting_ReturnsMeetingForClass(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
link, err := service.CreateClassMeeting(
|
||||
context.Background(),
|
||||
"5a",
|
||||
"Herr Lehrer",
|
||||
"Mathematik",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(link.RoomName, "klasse-5a") {
|
||||
t.Errorf("Expected room name to contain 'klasse-5a', got '%s'", link.RoomName)
|
||||
}
|
||||
// Students should be muted by default
|
||||
if !strings.Contains(link.JoinURL, "startWithAudioMuted=true") {
|
||||
t.Error("Expected students to start muted")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: JWT Generation
|
||||
// ========================================
|
||||
|
||||
func TestGenerateJWT_ValidClaims_ReturnsValidToken(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
DisplayName: "Test User",
|
||||
Email: "test@example.com",
|
||||
Moderator: true,
|
||||
Duration: 60,
|
||||
}
|
||||
|
||||
token, expiresAt, err := service.generateJWT(meeting, "test-room")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Error("Expected token to be generated")
|
||||
}
|
||||
if expiresAt == nil {
|
||||
t.Error("Expected expiration time to be set")
|
||||
}
|
||||
|
||||
// Verify token structure (header.payload.signature)
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode and verify payload
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode payload: %v", err)
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(payloadJSON, &claims); err != nil {
|
||||
t.Fatalf("Failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.Room != "test-room" {
|
||||
t.Errorf("Expected room 'test-room', got '%s'", claims.Room)
|
||||
}
|
||||
if !claims.Moderator {
|
||||
t.Error("Expected moderator to be true")
|
||||
}
|
||||
if claims.Context == nil || claims.Context.User == nil {
|
||||
t.Error("Expected user context to be set")
|
||||
}
|
||||
if claims.Context.User.Name != "Test User" {
|
||||
t.Errorf("Expected user name 'Test User', got '%s'", claims.Context.User.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJWT_WithFeatures_IncludesFeatures(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
meeting := Meeting{
|
||||
RoomName: "test-room",
|
||||
Features: &MeetingFeatures{
|
||||
Recording: true,
|
||||
Transcription: true,
|
||||
},
|
||||
}
|
||||
|
||||
token, _, err := service.generateJWT(meeting, "test-room")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
payloadJSON, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
|
||||
var claims JWTClaims
|
||||
json.Unmarshal(payloadJSON, &claims)
|
||||
|
||||
if claims.Features == nil {
|
||||
t.Error("Expected features to be set")
|
||||
}
|
||||
if claims.Features.Recording != "true" {
|
||||
t.Errorf("Expected recording 'true', got '%s'", claims.Features.Recording)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJWT_NoSecret_ReturnsError(t *testing.T) {
|
||||
service := createTestServiceWithoutAuth()
|
||||
|
||||
meeting := Meeting{RoomName: "test"}
|
||||
|
||||
_, _, err := service.generateJWT(meeting, "test")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when secret is not configured")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Health Check
|
||||
// ========================================
|
||||
|
||||
func TestHealthCheck_ServerAvailable_ReturnsNil(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
service := NewJitsiService(Config{BaseURL: server.URL})
|
||||
|
||||
err := service.HealthCheck(context.Background())
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheck_ServerError_ReturnsError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
service := NewJitsiService(Config{BaseURL: server.URL})
|
||||
|
||||
err := service.HealthCheck(context.Background())
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for server error response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheck_ServerUnreachable_ReturnsError(t *testing.T) {
|
||||
service := NewJitsiService(Config{BaseURL: "http://localhost:59999"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := service.HealthCheck(ctx)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for unreachable server")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: URL Building
|
||||
// ========================================
|
||||
|
||||
func TestBuildEmbedURL_BasicRoom_ReturnsURL(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
url := service.BuildEmbedURL("test-room", "", nil)
|
||||
|
||||
if url != "http://localhost:8443/test-room" {
|
||||
t.Errorf("Expected 'http://localhost:8443/test-room', got '%s'", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEmbedURL_WithDisplayName_IncludesParam(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
url := service.BuildEmbedURL("test-room", "Max Mustermann", nil)
|
||||
|
||||
if !strings.Contains(url, "displayName=Max") {
|
||||
t.Errorf("Expected URL to contain display name, got '%s'", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEmbedURL_WithConfig_IncludesParams(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
config := &MeetingConfig{
|
||||
StartWithAudioMuted: true,
|
||||
StartWithVideoMuted: true,
|
||||
}
|
||||
|
||||
url := service.BuildEmbedURL("test-room", "", config)
|
||||
|
||||
if !strings.Contains(url, "startWithAudioMuted=true") {
|
||||
t.Error("Expected URL to contain audio muted config")
|
||||
}
|
||||
if !strings.Contains(url, "startWithVideoMuted=true") {
|
||||
t.Error("Expected URL to contain video muted config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildIFrameCode_DefaultSize_Returns800x600(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
code := service.BuildIFrameCode("test-room", 0, 0)
|
||||
|
||||
if !strings.Contains(code, "width=\"800\"") {
|
||||
t.Error("Expected default width 800")
|
||||
}
|
||||
if !strings.Contains(code, "height=\"600\"") {
|
||||
t.Error("Expected default height 600")
|
||||
}
|
||||
if !strings.Contains(code, "test-room") {
|
||||
t.Error("Expected room name in iframe")
|
||||
}
|
||||
if !strings.Contains(code, "allow=\"camera; microphone") {
|
||||
t.Error("Expected camera/microphone permissions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildIFrameCode_CustomSize_ReturnsCorrectDimensions(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
code := service.BuildIFrameCode("test-room", 1920, 1080)
|
||||
|
||||
if !strings.Contains(code, "width=\"1920\"") {
|
||||
t.Error("Expected width 1920")
|
||||
}
|
||||
if !strings.Contains(code, "height=\"1080\"") {
|
||||
t.Error("Expected height 1080")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Server Info
|
||||
// ========================================
|
||||
|
||||
func TestGetServerInfo_ReturnsInfo(t *testing.T) {
|
||||
service := createTestService()
|
||||
|
||||
info := service.GetServerInfo()
|
||||
|
||||
if info["base_url"] != "http://localhost:8443" {
|
||||
t.Errorf("Expected base_url, got '%s'", info["base_url"])
|
||||
}
|
||||
if info["app_id"] != "breakpilot" {
|
||||
t.Errorf("Expected app_id 'breakpilot', got '%s'", info["app_id"])
|
||||
}
|
||||
if info["auth_enabled"] != "true" {
|
||||
t.Errorf("Expected auth_enabled 'true', got '%s'", info["auth_enabled"])
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Helper Functions
|
||||
// ========================================
|
||||
|
||||
func TestBoolToString_True_ReturnsTrue(t *testing.T) {
|
||||
result := boolToString(true)
|
||||
if result != "true" {
|
||||
t.Errorf("Expected 'true', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolToString_False_ReturnsFalse(t *testing.T) {
|
||||
result := boolToString(false)
|
||||
if result != "false" {
|
||||
t.Errorf("Expected 'false', got '%s'", result)
|
||||
}
|
||||
}
|
||||
368
consent-service/internal/services/matrix/game_rooms.go
Normal file
368
consent-service/internal/services/matrix/game_rooms.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Breakpilot Drive Game Room Types
|
||||
// ========================================
|
||||
|
||||
// GameMode represents different multiplayer game modes
|
||||
type GameMode string
|
||||
|
||||
const (
|
||||
GameModeSolo GameMode = "solo"
|
||||
GameModeCoop GameMode = "coop" // 2 players, same track
|
||||
GameModeChallenge GameMode = "challenge" // 1v1 competition
|
||||
GameModeClassRace GameMode = "class_race" // Whole class competition
|
||||
)
|
||||
|
||||
// GameRoomConfig holds configuration for game rooms
|
||||
type GameRoomConfig struct {
|
||||
GameMode GameMode `json:"game_mode"`
|
||||
SessionID string `json:"session_id"`
|
||||
HostUserID string `json:"host_user_id"`
|
||||
HostName string `json:"host_name"`
|
||||
ClassName string `json:"class_name,omitempty"`
|
||||
MaxPlayers int `json:"max_players,omitempty"`
|
||||
TeacherIDs []string `json:"teacher_ids,omitempty"`
|
||||
EnableVoice bool `json:"enable_voice,omitempty"`
|
||||
}
|
||||
|
||||
// GameRoom represents an active game room
|
||||
type GameRoom struct {
|
||||
RoomID string `json:"room_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
GameMode GameMode `json:"game_mode"`
|
||||
HostUserID string `json:"host_user_id"`
|
||||
Players []string `json:"players"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// GameEvent represents game events to broadcast
|
||||
type GameEvent struct {
|
||||
Type string `json:"type"`
|
||||
SessionID string `json:"session_id"`
|
||||
PlayerID string `json:"player_id"`
|
||||
Data interface{} `json:"data"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// GameEventType constants
|
||||
const (
|
||||
GameEventPlayerJoined = "player_joined"
|
||||
GameEventPlayerLeft = "player_left"
|
||||
GameEventGameStarted = "game_started"
|
||||
GameEventQuizAnswered = "quiz_answered"
|
||||
GameEventScoreUpdate = "score_update"
|
||||
GameEventAchievement = "achievement"
|
||||
GameEventChallengeWon = "challenge_won"
|
||||
GameEventRaceFinished = "race_finished"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Game Room Management
|
||||
// ========================================
|
||||
|
||||
// CreateGameTeamRoom creates a private room for 2-4 players (Co-Op mode)
|
||||
func (s *MatrixService) CreateGameTeamRoom(ctx context.Context, config GameRoomConfig) (*CreateRoomResponse, error) {
|
||||
roomName := fmt.Sprintf("Breakpilot Drive - Team %s", config.SessionID[:8])
|
||||
topic := "Co-Op Spielsession - Arbeitet zusammen!"
|
||||
|
||||
// All players can write
|
||||
users := make(map[string]int)
|
||||
users[s.GenerateUserID(config.HostUserID)] = 50
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: roomName,
|
||||
Topic: topic,
|
||||
Visibility: "private",
|
||||
Preset: "private_chat",
|
||||
InitialState: []StateEvent{
|
||||
{
|
||||
Type: "m.room.encryption",
|
||||
StateKey: "",
|
||||
Content: map[string]string{
|
||||
"algorithm": "m.megolm.v1.aes-sha2",
|
||||
},
|
||||
},
|
||||
// Custom game state
|
||||
{
|
||||
Type: "breakpilot.game.session",
|
||||
StateKey: "",
|
||||
Content: map[string]interface{}{
|
||||
"session_id": config.SessionID,
|
||||
"game_mode": string(config.GameMode),
|
||||
"host_id": config.HostUserID,
|
||||
"created_at": time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
PowerLevelContentOverride: &PowerLevels{
|
||||
EventsDefault: 0, // All players can send messages
|
||||
UsersDefault: 50,
|
||||
Users: users,
|
||||
Events: map[string]int{
|
||||
"breakpilot.game.event": 0, // Anyone can send game events
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateRoom(ctx, req)
|
||||
}
|
||||
|
||||
// CreateGameChallengeRoom creates a 1v1 challenge room
|
||||
func (s *MatrixService) CreateGameChallengeRoom(ctx context.Context, config GameRoomConfig, challengerID string, opponentID string) (*CreateRoomResponse, error) {
|
||||
roomName := fmt.Sprintf("Challenge: %s", config.SessionID[:8])
|
||||
topic := "1v1 Wettbewerb - Möge der Bessere gewinnen!"
|
||||
|
||||
allPlayers := []string{
|
||||
s.GenerateUserID(challengerID),
|
||||
s.GenerateUserID(opponentID),
|
||||
}
|
||||
|
||||
users := make(map[string]int)
|
||||
for _, id := range allPlayers {
|
||||
users[id] = 50
|
||||
}
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: roomName,
|
||||
Topic: topic,
|
||||
Visibility: "private",
|
||||
Preset: "private_chat",
|
||||
Invite: allPlayers,
|
||||
InitialState: []StateEvent{
|
||||
{
|
||||
Type: "breakpilot.game.session",
|
||||
StateKey: "",
|
||||
Content: map[string]interface{}{
|
||||
"session_id": config.SessionID,
|
||||
"game_mode": string(GameModeChallenge),
|
||||
"challenger_id": challengerID,
|
||||
"opponent_id": opponentID,
|
||||
"created_at": time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
PowerLevelContentOverride: &PowerLevels{
|
||||
EventsDefault: 0,
|
||||
UsersDefault: 50,
|
||||
Users: users,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateRoom(ctx, req)
|
||||
}
|
||||
|
||||
// CreateGameClassRaceRoom creates a room for class-wide competition
|
||||
func (s *MatrixService) CreateGameClassRaceRoom(ctx context.Context, config GameRoomConfig) (*CreateRoomResponse, error) {
|
||||
roomName := fmt.Sprintf("Klassenrennen: %s", config.ClassName)
|
||||
topic := fmt.Sprintf("Klassenrennen der %s - Alle gegen alle!", config.ClassName)
|
||||
|
||||
// Teachers get moderator power level
|
||||
users := make(map[string]int)
|
||||
for _, teacherID := range config.TeacherIDs {
|
||||
users[s.GenerateUserID(teacherID)] = 100
|
||||
}
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: roomName,
|
||||
Topic: topic,
|
||||
Visibility: "private",
|
||||
Preset: "private_chat",
|
||||
InitialState: []StateEvent{
|
||||
{
|
||||
Type: "breakpilot.game.session",
|
||||
StateKey: "",
|
||||
Content: map[string]interface{}{
|
||||
"session_id": config.SessionID,
|
||||
"game_mode": string(GameModeClassRace),
|
||||
"class_name": config.ClassName,
|
||||
"teacher_ids": config.TeacherIDs,
|
||||
"created_at": time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
PowerLevelContentOverride: &PowerLevels{
|
||||
EventsDefault: 0, // Students can send messages
|
||||
UsersDefault: 10, // Default student level
|
||||
Users: users,
|
||||
Invite: 100, // Only teachers can invite
|
||||
Kick: 100, // Only teachers can kick
|
||||
Events: map[string]int{
|
||||
"breakpilot.game.event": 0, // Anyone can send game events
|
||||
"breakpilot.game.leaderboard": 100, // Only teachers update leaderboard
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateRoom(ctx, req)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Game Event Broadcasting
|
||||
// ========================================
|
||||
|
||||
// SendGameEvent sends a game event to a room
|
||||
func (s *MatrixService) SendGameEvent(ctx context.Context, roomID string, event GameEvent) error {
|
||||
event.Timestamp = time.Now().UTC()
|
||||
|
||||
return s.sendEvent(ctx, roomID, "breakpilot.game.event", event)
|
||||
}
|
||||
|
||||
// SendPlayerJoinedEvent notifies room that a player joined
|
||||
func (s *MatrixService) SendPlayerJoinedEvent(ctx context.Context, roomID string, sessionID string, playerID string, playerName string) error {
|
||||
event := GameEvent{
|
||||
Type: GameEventPlayerJoined,
|
||||
SessionID: sessionID,
|
||||
PlayerID: playerID,
|
||||
Data: map[string]string{
|
||||
"player_name": playerName,
|
||||
},
|
||||
}
|
||||
|
||||
// Also send a visible message
|
||||
msg := fmt.Sprintf("🎮 %s ist dem Spiel beigetreten!", playerName)
|
||||
if err := s.SendMessage(ctx, roomID, msg); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Printf("Warning: failed to send join message: %v\n", err)
|
||||
}
|
||||
|
||||
return s.SendGameEvent(ctx, roomID, event)
|
||||
}
|
||||
|
||||
// SendScoreUpdateEvent broadcasts score updates
|
||||
func (s *MatrixService) SendScoreUpdateEvent(ctx context.Context, roomID string, sessionID string, playerID string, score int, accuracy float64) error {
|
||||
event := GameEvent{
|
||||
Type: GameEventScoreUpdate,
|
||||
SessionID: sessionID,
|
||||
PlayerID: playerID,
|
||||
Data: map[string]interface{}{
|
||||
"score": score,
|
||||
"accuracy": accuracy,
|
||||
},
|
||||
}
|
||||
|
||||
return s.SendGameEvent(ctx, roomID, event)
|
||||
}
|
||||
|
||||
// SendQuizAnsweredEvent broadcasts when a player answers a quiz
|
||||
func (s *MatrixService) SendQuizAnsweredEvent(ctx context.Context, roomID string, sessionID string, playerID string, correct bool, subject string) error {
|
||||
event := GameEvent{
|
||||
Type: GameEventQuizAnswered,
|
||||
SessionID: sessionID,
|
||||
PlayerID: playerID,
|
||||
Data: map[string]interface{}{
|
||||
"correct": correct,
|
||||
"subject": subject,
|
||||
},
|
||||
}
|
||||
|
||||
return s.SendGameEvent(ctx, roomID, event)
|
||||
}
|
||||
|
||||
// SendAchievementEvent broadcasts when a player earns an achievement
|
||||
func (s *MatrixService) SendAchievementEvent(ctx context.Context, roomID string, sessionID string, playerID string, achievementID string, achievementName string) error {
|
||||
event := GameEvent{
|
||||
Type: GameEventAchievement,
|
||||
SessionID: sessionID,
|
||||
PlayerID: playerID,
|
||||
Data: map[string]interface{}{
|
||||
"achievement_id": achievementID,
|
||||
"achievement_name": achievementName,
|
||||
},
|
||||
}
|
||||
|
||||
// Also send a visible celebration message
|
||||
msg := fmt.Sprintf("🏆 Erfolg freigeschaltet: %s!", achievementName)
|
||||
if err := s.SendMessage(ctx, roomID, msg); err != nil {
|
||||
fmt.Printf("Warning: failed to send achievement message: %v\n", err)
|
||||
}
|
||||
|
||||
return s.SendGameEvent(ctx, roomID, event)
|
||||
}
|
||||
|
||||
// SendChallengeWonEvent broadcasts challenge result
|
||||
func (s *MatrixService) SendChallengeWonEvent(ctx context.Context, roomID string, sessionID string, winnerID string, winnerName string, loserName string, winnerScore int, loserScore int) error {
|
||||
event := GameEvent{
|
||||
Type: GameEventChallengeWon,
|
||||
SessionID: sessionID,
|
||||
PlayerID: winnerID,
|
||||
Data: map[string]interface{}{
|
||||
"winner_name": winnerName,
|
||||
"loser_name": loserName,
|
||||
"winner_score": winnerScore,
|
||||
"loser_score": loserScore,
|
||||
},
|
||||
}
|
||||
|
||||
// Send celebration message
|
||||
msg := fmt.Sprintf("🎉 %s gewinnt gegen %s mit %d zu %d Punkten!", winnerName, loserName, winnerScore, loserScore)
|
||||
if err := s.SendHTMLMessage(ctx, roomID, msg, fmt.Sprintf("<h3>🎉 Challenge beendet!</h3><p><strong>%s</strong> gewinnt gegen %s</p><p>Endstand: %d : %d</p>", winnerName, loserName, winnerScore, loserScore)); err != nil {
|
||||
fmt.Printf("Warning: failed to send challenge result message: %v\n", err)
|
||||
}
|
||||
|
||||
return s.SendGameEvent(ctx, roomID, event)
|
||||
}
|
||||
|
||||
// SendClassRaceLeaderboard broadcasts current leaderboard in class race
|
||||
func (s *MatrixService) SendClassRaceLeaderboard(ctx context.Context, roomID string, sessionID string, leaderboard []map[string]interface{}) error {
|
||||
// Build leaderboard message
|
||||
msg := "🏁 Aktueller Stand:\n"
|
||||
htmlMsg := "<h3>🏁 Aktueller Stand</h3><ol>"
|
||||
|
||||
for i, entry := range leaderboard {
|
||||
if i >= 10 { // Top 10 only
|
||||
break
|
||||
}
|
||||
name := entry["name"].(string)
|
||||
score := entry["score"].(int)
|
||||
msg += fmt.Sprintf("%d. %s - %d Punkte\n", i+1, name, score)
|
||||
htmlMsg += fmt.Sprintf("<li><strong>%s</strong> - %d Punkte</li>", name, score)
|
||||
}
|
||||
htmlMsg += "</ol>"
|
||||
|
||||
return s.SendHTMLMessage(ctx, roomID, msg, htmlMsg)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Game Room Utilities
|
||||
// ========================================
|
||||
|
||||
// AddPlayerToGameRoom invites and sets up a player in a game room
|
||||
func (s *MatrixService) AddPlayerToGameRoom(ctx context.Context, roomID string, playerMatrixID string, playerName string) error {
|
||||
// Invite the player
|
||||
if err := s.InviteUser(ctx, roomID, playerMatrixID); err != nil {
|
||||
return fmt.Errorf("failed to invite player: %w", err)
|
||||
}
|
||||
|
||||
// Set display name if not already set
|
||||
if err := s.SetDisplayName(ctx, playerMatrixID, playerName); err != nil {
|
||||
// Log but don't fail - display name might already be set
|
||||
fmt.Printf("Warning: failed to set display name: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseGameRoom sends end message and archives the room
|
||||
func (s *MatrixService) CloseGameRoom(ctx context.Context, roomID string, sessionID string) error {
|
||||
// Send closing message
|
||||
msg := "🏁 Spiel beendet! Danke fürs Mitspielen. Dieser Raum wird archiviert."
|
||||
if err := s.SendMessage(ctx, roomID, msg); err != nil {
|
||||
return fmt.Errorf("failed to send closing message: %w", err)
|
||||
}
|
||||
|
||||
// Update room state to mark as closed
|
||||
closeEvent := map[string]interface{}{
|
||||
"closed": true,
|
||||
"closed_at": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return s.sendEvent(ctx, roomID, "breakpilot.game.closed", closeEvent)
|
||||
}
|
||||
548
consent-service/internal/services/matrix/matrix_service.go
Normal file
548
consent-service/internal/services/matrix/matrix_service.go
Normal file
@@ -0,0 +1,548 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MatrixService handles Matrix homeserver communication
|
||||
type MatrixService struct {
|
||||
homeserverURL string
|
||||
accessToken string
|
||||
serverName string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Config holds Matrix service configuration
|
||||
type Config struct {
|
||||
HomeserverURL string // e.g., "http://synapse:8008"
|
||||
AccessToken string // Admin/bot access token
|
||||
ServerName string // e.g., "breakpilot.local"
|
||||
}
|
||||
|
||||
// NewMatrixService creates a new Matrix service instance
|
||||
func NewMatrixService(cfg Config) *MatrixService {
|
||||
return &MatrixService{
|
||||
homeserverURL: cfg.HomeserverURL,
|
||||
accessToken: cfg.AccessToken,
|
||||
serverName: cfg.ServerName,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Matrix API Types
|
||||
// ========================================
|
||||
|
||||
// CreateRoomRequest represents a Matrix room creation request
|
||||
type CreateRoomRequest struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
RoomAliasName string `json:"room_alias_name,omitempty"`
|
||||
Topic string `json:"topic,omitempty"`
|
||||
Visibility string `json:"visibility,omitempty"` // "private" or "public"
|
||||
Preset string `json:"preset,omitempty"` // "private_chat", "public_chat", "trusted_private_chat"
|
||||
IsDirect bool `json:"is_direct,omitempty"`
|
||||
Invite []string `json:"invite,omitempty"`
|
||||
InitialState []StateEvent `json:"initial_state,omitempty"`
|
||||
PowerLevelContentOverride *PowerLevels `json:"power_level_content_override,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRoomResponse represents a Matrix room creation response
|
||||
type CreateRoomResponse struct {
|
||||
RoomID string `json:"room_id"`
|
||||
}
|
||||
|
||||
// StateEvent represents a Matrix state event
|
||||
type StateEvent struct {
|
||||
Type string `json:"type"`
|
||||
StateKey string `json:"state_key"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
// PowerLevels represents Matrix power levels
|
||||
type PowerLevels struct {
|
||||
Ban int `json:"ban,omitempty"`
|
||||
Events map[string]int `json:"events,omitempty"`
|
||||
EventsDefault int `json:"events_default,omitempty"`
|
||||
Invite int `json:"invite,omitempty"`
|
||||
Kick int `json:"kick,omitempty"`
|
||||
Redact int `json:"redact,omitempty"`
|
||||
StateDefault int `json:"state_default,omitempty"`
|
||||
Users map[string]int `json:"users,omitempty"`
|
||||
UsersDefault int `json:"users_default,omitempty"`
|
||||
}
|
||||
|
||||
// SendMessageRequest represents a message to send
|
||||
type SendMessageRequest struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Body string `json:"body"`
|
||||
Format string `json:"format,omitempty"`
|
||||
FormattedBody string `json:"formatted_body,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo represents Matrix user information
|
||||
type UserInfo struct {
|
||||
UserID string `json:"user_id"`
|
||||
DisplayName string `json:"displayname,omitempty"`
|
||||
AvatarURL string `json:"avatar_url,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterRequest for user registration
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Admin bool `json:"admin,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterResponse for user registration
|
||||
type RegisterResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
|
||||
// InviteRequest for inviting a user to a room
|
||||
type InviteRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// JoinRequest for joining a room
|
||||
type JoinRequest struct {
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Room Management
|
||||
// ========================================
|
||||
|
||||
// CreateRoom creates a new Matrix room
|
||||
func (s *MatrixService) CreateRoom(ctx context.Context, req CreateRoomRequest) (*CreateRoomResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.doRequest(ctx, "POST", "/_matrix/client/v3/createRoom", body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create room: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, s.parseError(resp)
|
||||
}
|
||||
|
||||
var result CreateRoomResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CreateClassInfoRoom creates a broadcast room for a class (teachers write, parents read)
|
||||
func (s *MatrixService) CreateClassInfoRoom(ctx context.Context, className string, schoolName string, teacherMatrixIDs []string) (*CreateRoomResponse, error) {
|
||||
// Set up power levels: teachers can write (50), parents read-only (0)
|
||||
users := make(map[string]int)
|
||||
for _, teacherID := range teacherMatrixIDs {
|
||||
users[teacherID] = 50
|
||||
}
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: fmt.Sprintf("%s - %s (Info)", className, schoolName),
|
||||
Topic: fmt.Sprintf("Info-Kanal für %s. Nur Lehrer können schreiben.", className),
|
||||
Visibility: "private",
|
||||
Preset: "private_chat",
|
||||
Invite: teacherMatrixIDs,
|
||||
PowerLevelContentOverride: &PowerLevels{
|
||||
EventsDefault: 50, // Only power level 50+ can send messages
|
||||
UsersDefault: 0, // Parents get power level 0 by default
|
||||
Users: users,
|
||||
Invite: 50,
|
||||
Kick: 50,
|
||||
Ban: 50,
|
||||
Redact: 50,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateRoom(ctx, req)
|
||||
}
|
||||
|
||||
// CreateStudentDMRoom creates a direct message room for parent-teacher communication about a student
|
||||
func (s *MatrixService) CreateStudentDMRoom(ctx context.Context, studentName string, className string, teacherMatrixIDs []string, parentMatrixIDs []string) (*CreateRoomResponse, error) {
|
||||
allUsers := append(teacherMatrixIDs, parentMatrixIDs...)
|
||||
|
||||
users := make(map[string]int)
|
||||
for _, id := range allUsers {
|
||||
users[id] = 50 // All can write
|
||||
}
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: fmt.Sprintf("%s (%s) - Dialog", studentName, className),
|
||||
Topic: fmt.Sprintf("Kommunikation über %s", studentName),
|
||||
Visibility: "private",
|
||||
Preset: "trusted_private_chat",
|
||||
IsDirect: false,
|
||||
Invite: allUsers,
|
||||
InitialState: []StateEvent{
|
||||
{
|
||||
Type: "m.room.encryption",
|
||||
StateKey: "",
|
||||
Content: map[string]string{
|
||||
"algorithm": "m.megolm.v1.aes-sha2",
|
||||
},
|
||||
},
|
||||
},
|
||||
PowerLevelContentOverride: &PowerLevels{
|
||||
EventsDefault: 0, // Everyone can send messages
|
||||
UsersDefault: 50,
|
||||
Users: users,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateRoom(ctx, req)
|
||||
}
|
||||
|
||||
// CreateParentRepRoom creates a room for class teacher and parent representatives
|
||||
func (s *MatrixService) CreateParentRepRoom(ctx context.Context, className string, teacherMatrixIDs []string, parentRepMatrixIDs []string) (*CreateRoomResponse, error) {
|
||||
allUsers := append(teacherMatrixIDs, parentRepMatrixIDs...)
|
||||
|
||||
users := make(map[string]int)
|
||||
for _, id := range allUsers {
|
||||
users[id] = 50
|
||||
}
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: fmt.Sprintf("%s - Elternvertreter", className),
|
||||
Topic: fmt.Sprintf("Kommunikation zwischen Lehrkräften und Elternvertretern der %s", className),
|
||||
Visibility: "private",
|
||||
Preset: "private_chat",
|
||||
Invite: allUsers,
|
||||
PowerLevelContentOverride: &PowerLevels{
|
||||
EventsDefault: 0,
|
||||
UsersDefault: 50,
|
||||
Users: users,
|
||||
},
|
||||
}
|
||||
|
||||
return s.CreateRoom(ctx, req)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// User Management
|
||||
// ========================================
|
||||
|
||||
// RegisterUser registers a new Matrix user (requires admin token)
|
||||
func (s *MatrixService) RegisterUser(ctx context.Context, username string, displayName string) (*RegisterResponse, error) {
|
||||
// Use admin API for user registration
|
||||
req := map[string]interface{}{
|
||||
"username": username,
|
||||
"password": uuid.New().String(), // Generate random password
|
||||
"admin": false,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// Use admin registration endpoint
|
||||
resp, err := s.doRequest(ctx, "POST", "/_synapse/admin/v2/users/@"+username+":"+s.serverName, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register user: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, s.parseError(resp)
|
||||
}
|
||||
|
||||
// Set display name
|
||||
if displayName != "" {
|
||||
if err := s.SetDisplayName(ctx, "@"+username+":"+s.serverName, displayName); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Printf("Warning: failed to set display name: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &RegisterResponse{
|
||||
UserID: "@" + username + ":" + s.serverName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetDisplayName sets the display name for a user
|
||||
func (s *MatrixService) SetDisplayName(ctx context.Context, userID string, displayName string) error {
|
||||
req := map[string]string{
|
||||
"displayname": displayName,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/_matrix/client/v3/profile/%s/displayname", url.PathEscape(userID))
|
||||
resp, err := s.doRequest(ctx, "PUT", endpoint, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set display name: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.parseError(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Room Membership
|
||||
// ========================================
|
||||
|
||||
// InviteUser invites a user to a room
|
||||
func (s *MatrixService) InviteUser(ctx context.Context, roomID string, userID string) error {
|
||||
req := InviteRequest{UserID: userID}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/_matrix/client/v3/rooms/%s/invite", url.PathEscape(roomID))
|
||||
resp, err := s.doRequest(ctx, "POST", endpoint, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invite user: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.parseError(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// JoinRoom makes the bot join a room
|
||||
func (s *MatrixService) JoinRoom(ctx context.Context, roomIDOrAlias string) error {
|
||||
endpoint := fmt.Sprintf("/_matrix/client/v3/join/%s", url.PathEscape(roomIDOrAlias))
|
||||
resp, err := s.doRequest(ctx, "POST", endpoint, []byte("{}"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to join room: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.parseError(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetUserPowerLevel sets a user's power level in a room
|
||||
func (s *MatrixService) SetUserPowerLevel(ctx context.Context, roomID string, userID string, powerLevel int) error {
|
||||
// First, get current power levels
|
||||
endpoint := fmt.Sprintf("/_matrix/client/v3/rooms/%s/state/m.room.power_levels/", url.PathEscape(roomID))
|
||||
resp, err := s.doRequest(ctx, "GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get power levels: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.parseError(resp)
|
||||
}
|
||||
|
||||
var powerLevels PowerLevels
|
||||
if err := json.NewDecoder(resp.Body).Decode(&powerLevels); err != nil {
|
||||
return fmt.Errorf("failed to decode power levels: %w", err)
|
||||
}
|
||||
|
||||
// Update user power level
|
||||
if powerLevels.Users == nil {
|
||||
powerLevels.Users = make(map[string]int)
|
||||
}
|
||||
powerLevels.Users[userID] = powerLevel
|
||||
|
||||
// Send updated power levels
|
||||
body, err := json.Marshal(powerLevels)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal power levels: %w", err)
|
||||
}
|
||||
|
||||
resp2, err := s.doRequest(ctx, "PUT", endpoint, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set power levels: %w", err)
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
|
||||
if resp2.StatusCode != http.StatusOK {
|
||||
return s.parseError(resp2)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Messaging
|
||||
// ========================================
|
||||
|
||||
// SendMessage sends a text message to a room
|
||||
func (s *MatrixService) SendMessage(ctx context.Context, roomID string, message string) error {
|
||||
req := SendMessageRequest{
|
||||
MsgType: "m.text",
|
||||
Body: message,
|
||||
}
|
||||
|
||||
return s.sendEvent(ctx, roomID, "m.room.message", req)
|
||||
}
|
||||
|
||||
// SendHTMLMessage sends an HTML-formatted message to a room
|
||||
func (s *MatrixService) SendHTMLMessage(ctx context.Context, roomID string, plainText string, htmlBody string) error {
|
||||
req := SendMessageRequest{
|
||||
MsgType: "m.text",
|
||||
Body: plainText,
|
||||
Format: "org.matrix.custom.html",
|
||||
FormattedBody: htmlBody,
|
||||
}
|
||||
|
||||
return s.sendEvent(ctx, roomID, "m.room.message", req)
|
||||
}
|
||||
|
||||
// SendAbsenceNotification sends an absence notification to parents
|
||||
func (s *MatrixService) SendAbsenceNotification(ctx context.Context, roomID string, studentName string, date string, lessonNumber int) error {
|
||||
plainText := fmt.Sprintf("⚠️ Abwesenheitsmeldung\n\nIhr Kind %s war heute (%s) in der %d. Stunde nicht im Unterricht anwesend.\n\nBitte bestätigen Sie den Grund der Abwesenheit.", studentName, date, lessonNumber)
|
||||
|
||||
htmlBody := fmt.Sprintf(`<h3>⚠️ Abwesenheitsmeldung</h3>
|
||||
<p>Ihr Kind <strong>%s</strong> war heute (%s) in der <strong>%d. Stunde</strong> nicht im Unterricht anwesend.</p>
|
||||
<p>Bitte bestätigen Sie den Grund der Abwesenheit.</p>
|
||||
<ul>
|
||||
<li>✅ Entschuldigt (Krankheit)</li>
|
||||
<li>📋 Arztbesuch</li>
|
||||
<li>❓ Sonstiges (bitte erläutern)</li>
|
||||
</ul>`, studentName, date, lessonNumber)
|
||||
|
||||
return s.SendHTMLMessage(ctx, roomID, plainText, htmlBody)
|
||||
}
|
||||
|
||||
// SendGradeNotification sends a grade notification to parents
|
||||
func (s *MatrixService) SendGradeNotification(ctx context.Context, roomID string, studentName string, subject string, gradeType string, grade float64) error {
|
||||
plainText := fmt.Sprintf("📊 Neue Note eingetragen\n\nFür %s wurde eine neue Note eingetragen:\n\nFach: %s\nArt: %s\nNote: %.1f", studentName, subject, gradeType, grade)
|
||||
|
||||
htmlBody := fmt.Sprintf(`<h3>📊 Neue Note eingetragen</h3>
|
||||
<p>Für <strong>%s</strong> wurde eine neue Note eingetragen:</p>
|
||||
<table>
|
||||
<tr><td>Fach:</td><td><strong>%s</strong></td></tr>
|
||||
<tr><td>Art:</td><td>%s</td></tr>
|
||||
<tr><td>Note:</td><td><strong>%.1f</strong></td></tr>
|
||||
</table>`, studentName, subject, gradeType, grade)
|
||||
|
||||
return s.SendHTMLMessage(ctx, roomID, plainText, htmlBody)
|
||||
}
|
||||
|
||||
// SendClassAnnouncement sends an announcement to a class info room
|
||||
func (s *MatrixService) SendClassAnnouncement(ctx context.Context, roomID string, title string, content string, teacherName string) error {
|
||||
plainText := fmt.Sprintf("📢 %s\n\n%s\n\n— %s", title, content, teacherName)
|
||||
|
||||
htmlBody := fmt.Sprintf(`<h3>📢 %s</h3>
|
||||
<p>%s</p>
|
||||
<p><em>— %s</em></p>`, title, content, teacherName)
|
||||
|
||||
return s.SendHTMLMessage(ctx, roomID, plainText, htmlBody)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Internal Helpers
|
||||
// ========================================
|
||||
|
||||
func (s *MatrixService) sendEvent(ctx context.Context, roomID string, eventType string, content interface{}) error {
|
||||
body, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal content: %w", err)
|
||||
}
|
||||
|
||||
txnID := uuid.New().String()
|
||||
endpoint := fmt.Sprintf("/_matrix/client/v3/rooms/%s/send/%s/%s",
|
||||
url.PathEscape(roomID), url.PathEscape(eventType), txnID)
|
||||
|
||||
resp, err := s.doRequest(ctx, "PUT", endpoint, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send event: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.parseError(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MatrixService) doRequest(ctx context.Context, method string, endpoint string, body []byte) (*http.Response, error) {
|
||||
fullURL := s.homeserverURL + endpoint
|
||||
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+s.accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return s.httpClient.Do(req)
|
||||
}
|
||||
|
||||
func (s *MatrixService) parseError(resp *http.Response) error {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var errResp struct {
|
||||
ErrCode string `json:"errcode"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &errResp); err != nil {
|
||||
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return fmt.Errorf("matrix error %s: %s", errResp.ErrCode, errResp.Error)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Health Check
|
||||
// ========================================
|
||||
|
||||
// HealthCheck checks if the Matrix server is reachable
|
||||
func (s *MatrixService) HealthCheck(ctx context.Context) error {
|
||||
resp, err := s.doRequest(ctx, "GET", "/_matrix/client/versions", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("matrix server unreachable: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("matrix server returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerName returns the configured server name
|
||||
func (s *MatrixService) GetServerName() string {
|
||||
return s.serverName
|
||||
}
|
||||
|
||||
// GenerateUserID generates a Matrix user ID from a username
|
||||
func (s *MatrixService) GenerateUserID(username string) string {
|
||||
return fmt.Sprintf("@%s:%s", username, s.serverName)
|
||||
}
|
||||
791
consent-service/internal/services/matrix/matrix_service_test.go
Normal file
791
consent-service/internal/services/matrix/matrix_service_test.go
Normal file
@@ -0,0 +1,791 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Test Helpers
|
||||
// ========================================
|
||||
|
||||
// createTestServer creates a mock Matrix server for testing
|
||||
func createTestServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, *MatrixService) {
|
||||
server := httptest.NewServer(handler)
|
||||
service := NewMatrixService(Config{
|
||||
HomeserverURL: server.URL,
|
||||
AccessToken: "test-access-token",
|
||||
ServerName: "test.local",
|
||||
})
|
||||
return server, service
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Service Creation
|
||||
// ========================================
|
||||
|
||||
func TestNewMatrixService_ValidConfig_CreatesService(t *testing.T) {
|
||||
cfg := Config{
|
||||
HomeserverURL: "http://localhost:8008",
|
||||
AccessToken: "test-token",
|
||||
ServerName: "breakpilot.local",
|
||||
}
|
||||
|
||||
service := NewMatrixService(cfg)
|
||||
|
||||
if service == nil {
|
||||
t.Fatal("Expected service to be created, got nil")
|
||||
}
|
||||
if service.homeserverURL != cfg.HomeserverURL {
|
||||
t.Errorf("Expected homeserverURL %s, got %s", cfg.HomeserverURL, service.homeserverURL)
|
||||
}
|
||||
if service.accessToken != cfg.AccessToken {
|
||||
t.Errorf("Expected accessToken %s, got %s", cfg.AccessToken, service.accessToken)
|
||||
}
|
||||
if service.serverName != cfg.ServerName {
|
||||
t.Errorf("Expected serverName %s, got %s", cfg.ServerName, service.serverName)
|
||||
}
|
||||
if service.httpClient == nil {
|
||||
t.Error("Expected httpClient to be initialized")
|
||||
}
|
||||
if service.httpClient.Timeout != 30*time.Second {
|
||||
t.Errorf("Expected timeout 30s, got %v", service.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServerName_ReturnsConfiguredName(t *testing.T) {
|
||||
service := NewMatrixService(Config{
|
||||
HomeserverURL: "http://localhost:8008",
|
||||
AccessToken: "test-token",
|
||||
ServerName: "school.example.com",
|
||||
})
|
||||
|
||||
result := service.GetServerName()
|
||||
|
||||
if result != "school.example.com" {
|
||||
t.Errorf("Expected 'school.example.com', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateUserID_ValidUsername_ReturnsFormattedID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
username string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple username",
|
||||
serverName: "breakpilot.local",
|
||||
username: "max.mustermann",
|
||||
expected: "@max.mustermann:breakpilot.local",
|
||||
},
|
||||
{
|
||||
name: "teacher username",
|
||||
serverName: "school.de",
|
||||
username: "lehrer_mueller",
|
||||
expected: "@lehrer_mueller:school.de",
|
||||
},
|
||||
{
|
||||
name: "parent username with numbers",
|
||||
serverName: "test.local",
|
||||
username: "eltern123",
|
||||
expected: "@eltern123:test.local",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := NewMatrixService(Config{
|
||||
HomeserverURL: "http://localhost:8008",
|
||||
AccessToken: "test-token",
|
||||
ServerName: tt.serverName,
|
||||
})
|
||||
|
||||
result := service.GenerateUserID(tt.username)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Health Check
|
||||
// ========================================
|
||||
|
||||
func TestHealthCheck_ServerHealthy_ReturnsNil(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/_matrix/client/versions" {
|
||||
t.Errorf("Expected path /_matrix/client/versions, got %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("Expected GET method, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"versions": []string{"v1.1", "v1.2"},
|
||||
})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.HealthCheck(context.Background())
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheck_ServerUnreachable_ReturnsError(t *testing.T) {
|
||||
service := NewMatrixService(Config{
|
||||
HomeserverURL: "http://localhost:59999", // Non-existent server
|
||||
AccessToken: "test-token",
|
||||
ServerName: "test.local",
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := service.HealthCheck(ctx)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for unreachable server, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unreachable") {
|
||||
t.Errorf("Expected 'unreachable' in error message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheck_ServerReturns500_ReturnsError(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.HealthCheck(context.Background())
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for 500 response, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("Expected '500' in error message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Room Creation
|
||||
// ========================================
|
||||
|
||||
func TestCreateRoom_ValidRequest_ReturnsRoomID(t *testing.T) {
|
||||
expectedRoomID := "!abc123:test.local"
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/_matrix/client/v3/createRoom" {
|
||||
t.Errorf("Expected path /_matrix/client/v3/createRoom, got %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST method, got %s", r.Method)
|
||||
}
|
||||
|
||||
// Verify authorization header
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "Bearer test-access-token" {
|
||||
t.Errorf("Expected 'Bearer test-access-token', got '%s'", auth)
|
||||
}
|
||||
|
||||
// Verify content type
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected 'application/json', got '%s'", contentType)
|
||||
}
|
||||
|
||||
// Decode and verify request body
|
||||
var req CreateRoomRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("Failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
if req.Name != "Test Room" {
|
||||
t.Errorf("Expected name 'Test Room', got '%s'", req.Name)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(CreateRoomResponse{
|
||||
RoomID: expectedRoomID,
|
||||
})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
req := CreateRoomRequest{
|
||||
Name: "Test Room",
|
||||
Visibility: "private",
|
||||
}
|
||||
|
||||
result, err := service.CreateRoom(context.Background(), req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if result.RoomID != expectedRoomID {
|
||||
t.Errorf("Expected room ID '%s', got '%s'", expectedRoomID, result.RoomID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRoom_ServerError_ReturnsError(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"errcode": "M_FORBIDDEN",
|
||||
"error": "Not allowed to create rooms",
|
||||
})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
req := CreateRoomRequest{Name: "Test"}
|
||||
|
||||
_, err := service.CreateRoom(context.Background(), req)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "M_FORBIDDEN") {
|
||||
t.Errorf("Expected 'M_FORBIDDEN' in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateClassInfoRoom_ValidInput_CreatesRoomWithCorrectPowerLevels(t *testing.T) {
|
||||
var receivedRequest CreateRoomRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedRequest)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(CreateRoomResponse{RoomID: "!class:test.local"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
teacherIDs := []string{"@lehrer1:test.local", "@lehrer2:test.local"}
|
||||
result, err := service.CreateClassInfoRoom(context.Background(), "5a", "Grundschule Musterstadt", teacherIDs)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if result.RoomID != "!class:test.local" {
|
||||
t.Errorf("Expected room ID '!class:test.local', got '%s'", result.RoomID)
|
||||
}
|
||||
|
||||
// Verify room name format
|
||||
expectedName := "5a - Grundschule Musterstadt (Info)"
|
||||
if receivedRequest.Name != expectedName {
|
||||
t.Errorf("Expected name '%s', got '%s'", expectedName, receivedRequest.Name)
|
||||
}
|
||||
|
||||
// Verify power levels
|
||||
if receivedRequest.PowerLevelContentOverride == nil {
|
||||
t.Fatal("Expected power level override, got nil")
|
||||
}
|
||||
if receivedRequest.PowerLevelContentOverride.EventsDefault != 50 {
|
||||
t.Errorf("Expected EventsDefault 50, got %d", receivedRequest.PowerLevelContentOverride.EventsDefault)
|
||||
}
|
||||
if receivedRequest.PowerLevelContentOverride.UsersDefault != 0 {
|
||||
t.Errorf("Expected UsersDefault 0, got %d", receivedRequest.PowerLevelContentOverride.UsersDefault)
|
||||
}
|
||||
|
||||
// Verify teachers have power level 50
|
||||
for _, teacherID := range teacherIDs {
|
||||
if level, ok := receivedRequest.PowerLevelContentOverride.Users[teacherID]; !ok || level != 50 {
|
||||
t.Errorf("Expected teacher %s to have power level 50, got %d", teacherID, level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateStudentDMRoom_ValidInput_CreatesEncryptedRoom(t *testing.T) {
|
||||
var receivedRequest CreateRoomRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedRequest)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(CreateRoomResponse{RoomID: "!dm:test.local"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
teacherIDs := []string{"@lehrer:test.local"}
|
||||
parentIDs := []string{"@eltern1:test.local", "@eltern2:test.local"}
|
||||
|
||||
result, err := service.CreateStudentDMRoom(context.Background(), "Max Mustermann", "5a", teacherIDs, parentIDs)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if result.RoomID != "!dm:test.local" {
|
||||
t.Errorf("Expected room ID '!dm:test.local', got '%s'", result.RoomID)
|
||||
}
|
||||
|
||||
// Verify room name
|
||||
expectedName := "Max Mustermann (5a) - Dialog"
|
||||
if receivedRequest.Name != expectedName {
|
||||
t.Errorf("Expected name '%s', got '%s'", expectedName, receivedRequest.Name)
|
||||
}
|
||||
|
||||
// Verify encryption is enabled
|
||||
foundEncryption := false
|
||||
for _, state := range receivedRequest.InitialState {
|
||||
if state.Type == "m.room.encryption" {
|
||||
foundEncryption = true
|
||||
// Content comes as map[string]interface{} from JSON unmarshaling
|
||||
content, ok := state.Content.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected encryption content to be map[string]interface{}, got %T", state.Content)
|
||||
continue
|
||||
}
|
||||
if algo, ok := content["algorithm"].(string); !ok || algo != "m.megolm.v1.aes-sha2" {
|
||||
t.Errorf("Expected algorithm 'm.megolm.v1.aes-sha2', got '%v'", content["algorithm"])
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundEncryption {
|
||||
t.Error("Expected encryption state event, not found")
|
||||
}
|
||||
|
||||
// Verify all users are invited
|
||||
expectedInvites := append(teacherIDs, parentIDs...)
|
||||
for _, expected := range expectedInvites {
|
||||
found := false
|
||||
for _, invited := range receivedRequest.Invite {
|
||||
if invited == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected user %s to be invited", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateParentRepRoom_ValidInput_CreatesRoom(t *testing.T) {
|
||||
var receivedRequest CreateRoomRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedRequest)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(CreateRoomResponse{RoomID: "!rep:test.local"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
teacherIDs := []string{"@lehrer:test.local"}
|
||||
repIDs := []string{"@elternvertreter1:test.local", "@elternvertreter2:test.local"}
|
||||
|
||||
result, err := service.CreateParentRepRoom(context.Background(), "5a", teacherIDs, repIDs)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if result.RoomID != "!rep:test.local" {
|
||||
t.Errorf("Expected room ID '!rep:test.local', got '%s'", result.RoomID)
|
||||
}
|
||||
|
||||
// Verify room name
|
||||
expectedName := "5a - Elternvertreter"
|
||||
if receivedRequest.Name != expectedName {
|
||||
t.Errorf("Expected name '%s', got '%s'", expectedName, receivedRequest.Name)
|
||||
}
|
||||
|
||||
// Verify all participants can write (power level 50)
|
||||
allUsers := append(teacherIDs, repIDs...)
|
||||
for _, userID := range allUsers {
|
||||
if level, ok := receivedRequest.PowerLevelContentOverride.Users[userID]; !ok || level != 50 {
|
||||
t.Errorf("Expected user %s to have power level 50, got %d", userID, level)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: User Management
|
||||
// ========================================
|
||||
|
||||
func TestSetDisplayName_ValidRequest_Succeeds(t *testing.T) {
|
||||
var receivedPath string
|
||||
var receivedBody map[string]string
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedPath = r.URL.Path
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SetDisplayName(context.Background(), "@user:test.local", "Max Mustermann")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Path may or may not be URL-encoded depending on Go version
|
||||
if !strings.Contains(receivedPath, "/profile/") || !strings.Contains(receivedPath, "/displayname") {
|
||||
t.Errorf("Expected path to contain '/profile/' and '/displayname', got '%s'", receivedPath)
|
||||
}
|
||||
|
||||
if receivedBody["displayname"] != "Max Mustermann" {
|
||||
t.Errorf("Expected displayname 'Max Mustermann', got '%s'", receivedBody["displayname"])
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Room Membership
|
||||
// ========================================
|
||||
|
||||
func TestInviteUser_ValidRequest_Succeeds(t *testing.T) {
|
||||
var receivedBody InviteRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.URL.Path, "/invite") {
|
||||
t.Errorf("Expected path to contain '/invite', got '%s'", r.URL.Path)
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST method, got %s", r.Method)
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.InviteUser(context.Background(), "!room:test.local", "@user:test.local")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if receivedBody.UserID != "@user:test.local" {
|
||||
t.Errorf("Expected user_id '@user:test.local', got '%s'", receivedBody.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteUser_UserAlreadyInRoom_ReturnsError(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"errcode": "M_FORBIDDEN",
|
||||
"error": "User is already in the room",
|
||||
})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.InviteUser(context.Background(), "!room:test.local", "@user:test.local")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinRoom_ValidRequest_Succeeds(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.URL.Path, "/join/") {
|
||||
t.Errorf("Expected path to contain '/join/', got '%s'", r.URL.Path)
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST method, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"room_id": "!room:test.local"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.JoinRoom(context.Background(), "!room:test.local")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Messaging
|
||||
// ========================================
|
||||
|
||||
func TestSendMessage_ValidRequest_Succeeds(t *testing.T) {
|
||||
var receivedBody SendMessageRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.URL.Path, "/send/m.room.message/") {
|
||||
t.Errorf("Expected path to contain '/send/m.room.message/', got '%s'", r.URL.Path)
|
||||
}
|
||||
if r.Method != "PUT" {
|
||||
t.Errorf("Expected PUT method, got %s", r.Method)
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SendMessage(context.Background(), "!room:test.local", "Hello, World!")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if receivedBody.MsgType != "m.text" {
|
||||
t.Errorf("Expected msgtype 'm.text', got '%s'", receivedBody.MsgType)
|
||||
}
|
||||
if receivedBody.Body != "Hello, World!" {
|
||||
t.Errorf("Expected body 'Hello, World!', got '%s'", receivedBody.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendHTMLMessage_ValidRequest_IncludesFormattedBody(t *testing.T) {
|
||||
var receivedBody SendMessageRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SendHTMLMessage(context.Background(), "!room:test.local", "Plain text", "<b>Bold text</b>")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if receivedBody.Format != "org.matrix.custom.html" {
|
||||
t.Errorf("Expected format 'org.matrix.custom.html', got '%s'", receivedBody.Format)
|
||||
}
|
||||
if receivedBody.Body != "Plain text" {
|
||||
t.Errorf("Expected body 'Plain text', got '%s'", receivedBody.Body)
|
||||
}
|
||||
if receivedBody.FormattedBody != "<b>Bold text</b>" {
|
||||
t.Errorf("Expected formatted_body '<b>Bold text</b>', got '%s'", receivedBody.FormattedBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendAbsenceNotification_ValidRequest_FormatsCorrectly(t *testing.T) {
|
||||
var receivedBody SendMessageRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SendAbsenceNotification(context.Background(), "!room:test.local", "Max Mustermann", "15.12.2025", 3)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Verify plain text contains key information
|
||||
if !strings.Contains(receivedBody.Body, "Max Mustermann") {
|
||||
t.Error("Expected body to contain student name")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "15.12.2025") {
|
||||
t.Error("Expected body to contain date")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "3. Stunde") {
|
||||
t.Error("Expected body to contain lesson number")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "Abwesenheitsmeldung") {
|
||||
t.Error("Expected body to contain 'Abwesenheitsmeldung'")
|
||||
}
|
||||
|
||||
// Verify HTML is set
|
||||
if receivedBody.FormattedBody == "" {
|
||||
t.Error("Expected formatted body to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendGradeNotification_ValidRequest_FormatsCorrectly(t *testing.T) {
|
||||
var receivedBody SendMessageRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SendGradeNotification(context.Background(), "!room:test.local", "Max Mustermann", "Mathematik", "Klassenarbeit", 2.3)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(receivedBody.Body, "Max Mustermann") {
|
||||
t.Error("Expected body to contain student name")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "Mathematik") {
|
||||
t.Error("Expected body to contain subject")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "Klassenarbeit") {
|
||||
t.Error("Expected body to contain grade type")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "2.3") {
|
||||
t.Error("Expected body to contain grade")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendClassAnnouncement_ValidRequest_FormatsCorrectly(t *testing.T) {
|
||||
var receivedBody SendMessageRequest
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&receivedBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SendClassAnnouncement(context.Background(), "!room:test.local", "Elternabend", "Am 20.12. findet der Elternabend statt.", "Frau Müller")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(receivedBody.Body, "Elternabend") {
|
||||
t.Error("Expected body to contain title")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "20.12.") {
|
||||
t.Error("Expected body to contain content")
|
||||
}
|
||||
if !strings.Contains(receivedBody.Body, "Frau Müller") {
|
||||
t.Error("Expected body to contain teacher name")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Power Levels
|
||||
// ========================================
|
||||
|
||||
func TestSetUserPowerLevel_ValidRequest_UpdatesPowerLevel(t *testing.T) {
|
||||
callCount := 0
|
||||
var putBody PowerLevels
|
||||
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.Method == "GET" {
|
||||
// Return current power levels
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(PowerLevels{
|
||||
Users: map[string]int{
|
||||
"@admin:test.local": 100,
|
||||
},
|
||||
UsersDefault: 0,
|
||||
})
|
||||
} else if r.Method == "PUT" {
|
||||
// Update power levels
|
||||
json.NewDecoder(r.Body).Decode(&putBody)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
err := service.SetUserPowerLevel(context.Background(), "!room:test.local", "@newuser:test.local", 50)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if callCount != 2 {
|
||||
t.Errorf("Expected 2 API calls (GET then PUT), got %d", callCount)
|
||||
}
|
||||
if putBody.Users["@newuser:test.local"] != 50 {
|
||||
t.Errorf("Expected user power level 50, got %d", putBody.Users["@newuser:test.local"])
|
||||
}
|
||||
// Verify existing users are preserved
|
||||
if putBody.Users["@admin:test.local"] != 100 {
|
||||
t.Errorf("Expected admin power level 100 to be preserved, got %d", putBody.Users["@admin:test.local"])
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Error Handling
|
||||
// ========================================
|
||||
|
||||
func TestParseError_MatrixError_ExtractsFields(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "Something went wrong",
|
||||
})
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
_, err := service.CreateRoom(context.Background(), CreateRoomRequest{Name: "Test"})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "M_UNKNOWN") {
|
||||
t.Errorf("Expected error to contain 'M_UNKNOWN', got: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Something went wrong") {
|
||||
t.Errorf("Expected error to contain 'Something went wrong', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseError_NonJSONError_ReturnsRawBody(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Internal Server Error"))
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
_, err := service.CreateRoom(context.Background(), CreateRoomRequest{Name: "Test"})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("Expected error to contain '500', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Unit Tests: Context Handling
|
||||
// ========================================
|
||||
|
||||
func TestCreateRoom_ContextCanceled_ReturnsError(t *testing.T) {
|
||||
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := service.CreateRoom(ctx, CreateRoomRequest{Name: "Test"})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for canceled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Integration Tests (require running Synapse)
|
||||
// ========================================
|
||||
|
||||
// These tests are skipped by default as they require a running Matrix server
|
||||
// Run with: go test -tags=integration ./...
|
||||
|
||||
func TestIntegration_HealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
service := NewMatrixService(Config{
|
||||
HomeserverURL: "http://localhost:8008",
|
||||
AccessToken: "", // Not needed for health check
|
||||
ServerName: "breakpilot.local",
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := service.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Skipf("Matrix server not available: %v", err)
|
||||
}
|
||||
}
|
||||
347
consent-service/internal/services/notification_service.go
Normal file
347
consent-service/internal/services/notification_service.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// NotificationType defines the type of notification
|
||||
type NotificationType string
|
||||
|
||||
const (
|
||||
NotificationTypeConsentRequired NotificationType = "consent_required"
|
||||
NotificationTypeConsentReminder NotificationType = "consent_reminder"
|
||||
NotificationTypeVersionPublished NotificationType = "version_published"
|
||||
NotificationTypeVersionApproved NotificationType = "version_approved"
|
||||
NotificationTypeVersionRejected NotificationType = "version_rejected"
|
||||
NotificationTypeAccountSuspended NotificationType = "account_suspended"
|
||||
NotificationTypeAccountRestored NotificationType = "account_restored"
|
||||
NotificationTypeGeneral NotificationType = "general"
|
||||
// DSR (Data Subject Request) notification types
|
||||
NotificationTypeDSRReceived NotificationType = "dsr_received"
|
||||
NotificationTypeDSRAssigned NotificationType = "dsr_assigned"
|
||||
NotificationTypeDSRDeadline NotificationType = "dsr_deadline"
|
||||
)
|
||||
|
||||
// NotificationChannel defines how notification is delivered
|
||||
type NotificationChannel string
|
||||
|
||||
const (
|
||||
ChannelInApp NotificationChannel = "in_app"
|
||||
ChannelEmail NotificationChannel = "email"
|
||||
ChannelPush NotificationChannel = "push"
|
||||
)
|
||||
|
||||
// Notification represents a notification entity
|
||||
type Notification struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Type NotificationType `json:"type"`
|
||||
Channel NotificationChannel `json:"channel"`
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
ReadAt *time.Time `json:"read_at,omitempty"`
|
||||
SentAt *time.Time `json:"sent_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// NotificationPreferences holds user notification settings
|
||||
type NotificationPreferences struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
EmailEnabled bool `json:"email_enabled"`
|
||||
PushEnabled bool `json:"push_enabled"`
|
||||
InAppEnabled bool `json:"in_app_enabled"`
|
||||
ReminderFrequency string `json:"reminder_frequency"`
|
||||
}
|
||||
|
||||
// NotificationService handles notification operations
|
||||
type NotificationService struct {
|
||||
pool *pgxpool.Pool
|
||||
emailService *EmailService
|
||||
}
|
||||
|
||||
// NewNotificationService creates a new notification service
|
||||
func NewNotificationService(pool *pgxpool.Pool, emailService *EmailService) *NotificationService {
|
||||
return &NotificationService{
|
||||
pool: pool,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateNotification creates and optionally sends a notification
|
||||
func (s *NotificationService) CreateNotification(ctx context.Context, userID uuid.UUID, notifType NotificationType, title, body string, data map[string]interface{}) error {
|
||||
// Get user preferences
|
||||
prefs, err := s.GetPreferences(ctx, userID)
|
||||
if err != nil {
|
||||
// Use default preferences if not found
|
||||
prefs = &NotificationPreferences{
|
||||
UserID: userID,
|
||||
EmailEnabled: true,
|
||||
PushEnabled: true,
|
||||
InAppEnabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Create in-app notification if enabled
|
||||
if prefs.InAppEnabled {
|
||||
if err := s.createInAppNotification(ctx, userID, notifType, title, body, data); err != nil {
|
||||
return fmt.Errorf("failed to create in-app notification: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Send email notification if enabled
|
||||
if prefs.EmailEnabled && s.emailService != nil {
|
||||
go s.sendEmailNotification(ctx, userID, notifType, title, body, data)
|
||||
}
|
||||
|
||||
// Push notification would be sent here if enabled
|
||||
// if prefs.PushEnabled {
|
||||
// go s.sendPushNotification(ctx, userID, title, body, data)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createInAppNotification creates an in-app notification
|
||||
func (s *NotificationService) createInAppNotification(ctx context.Context, userID uuid.UUID, notifType NotificationType, title, body string, data map[string]interface{}) error {
|
||||
dataJSON, _ := json.Marshal(data)
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO notifications (user_id, type, channel, title, body, data, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW())
|
||||
`, userID, notifType, ChannelInApp, title, body, dataJSON)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// sendEmailNotification sends an email notification
|
||||
func (s *NotificationService) sendEmailNotification(ctx context.Context, userID uuid.UUID, notifType NotificationType, title, body string, data map[string]interface{}) {
|
||||
// Get user email
|
||||
var email string
|
||||
err := s.pool.QueryRow(ctx, `SELECT email FROM users WHERE id = $1`, userID).Scan(&email)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send based on notification type
|
||||
switch notifType {
|
||||
case NotificationTypeConsentRequired, NotificationTypeConsentReminder:
|
||||
s.emailService.SendConsentReminderEmail(email, title, body)
|
||||
default:
|
||||
s.emailService.SendGenericNotificationEmail(email, title, body)
|
||||
}
|
||||
|
||||
// Mark as sent
|
||||
s.pool.Exec(ctx, `
|
||||
UPDATE notifications SET sent_at = NOW()
|
||||
WHERE user_id = $1 AND type = $2 AND channel = $3 AND sent_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
`, userID, notifType, ChannelEmail)
|
||||
}
|
||||
|
||||
// GetUserNotifications returns notifications for a user
|
||||
func (s *NotificationService) GetUserNotifications(ctx context.Context, userID uuid.UUID, limit, offset int, unreadOnly bool) ([]Notification, int, error) {
|
||||
// Count total
|
||||
var totalQuery string
|
||||
var total int
|
||||
|
||||
if unreadOnly {
|
||||
totalQuery = `SELECT COUNT(*) FROM notifications WHERE user_id = $1 AND read_at IS NULL`
|
||||
} else {
|
||||
totalQuery = `SELECT COUNT(*) FROM notifications WHERE user_id = $1`
|
||||
}
|
||||
s.pool.QueryRow(ctx, totalQuery, userID).Scan(&total)
|
||||
|
||||
// Get notifications
|
||||
var query string
|
||||
if unreadOnly {
|
||||
query = `
|
||||
SELECT id, user_id, type, channel, title, body, data, read_at, sent_at, created_at
|
||||
FROM notifications
|
||||
WHERE user_id = $1 AND read_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
} else {
|
||||
query = `
|
||||
SELECT id, user_id, type, channel, title, body, data, read_at, sent_at, created_at
|
||||
FROM notifications
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`
|
||||
}
|
||||
|
||||
rows, err := s.pool.Query(ctx, query, userID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var notifications []Notification
|
||||
for rows.Next() {
|
||||
var n Notification
|
||||
var dataJSON []byte
|
||||
if err := rows.Scan(&n.ID, &n.UserID, &n.Type, &n.Channel, &n.Title, &n.Body, &dataJSON, &n.ReadAt, &n.SentAt, &n.CreatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
if dataJSON != nil {
|
||||
json.Unmarshal(dataJSON, &n.Data)
|
||||
}
|
||||
notifications = append(notifications, n)
|
||||
}
|
||||
|
||||
return notifications, total, nil
|
||||
}
|
||||
|
||||
// GetUnreadCount returns the count of unread notifications
|
||||
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID uuid.UUID) (int, error) {
|
||||
var count int
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM notifications WHERE user_id = $1 AND read_at IS NULL
|
||||
`, userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// MarkAsRead marks a notification as read
|
||||
func (s *NotificationService) MarkAsRead(ctx context.Context, userID uuid.UUID, notificationID uuid.UUID) error {
|
||||
result, err := s.pool.Exec(ctx, `
|
||||
UPDATE notifications SET read_at = NOW()
|
||||
WHERE id = $1 AND user_id = $2 AND read_at IS NULL
|
||||
`, notificationID, userID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("notification not found or already read")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAllAsRead marks all notifications as read for a user
|
||||
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE notifications SET read_at = NOW()
|
||||
WHERE user_id = $1 AND read_at IS NULL
|
||||
`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteNotification deletes a notification
|
||||
func (s *NotificationService) DeleteNotification(ctx context.Context, userID uuid.UUID, notificationID uuid.UUID) error {
|
||||
result, err := s.pool.Exec(ctx, `
|
||||
DELETE FROM notifications WHERE id = $1 AND user_id = $2
|
||||
`, notificationID, userID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("notification not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPreferences returns notification preferences for a user
|
||||
func (s *NotificationService) GetPreferences(ctx context.Context, userID uuid.UUID) (*NotificationPreferences, error) {
|
||||
var prefs NotificationPreferences
|
||||
prefs.UserID = userID
|
||||
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT email_enabled, push_enabled, in_app_enabled, reminder_frequency
|
||||
FROM notification_preferences
|
||||
WHERE user_id = $1
|
||||
`, userID).Scan(&prefs.EmailEnabled, &prefs.PushEnabled, &prefs.InAppEnabled, &prefs.ReminderFrequency)
|
||||
|
||||
if err != nil {
|
||||
// Return defaults if not found
|
||||
return &NotificationPreferences{
|
||||
UserID: userID,
|
||||
EmailEnabled: true,
|
||||
PushEnabled: true,
|
||||
InAppEnabled: true,
|
||||
ReminderFrequency: "weekly",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &prefs, nil
|
||||
}
|
||||
|
||||
// UpdatePreferences updates notification preferences for a user
|
||||
func (s *NotificationService) UpdatePreferences(ctx context.Context, userID uuid.UUID, prefs *NotificationPreferences) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO notification_preferences (user_id, email_enabled, push_enabled, in_app_enabled, reminder_frequency, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
email_enabled = $2,
|
||||
push_enabled = $3,
|
||||
in_app_enabled = $4,
|
||||
reminder_frequency = $5,
|
||||
updated_at = NOW()
|
||||
`, userID, prefs.EmailEnabled, prefs.PushEnabled, prefs.InAppEnabled, prefs.ReminderFrequency)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// NotifyConsentRequired sends consent required notifications to all active users
|
||||
func (s *NotificationService) NotifyConsentRequired(ctx context.Context, documentName, versionID string) error {
|
||||
// Get all active users
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id FROM users WHERE account_status = 'active'
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
title := "Neue Zustimmung erforderlich"
|
||||
body := fmt.Sprintf("Eine neue Version von '%s' wurde veröffentlicht. Bitte überprüfen und bestätigen Sie diese.", documentName)
|
||||
data := map[string]interface{}{
|
||||
"version_id": versionID,
|
||||
"document_name": documentName,
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
continue
|
||||
}
|
||||
go s.CreateNotification(ctx, userID, NotificationTypeConsentRequired, title, body, data)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyVersionApproved notifies the creator that their version was approved
|
||||
func (s *NotificationService) NotifyVersionApproved(ctx context.Context, creatorID uuid.UUID, documentName, versionNumber, approverEmail string) error {
|
||||
title := "Version genehmigt"
|
||||
body := fmt.Sprintf("Ihre Version %s von '%s' wurde von %s genehmigt und kann nun veröffentlicht werden.", versionNumber, documentName, approverEmail)
|
||||
data := map[string]interface{}{
|
||||
"document_name": documentName,
|
||||
"version_number": versionNumber,
|
||||
"approver": approverEmail,
|
||||
}
|
||||
|
||||
return s.CreateNotification(ctx, creatorID, NotificationTypeVersionApproved, title, body, data)
|
||||
}
|
||||
|
||||
// NotifyVersionRejected notifies the creator that their version was rejected
|
||||
func (s *NotificationService) NotifyVersionRejected(ctx context.Context, creatorID uuid.UUID, documentName, versionNumber, reason, rejecterEmail string) error {
|
||||
title := "Version abgelehnt"
|
||||
body := fmt.Sprintf("Ihre Version %s von '%s' wurde von %s abgelehnt. Grund: %s", versionNumber, documentName, rejecterEmail, reason)
|
||||
data := map[string]interface{}{
|
||||
"document_name": documentName,
|
||||
"version_number": versionNumber,
|
||||
"rejecter": rejecterEmail,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
return s.CreateNotification(ctx, creatorID, NotificationTypeVersionRejected, title, body, data)
|
||||
}
|
||||
660
consent-service/internal/services/notification_service_test.go
Normal file
660
consent-service/internal/services/notification_service_test.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestNotificationService_CreateNotification tests notification creation
|
||||
func TestNotificationService_CreateNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
notifType NotificationType
|
||||
title string
|
||||
body string
|
||||
data map[string]interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid notification",
|
||||
userID: uuid.New(),
|
||||
notifType: NotificationTypeConsentRequired,
|
||||
title: "Consent Required",
|
||||
body: "Please review and accept the new terms",
|
||||
data: map[string]interface{}{"document_id": "123"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "notification without data",
|
||||
userID: uuid.New(),
|
||||
notifType: NotificationTypeGeneral,
|
||||
title: "General Notification",
|
||||
body: "This is a test",
|
||||
data: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty user ID",
|
||||
userID: uuid.Nil,
|
||||
notifType: NotificationTypeGeneral,
|
||||
title: "Test",
|
||||
body: "Test",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty title",
|
||||
userID: uuid.New(),
|
||||
notifType: NotificationTypeGeneral,
|
||||
title: "",
|
||||
body: "Test body",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if tt.title == "" {
|
||||
err = &ValidationError{Field: "title", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_NotificationTypes tests notification type validation
|
||||
func TestNotificationService_NotificationTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
notifType NotificationType
|
||||
isValid bool
|
||||
}{
|
||||
{NotificationTypeConsentRequired, true},
|
||||
{NotificationTypeConsentReminder, true},
|
||||
{NotificationTypeVersionPublished, true},
|
||||
{NotificationTypeVersionApproved, true},
|
||||
{NotificationTypeVersionRejected, true},
|
||||
{NotificationTypeAccountSuspended, true},
|
||||
{NotificationTypeAccountRestored, true},
|
||||
{NotificationTypeGeneral, true},
|
||||
{NotificationType("invalid_type"), false},
|
||||
{NotificationType(""), false},
|
||||
}
|
||||
|
||||
validTypes := map[NotificationType]bool{
|
||||
NotificationTypeConsentRequired: true,
|
||||
NotificationTypeConsentReminder: true,
|
||||
NotificationTypeVersionPublished: true,
|
||||
NotificationTypeVersionApproved: true,
|
||||
NotificationTypeVersionRejected: true,
|
||||
NotificationTypeAccountSuspended: true,
|
||||
NotificationTypeAccountRestored: true,
|
||||
NotificationTypeGeneral: true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.notifType), func(t *testing.T) {
|
||||
isValid := validTypes[tt.notifType]
|
||||
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("Type %s: expected valid=%v, got %v", tt.notifType, tt.isValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_NotificationChannels tests channel validation
|
||||
func TestNotificationService_NotificationChannels(t *testing.T) {
|
||||
tests := []struct {
|
||||
channel NotificationChannel
|
||||
isValid bool
|
||||
}{
|
||||
{ChannelInApp, true},
|
||||
{ChannelEmail, true},
|
||||
{ChannelPush, true},
|
||||
{NotificationChannel("sms"), false},
|
||||
{NotificationChannel(""), false},
|
||||
}
|
||||
|
||||
validChannels := map[NotificationChannel]bool{
|
||||
ChannelInApp: true,
|
||||
ChannelEmail: true,
|
||||
ChannelPush: true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.channel), func(t *testing.T) {
|
||||
isValid := validChannels[tt.channel]
|
||||
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("Channel %s: expected valid=%v, got %v", tt.channel, tt.isValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_GetUserNotifications tests retrieving notifications
|
||||
func TestNotificationService_GetUserNotifications(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
limit int
|
||||
offset int
|
||||
unreadOnly bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get all notifications",
|
||||
userID: uuid.New(),
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
unreadOnly: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "get unread only",
|
||||
userID: uuid.New(),
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
unreadOnly: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "with pagination",
|
||||
userID: uuid.New(),
|
||||
limit: 10,
|
||||
offset: 20,
|
||||
unreadOnly: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
unreadOnly: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "negative limit",
|
||||
userID: uuid.New(),
|
||||
limit: -1,
|
||||
offset: 0,
|
||||
unreadOnly: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if tt.limit < 0 {
|
||||
err = &ValidationError{Field: "limit", Message: "must be >= 0"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_MarkAsRead tests marking notifications as read
|
||||
func TestNotificationService_MarkAsRead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
notificationID uuid.UUID
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "mark valid notification as read",
|
||||
notificationID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid notification ID",
|
||||
notificationID: uuid.Nil,
|
||||
userID: uuid.New(),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
notificationID: uuid.New(),
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.notificationID == uuid.Nil {
|
||||
err = &ValidationError{Field: "notification ID", Message: "required"}
|
||||
} else if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_GetPreferences tests retrieving user preferences
|
||||
func TestNotificationService_GetPreferences(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get valid user preferences",
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_UpdatePreferences tests updating notification preferences
|
||||
func TestNotificationService_UpdatePreferences(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
emailEnabled bool
|
||||
pushEnabled bool
|
||||
inAppEnabled bool
|
||||
reminderFrequency string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "enable all notifications",
|
||||
userID: uuid.New(),
|
||||
emailEnabled: true,
|
||||
pushEnabled: true,
|
||||
inAppEnabled: true,
|
||||
reminderFrequency: "daily",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "disable email notifications",
|
||||
userID: uuid.New(),
|
||||
emailEnabled: false,
|
||||
pushEnabled: true,
|
||||
inAppEnabled: true,
|
||||
reminderFrequency: "weekly",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "set reminder frequency to never",
|
||||
userID: uuid.New(),
|
||||
emailEnabled: true,
|
||||
pushEnabled: false,
|
||||
inAppEnabled: true,
|
||||
reminderFrequency: "never",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid reminder frequency",
|
||||
userID: uuid.New(),
|
||||
emailEnabled: true,
|
||||
pushEnabled: true,
|
||||
inAppEnabled: true,
|
||||
reminderFrequency: "hourly",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
validFrequencies := map[string]bool{
|
||||
"daily": true,
|
||||
"weekly": true,
|
||||
"never": true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if !validFrequencies[tt.reminderFrequency] {
|
||||
err = &ValidationError{Field: "reminder_frequency", Message: "must be daily, weekly, or never"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_NotifyConsentRequired tests consent required notification
|
||||
func TestNotificationService_NotifyConsentRequired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentName string
|
||||
versionID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid consent notification",
|
||||
documentName: "Terms of Service",
|
||||
versionID: uuid.New().String(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty document name",
|
||||
documentName: "",
|
||||
versionID: uuid.New().String(),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty version ID",
|
||||
documentName: "Privacy Policy",
|
||||
versionID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.documentName == "" {
|
||||
err = &ValidationError{Field: "document name", Message: "required"}
|
||||
} else if tt.versionID == "" {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_DeleteNotification tests deleting notifications
|
||||
func TestNotificationService_DeleteNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
notificationID uuid.UUID
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete valid notification",
|
||||
notificationID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid notification ID",
|
||||
notificationID: uuid.Nil,
|
||||
userID: uuid.New(),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
notificationID: uuid.New(),
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.notificationID == uuid.Nil {
|
||||
err = &ValidationError{Field: "notification ID", Message: "required"}
|
||||
} else if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_BatchMarkAsRead tests batch marking as read
|
||||
func TestNotificationService_BatchMarkAsRead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
notificationIDs []uuid.UUID
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "mark multiple notifications",
|
||||
notificationIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New()},
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
notificationIDs: []uuid.UUID{},
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
notificationIDs: []uuid.UUID{uuid.New()},
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_GetUnreadCount tests getting unread count
|
||||
func TestNotificationService_GetUnreadCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get count for valid user",
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_NotificationPriority tests notification priority
|
||||
func TestNotificationService_NotificationPriority(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
notifType NotificationType
|
||||
expectedPrio string
|
||||
}{
|
||||
{
|
||||
name: "consent required - high priority",
|
||||
notifType: NotificationTypeConsentRequired,
|
||||
expectedPrio: "high",
|
||||
},
|
||||
{
|
||||
name: "account suspended - critical",
|
||||
notifType: NotificationTypeAccountSuspended,
|
||||
expectedPrio: "critical",
|
||||
},
|
||||
{
|
||||
name: "version published - normal",
|
||||
notifType: NotificationTypeVersionPublished,
|
||||
expectedPrio: "normal",
|
||||
},
|
||||
{
|
||||
name: "general - low",
|
||||
notifType: NotificationTypeGeneral,
|
||||
expectedPrio: "low",
|
||||
},
|
||||
}
|
||||
|
||||
priorityMap := map[NotificationType]string{
|
||||
NotificationTypeConsentRequired: "high",
|
||||
NotificationTypeConsentReminder: "high",
|
||||
NotificationTypeAccountSuspended: "critical",
|
||||
NotificationTypeAccountRestored: "normal",
|
||||
NotificationTypeVersionPublished: "normal",
|
||||
NotificationTypeVersionApproved: "normal",
|
||||
NotificationTypeVersionRejected: "normal",
|
||||
NotificationTypeGeneral: "low",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
priority := priorityMap[tt.notifType]
|
||||
|
||||
if priority != tt.expectedPrio {
|
||||
t.Errorf("Expected priority %s, got %s", tt.expectedPrio, priority)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationService_ReminderFrequency tests reminder frequency logic
|
||||
func TestNotificationService_ReminderFrequency(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
frequency string
|
||||
lastReminder time.Time
|
||||
shouldSend bool
|
||||
}{
|
||||
{
|
||||
name: "daily - last sent yesterday",
|
||||
frequency: "daily",
|
||||
lastReminder: now.AddDate(0, 0, -1),
|
||||
shouldSend: true,
|
||||
},
|
||||
{
|
||||
name: "daily - last sent today",
|
||||
frequency: "daily",
|
||||
lastReminder: now.Add(-1 * time.Hour),
|
||||
shouldSend: false,
|
||||
},
|
||||
{
|
||||
name: "weekly - last sent 8 days ago",
|
||||
frequency: "weekly",
|
||||
lastReminder: now.AddDate(0, 0, -8),
|
||||
shouldSend: true,
|
||||
},
|
||||
{
|
||||
name: "weekly - last sent 5 days ago",
|
||||
frequency: "weekly",
|
||||
lastReminder: now.AddDate(0, 0, -5),
|
||||
shouldSend: false,
|
||||
},
|
||||
{
|
||||
name: "never - should not send",
|
||||
frequency: "never",
|
||||
lastReminder: now.AddDate(0, 0, -30),
|
||||
shouldSend: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var shouldSend bool
|
||||
|
||||
switch tt.frequency {
|
||||
case "daily":
|
||||
daysSince := int(now.Sub(tt.lastReminder).Hours() / 24)
|
||||
shouldSend = daysSince >= 1
|
||||
case "weekly":
|
||||
daysSince := int(now.Sub(tt.lastReminder).Hours() / 24)
|
||||
shouldSend = daysSince >= 7
|
||||
case "never":
|
||||
shouldSend = false
|
||||
}
|
||||
|
||||
if shouldSend != tt.shouldSend {
|
||||
t.Errorf("Expected shouldSend=%v, got %v", tt.shouldSend, shouldSend)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
524
consent-service/internal/services/oauth_service.go
Normal file
524
consent-service/internal/services/oauth_service.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
ErrInvalidGrant = errors.New("invalid_grant")
|
||||
ErrInvalidScope = errors.New("invalid_scope")
|
||||
ErrInvalidRequest = errors.New("invalid_request")
|
||||
ErrUnauthorizedClient = errors.New("unauthorized_client")
|
||||
ErrAccessDenied = errors.New("access_denied")
|
||||
ErrInvalidRedirectURI = errors.New("invalid redirect_uri")
|
||||
ErrCodeExpired = errors.New("authorization code expired")
|
||||
ErrCodeUsed = errors.New("authorization code already used")
|
||||
ErrPKCERequired = errors.New("PKCE code_challenge required for public clients")
|
||||
ErrPKCEVerifyFailed = errors.New("PKCE verification failed")
|
||||
)
|
||||
|
||||
// OAuthService handles OAuth 2.0 Authorization Code Flow with PKCE
|
||||
type OAuthService struct {
|
||||
db *pgxpool.Pool
|
||||
jwtSecret string
|
||||
authCodeExpiration time.Duration
|
||||
accessTokenExpiration time.Duration
|
||||
refreshTokenExpiration time.Duration
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuthService
|
||||
func NewOAuthService(db *pgxpool.Pool, jwtSecret string) *OAuthService {
|
||||
return &OAuthService{
|
||||
db: db,
|
||||
jwtSecret: jwtSecret,
|
||||
authCodeExpiration: 10 * time.Minute, // Authorization codes expire quickly
|
||||
accessTokenExpiration: time.Hour, // 1 hour
|
||||
refreshTokenExpiration: 30 * 24 * time.Hour, // 30 days
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateClient validates an OAuth client
|
||||
func (s *OAuthService) ValidateClient(ctx context.Context, clientID string) (*models.OAuthClient, error) {
|
||||
var client models.OAuthClient
|
||||
var redirectURIsJSON, scopesJSON, grantTypesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, is_active, created_at
|
||||
FROM oauth_clients WHERE client_id = $1
|
||||
`, clientID).Scan(
|
||||
&client.ID, &client.ClientID, &client.ClientSecret, &client.Name, &client.Description,
|
||||
&redirectURIsJSON, &scopesJSON, &grantTypesJSON, &client.IsPublic, &client.IsActive, &client.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
if !client.IsActive {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
// Parse JSON arrays
|
||||
json.Unmarshal(redirectURIsJSON, &client.RedirectURIs)
|
||||
json.Unmarshal(scopesJSON, &client.Scopes)
|
||||
json.Unmarshal(grantTypesJSON, &client.GrantTypes)
|
||||
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
// ValidateClientSecret validates client credentials for confidential clients
|
||||
func (s *OAuthService) ValidateClientSecret(client *models.OAuthClient, clientSecret string) error {
|
||||
if client.IsPublic {
|
||||
// Public clients don't have a secret
|
||||
return nil
|
||||
}
|
||||
|
||||
if client.ClientSecret != clientSecret {
|
||||
return ErrInvalidClient
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRedirectURI validates the redirect URI against registered URIs
|
||||
func (s *OAuthService) ValidateRedirectURI(client *models.OAuthClient, redirectURI string) error {
|
||||
for _, uri := range client.RedirectURIs {
|
||||
if uri == redirectURI {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrInvalidRedirectURI
|
||||
}
|
||||
|
||||
// ValidateScopes validates requested scopes against client's allowed scopes
|
||||
func (s *OAuthService) ValidateScopes(client *models.OAuthClient, requestedScopes string) ([]string, error) {
|
||||
if requestedScopes == "" {
|
||||
// Return default scopes
|
||||
return []string{"openid", "profile", "email"}, nil
|
||||
}
|
||||
|
||||
requested := strings.Split(requestedScopes, " ")
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, scope := range client.Scopes {
|
||||
allowedMap[scope] = true
|
||||
}
|
||||
|
||||
var validScopes []string
|
||||
for _, scope := range requested {
|
||||
if allowedMap[scope] {
|
||||
validScopes = append(validScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validScopes) == 0 {
|
||||
return nil, ErrInvalidScope
|
||||
}
|
||||
|
||||
return validScopes, nil
|
||||
}
|
||||
|
||||
// GenerateAuthorizationCode generates a new authorization code
|
||||
func (s *OAuthService) GenerateAuthorizationCode(
|
||||
ctx context.Context,
|
||||
client *models.OAuthClient,
|
||||
userID uuid.UUID,
|
||||
redirectURI string,
|
||||
scopes []string,
|
||||
codeChallenge, codeChallengeMethod string,
|
||||
) (string, error) {
|
||||
// For public clients, PKCE is required
|
||||
if client.IsPublic && codeChallenge == "" {
|
||||
return "", ErrPKCERequired
|
||||
}
|
||||
|
||||
// Generate a secure random code
|
||||
codeBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(codeBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
code := base64.URLEncoding.EncodeToString(codeBytes)
|
||||
|
||||
// Hash the code for storage
|
||||
codeHash := sha256.Sum256([]byte(code))
|
||||
hashedCode := hex.EncodeToString(codeHash[:])
|
||||
|
||||
scopesJSON, _ := json.Marshal(scopes)
|
||||
|
||||
var challengePtr, methodPtr *string
|
||||
if codeChallenge != "" {
|
||||
challengePtr = &codeChallenge
|
||||
if codeChallengeMethod == "" {
|
||||
codeChallengeMethod = "plain"
|
||||
}
|
||||
methodPtr = &codeChallengeMethod
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, hashedCode, client.ClientID, userID, redirectURI, scopesJSON, challengePtr, methodPtr, time.Now().Add(s.authCodeExpiration))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to store authorization code: %w", err)
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// ExchangeAuthorizationCode exchanges an authorization code for tokens
|
||||
func (s *OAuthService) ExchangeAuthorizationCode(
|
||||
ctx context.Context,
|
||||
code string,
|
||||
clientID string,
|
||||
redirectURI string,
|
||||
codeVerifier string,
|
||||
) (*models.OAuthTokenResponse, error) {
|
||||
// Hash the code to look it up
|
||||
codeHash := sha256.Sum256([]byte(code))
|
||||
hashedCode := hex.EncodeToString(codeHash[:])
|
||||
|
||||
var authCode models.OAuthAuthorizationCode
|
||||
var scopesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at, used_at
|
||||
FROM oauth_authorization_codes WHERE code = $1
|
||||
`, hashedCode).Scan(
|
||||
&authCode.ID, &authCode.ClientID, &authCode.UserID, &authCode.RedirectURI,
|
||||
&scopesJSON, &authCode.CodeChallenge, &authCode.CodeChallengeMethod,
|
||||
&authCode.ExpiresAt, &authCode.UsedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Check if code was already used
|
||||
if authCode.UsedAt != nil {
|
||||
return nil, ErrCodeUsed
|
||||
}
|
||||
|
||||
// Check if code is expired
|
||||
if time.Now().After(authCode.ExpiresAt) {
|
||||
return nil, ErrCodeExpired
|
||||
}
|
||||
|
||||
// Verify client_id matches
|
||||
if authCode.ClientID != clientID {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Verify redirect_uri matches
|
||||
if authCode.RedirectURI != redirectURI {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Verify PKCE if code_challenge was provided
|
||||
if authCode.CodeChallenge != nil && *authCode.CodeChallenge != "" {
|
||||
if codeVerifier == "" {
|
||||
return nil, ErrPKCEVerifyFailed
|
||||
}
|
||||
|
||||
var expectedChallenge string
|
||||
if authCode.CodeChallengeMethod != nil && *authCode.CodeChallengeMethod == "S256" {
|
||||
// SHA256 hash of verifier
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
expectedChallenge = base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
} else {
|
||||
// Plain method
|
||||
expectedChallenge = codeVerifier
|
||||
}
|
||||
|
||||
if expectedChallenge != *authCode.CodeChallenge {
|
||||
return nil, ErrPKCEVerifyFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE oauth_authorization_codes SET used_at = NOW() WHERE id = $1`, authCode.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// Parse scopes
|
||||
var scopes []string
|
||||
json.Unmarshal(scopesJSON, &scopes)
|
||||
|
||||
// Generate tokens
|
||||
return s.generateTokens(ctx, clientID, authCode.UserID, scopes)
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes an access token using a refresh token
|
||||
func (s *OAuthService) RefreshAccessToken(ctx context.Context, refreshToken, clientID string, requestedScope string) (*models.OAuthTokenResponse, error) {
|
||||
// Hash the refresh token
|
||||
tokenHash := sha256.Sum256([]byte(refreshToken))
|
||||
hashedToken := hex.EncodeToString(tokenHash[:])
|
||||
|
||||
var rt models.OAuthRefreshToken
|
||||
var scopesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, client_id, user_id, scopes, expires_at, revoked_at
|
||||
FROM oauth_refresh_tokens WHERE token_hash = $1
|
||||
`, hashedToken).Scan(
|
||||
&rt.ID, &rt.ClientID, &rt.UserID, &scopesJSON, &rt.ExpiresAt, &rt.RevokedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
if rt.RevokedAt != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if time.Now().After(rt.ExpiresAt) {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Verify client_id matches
|
||||
if rt.ClientID != clientID {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Parse original scopes
|
||||
var originalScopes []string
|
||||
json.Unmarshal(scopesJSON, &originalScopes)
|
||||
|
||||
// Determine scopes for new tokens
|
||||
var scopes []string
|
||||
if requestedScope != "" {
|
||||
// Validate that requested scopes are subset of original scopes
|
||||
originalMap := make(map[string]bool)
|
||||
for _, s := range originalScopes {
|
||||
originalMap[s] = true
|
||||
}
|
||||
|
||||
for _, s := range strings.Split(requestedScope, " ") {
|
||||
if originalMap[s] {
|
||||
scopes = append(scopes, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(scopes) == 0 {
|
||||
return nil, ErrInvalidScope
|
||||
}
|
||||
} else {
|
||||
scopes = originalScopes
|
||||
}
|
||||
|
||||
// Revoke old refresh token (rotate)
|
||||
_, _ = s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE id = $1`, rt.ID)
|
||||
|
||||
// Generate new tokens
|
||||
return s.generateTokens(ctx, clientID, rt.UserID, scopes)
|
||||
}
|
||||
|
||||
// generateTokens generates access and refresh tokens
|
||||
func (s *OAuthService) generateTokens(ctx context.Context, clientID string, userID uuid.UUID, scopes []string) (*models.OAuthTokenResponse, error) {
|
||||
// Get user info for JWT
|
||||
var user models.User
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, email, name, role, account_status FROM users WHERE id = $1
|
||||
`, userID).Scan(&user.ID, &user.Email, &user.Name, &user.Role, &user.AccountStatus)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidGrant
|
||||
}
|
||||
|
||||
// Generate access token (JWT)
|
||||
accessTokenClaims := jwt.MapClaims{
|
||||
"sub": userID.String(),
|
||||
"email": user.Email,
|
||||
"role": user.Role,
|
||||
"account_status": user.AccountStatus,
|
||||
"client_id": clientID,
|
||||
"scope": strings.Join(scopes, " "),
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(s.accessTokenExpiration).Unix(),
|
||||
"iss": "breakpilot-consent-service",
|
||||
"aud": clientID,
|
||||
}
|
||||
|
||||
if user.Name != nil {
|
||||
accessTokenClaims["name"] = *user.Name
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessTokenClaims)
|
||||
accessTokenString, err := accessToken.SignedString([]byte(s.jwtSecret))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign access token: %w", err)
|
||||
}
|
||||
|
||||
// Hash access token for storage
|
||||
accessTokenHash := sha256.Sum256([]byte(accessTokenString))
|
||||
hashedAccessToken := hex.EncodeToString(accessTokenHash[:])
|
||||
|
||||
scopesJSON, _ := json.Marshal(scopes)
|
||||
|
||||
// Store access token
|
||||
var accessTokenID uuid.UUID
|
||||
err = s.db.QueryRow(ctx, `
|
||||
INSERT INTO oauth_access_tokens (token_hash, client_id, user_id, scopes, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id
|
||||
`, hashedAccessToken, clientID, userID, scopesJSON, time.Now().Add(s.accessTokenExpiration)).Scan(&accessTokenID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store access token: %w", err)
|
||||
}
|
||||
|
||||
// Generate refresh token (opaque)
|
||||
refreshTokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(refreshTokenBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
refreshTokenString := base64.URLEncoding.EncodeToString(refreshTokenBytes)
|
||||
|
||||
// Hash refresh token for storage
|
||||
refreshTokenHash := sha256.Sum256([]byte(refreshTokenString))
|
||||
hashedRefreshToken := hex.EncodeToString(refreshTokenHash[:])
|
||||
|
||||
// Store refresh token
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO oauth_refresh_tokens (token_hash, access_token_id, client_id, user_id, scopes, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`, hashedRefreshToken, accessTokenID, clientID, userID, scopesJSON, time.Now().Add(s.refreshTokenExpiration))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &models.OAuthTokenResponse{
|
||||
AccessToken: accessTokenString,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(s.accessTokenExpiration.Seconds()),
|
||||
RefreshToken: refreshTokenString,
|
||||
Scope: strings.Join(scopes, " "),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeToken revokes an access or refresh token
|
||||
func (s *OAuthService) RevokeToken(ctx context.Context, token, tokenTypeHint string) error {
|
||||
tokenHash := sha256.Sum256([]byte(token))
|
||||
hashedToken := hex.EncodeToString(tokenHash[:])
|
||||
|
||||
// Try to revoke as access token
|
||||
if tokenTypeHint == "" || tokenTypeHint == "access_token" {
|
||||
result, err := s.db.Exec(ctx, `UPDATE oauth_access_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken)
|
||||
if err == nil && result.RowsAffected() > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try to revoke as refresh token
|
||||
if tokenTypeHint == "" || tokenTypeHint == "refresh_token" {
|
||||
result, err := s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken)
|
||||
if err == nil && result.RowsAffected() > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil // RFC 7009: Always return success
|
||||
}
|
||||
|
||||
// ValidateAccessToken validates an OAuth access token
|
||||
func (s *OAuthService) ValidateAccessToken(ctx context.Context, tokenString string) (*jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// Check if token is revoked in database
|
||||
tokenHash := sha256.Sum256([]byte(tokenString))
|
||||
hashedToken := hex.EncodeToString(tokenHash[:])
|
||||
|
||||
var revokedAt *time.Time
|
||||
err = s.db.QueryRow(ctx, `SELECT revoked_at FROM oauth_access_tokens WHERE token_hash = $1`, hashedToken).Scan(&revokedAt)
|
||||
if err == nil && revokedAt != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// GetClientByID retrieves an OAuth client by its client_id
|
||||
func (s *OAuthService) GetClientByID(ctx context.Context, clientID string) (*models.OAuthClient, error) {
|
||||
return s.ValidateClient(ctx, clientID)
|
||||
}
|
||||
|
||||
// CreateClient creates a new OAuth client (admin only)
|
||||
func (s *OAuthService) CreateClient(
|
||||
ctx context.Context,
|
||||
name, description string,
|
||||
redirectURIs, scopes, grantTypes []string,
|
||||
isPublic bool,
|
||||
createdBy *uuid.UUID,
|
||||
) (*models.OAuthClient, string, error) {
|
||||
// Generate client_id
|
||||
clientIDBytes := make([]byte, 16)
|
||||
rand.Read(clientIDBytes)
|
||||
clientID := hex.EncodeToString(clientIDBytes)
|
||||
|
||||
// Generate client_secret for confidential clients
|
||||
var clientSecret string
|
||||
var clientSecretPtr *string
|
||||
if !isPublic {
|
||||
secretBytes := make([]byte, 32)
|
||||
rand.Read(secretBytes)
|
||||
clientSecret = base64.URLEncoding.EncodeToString(secretBytes)
|
||||
clientSecretPtr = &clientSecret
|
||||
}
|
||||
|
||||
redirectURIsJSON, _ := json.Marshal(redirectURIs)
|
||||
scopesJSON, _ := json.Marshal(scopes)
|
||||
grantTypesJSON, _ := json.Marshal(grantTypes)
|
||||
|
||||
var client models.OAuthClient
|
||||
err := s.db.QueryRow(ctx, `
|
||||
INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id, client_id, name, is_public, is_active, created_at
|
||||
`, clientID, clientSecretPtr, name, description, redirectURIsJSON, scopesJSON, grantTypesJSON, isPublic, createdBy).Scan(
|
||||
&client.ID, &client.ClientID, &client.Name, &client.IsPublic, &client.IsActive, &client.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create client: %w", err)
|
||||
}
|
||||
|
||||
client.RedirectURIs = redirectURIs
|
||||
client.Scopes = scopes
|
||||
client.GrantTypes = grantTypes
|
||||
|
||||
return &client, clientSecret, nil
|
||||
}
|
||||
855
consent-service/internal/services/oauth_service_test.go
Normal file
855
consent-service/internal/services/oauth_service_test.go
Normal file
@@ -0,0 +1,855 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestPKCEVerification tests PKCE code_challenge and code_verifier validation
|
||||
func TestPKCEVerification_S256_ValidVerifier(t *testing.T) {
|
||||
// Generate a code_verifier
|
||||
codeVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
// Calculate expected code_challenge (S256)
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
// Verify the challenge matches
|
||||
verifierHash := sha256.Sum256([]byte(codeVerifier))
|
||||
calculatedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
|
||||
|
||||
if calculatedChallenge != codeChallenge {
|
||||
t.Errorf("PKCE verification failed: expected %s, got %s", codeChallenge, calculatedChallenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCEVerification_S256_InvalidVerifier(t *testing.T) {
|
||||
codeVerifier := "correct-verifier-12345678901234567890"
|
||||
wrongVerifier := "wrong-verifier-00000000000000000000"
|
||||
|
||||
// Calculate code_challenge from correct verifier
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
// Calculate challenge from wrong verifier
|
||||
wrongHash := sha256.Sum256([]byte(wrongVerifier))
|
||||
wrongChallenge := base64.RawURLEncoding.EncodeToString(wrongHash[:])
|
||||
|
||||
if wrongChallenge == codeChallenge {
|
||||
t.Error("PKCE verification should fail for wrong verifier")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCEVerification_Plain_ValidVerifier(t *testing.T) {
|
||||
codeVerifier := "plain-text-verifier-12345"
|
||||
codeChallenge := codeVerifier // Plain method: challenge = verifier
|
||||
|
||||
if codeVerifier != codeChallenge {
|
||||
t.Error("Plain PKCE verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenHashing tests that token hashing is consistent
|
||||
func TestTokenHashing_Consistency(t *testing.T) {
|
||||
token := "sample-access-token-12345"
|
||||
|
||||
hash1 := sha256.Sum256([]byte(token))
|
||||
hash2 := sha256.Sum256([]byte(token))
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("Token hashing should be consistent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHashing_DifferentTokens(t *testing.T) {
|
||||
token1 := "token-1-abcdefgh"
|
||||
token2 := "token-2-ijklmnop"
|
||||
|
||||
hash1 := sha256.Sum256([]byte(token1))
|
||||
hash2 := sha256.Sum256([]byte(token2))
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Different tokens should produce different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestScopeValidation tests scope parsing and validation
|
||||
func TestScopeValidation_ParseScopes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedScope string
|
||||
allowedScopes []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "all scopes allowed",
|
||||
requestedScope: "openid profile email",
|
||||
allowedScopes: []string{"openid", "profile", "email", "offline_access"},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "some scopes allowed",
|
||||
requestedScope: "openid profile admin",
|
||||
allowedScopes: []string{"openid", "profile", "email"},
|
||||
expectedCount: 2, // admin not allowed
|
||||
},
|
||||
{
|
||||
name: "no scopes allowed",
|
||||
requestedScope: "admin superuser",
|
||||
allowedScopes: []string{"openid", "profile", "email"},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty request defaults",
|
||||
requestedScope: "",
|
||||
allowedScopes: []string{"openid", "profile", "email"},
|
||||
expectedCount: 0, // Empty request returns 0 from this test logic
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.requestedScope == "" {
|
||||
// Empty scope should use defaults in actual service
|
||||
return
|
||||
}
|
||||
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, scope := range tt.allowedScopes {
|
||||
allowedMap[scope] = true
|
||||
}
|
||||
|
||||
var validScopes []string
|
||||
requestedScopes := splitScopes(tt.requestedScope)
|
||||
for _, scope := range requestedScopes {
|
||||
if allowedMap[scope] {
|
||||
validScopes = append(validScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validScopes) != tt.expectedCount {
|
||||
t.Errorf("Expected %d valid scopes, got %d", tt.expectedCount, len(validScopes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for scope splitting
|
||||
func splitScopes(scopes string) []string {
|
||||
if scopes == "" {
|
||||
return nil
|
||||
}
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i <= len(scopes); i++ {
|
||||
if i == len(scopes) || scopes[i] == ' ' {
|
||||
if start < i {
|
||||
result = append(result, scopes[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TestRedirectURIValidation tests redirect URI validation
|
||||
func TestRedirectURIValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registeredURIs []string
|
||||
requestURI string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
registeredURIs: []string{"https://example.com/callback"},
|
||||
requestURI: "https://example.com/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "no match different domain",
|
||||
registeredURIs: []string{"https://example.com/callback"},
|
||||
requestURI: "https://evil.com/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no match different path",
|
||||
registeredURIs: []string{"https://example.com/callback"},
|
||||
requestURI: "https://example.com/other",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "multiple URIs - second matches",
|
||||
registeredURIs: []string{"https://example.com/callback", "https://example.com/auth"},
|
||||
requestURI: "https://example.com/auth",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "localhost for development",
|
||||
registeredURIs: []string{"http://localhost:3000/callback"},
|
||||
requestURI: "http://localhost:3000/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matched := false
|
||||
for _, uri := range tt.registeredURIs {
|
||||
if uri == tt.requestURI {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matched != tt.shouldMatch {
|
||||
t.Errorf("Expected match=%v, got match=%v", tt.shouldMatch, matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGrantTypeValidation tests grant type validation
|
||||
func TestGrantTypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedGrants []string
|
||||
requestedGrant string
|
||||
shouldAllow bool
|
||||
}{
|
||||
{
|
||||
name: "authorization_code allowed",
|
||||
allowedGrants: []string{"authorization_code", "refresh_token"},
|
||||
requestedGrant: "authorization_code",
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "refresh_token allowed",
|
||||
allowedGrants: []string{"authorization_code", "refresh_token"},
|
||||
requestedGrant: "refresh_token",
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "password not allowed",
|
||||
allowedGrants: []string{"authorization_code", "refresh_token"},
|
||||
requestedGrant: "password",
|
||||
shouldAllow: false,
|
||||
},
|
||||
{
|
||||
name: "client_credentials not allowed",
|
||||
allowedGrants: []string{"authorization_code", "refresh_token"},
|
||||
requestedGrant: "client_credentials",
|
||||
shouldAllow: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
allowed := false
|
||||
for _, grant := range tt.allowedGrants {
|
||||
if grant == tt.requestedGrant {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.shouldAllow {
|
||||
t.Errorf("Expected allow=%v, got allow=%v", tt.shouldAllow, allowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthorizationCodeExpiry tests that expired codes should be rejected
|
||||
func TestAuthorizationCodeExpiry_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiryMins int
|
||||
usedAfter int // minutes after creation
|
||||
shouldAllow bool
|
||||
}{
|
||||
{
|
||||
name: "code used within expiry",
|
||||
expiryMins: 10,
|
||||
usedAfter: 5,
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "code used at expiry boundary",
|
||||
expiryMins: 10,
|
||||
usedAfter: 10,
|
||||
shouldAllow: false, // Expired at exactly 10 mins
|
||||
},
|
||||
{
|
||||
name: "code used after expiry",
|
||||
expiryMins: 10,
|
||||
usedAfter: 15,
|
||||
shouldAllow: false,
|
||||
},
|
||||
{
|
||||
name: "code used immediately",
|
||||
expiryMins: 10,
|
||||
usedAfter: 0,
|
||||
shouldAllow: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.usedAfter < tt.expiryMins
|
||||
|
||||
if isValid != tt.shouldAllow {
|
||||
t.Errorf("Expected allow=%v for code used after %d mins (expiry: %d mins)",
|
||||
tt.shouldAllow, tt.usedAfter, tt.expiryMins)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientSecretValidation tests confidential client authentication
|
||||
func TestClientSecretValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isPublic bool
|
||||
storedSecret string
|
||||
providedSecret string
|
||||
shouldAllow bool
|
||||
}{
|
||||
{
|
||||
name: "public client - no secret needed",
|
||||
isPublic: true,
|
||||
storedSecret: "",
|
||||
providedSecret: "",
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "confidential client - correct secret",
|
||||
isPublic: false,
|
||||
storedSecret: "super-secret-123",
|
||||
providedSecret: "super-secret-123",
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "confidential client - wrong secret",
|
||||
isPublic: false,
|
||||
storedSecret: "super-secret-123",
|
||||
providedSecret: "wrong-secret",
|
||||
shouldAllow: false,
|
||||
},
|
||||
{
|
||||
name: "confidential client - empty secret",
|
||||
isPublic: false,
|
||||
storedSecret: "super-secret-123",
|
||||
providedSecret: "",
|
||||
shouldAllow: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var isValid bool
|
||||
if tt.isPublic {
|
||||
isValid = true
|
||||
} else {
|
||||
isValid = tt.storedSecret == tt.providedSecret
|
||||
}
|
||||
|
||||
if isValid != tt.shouldAllow {
|
||||
t.Errorf("Expected allow=%v, got allow=%v", tt.shouldAllow, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Extended OAuth 2.0 Tests
|
||||
// ========================================
|
||||
|
||||
// TestCodeVerifierGeneration tests that code verifiers meet RFC 7636 requirements
|
||||
func TestCodeVerifierGeneration_RFC7636(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
expectedLength int
|
||||
description string
|
||||
}{
|
||||
{"minimum length (43)", 43, 43, "RFC 7636 minimum"},
|
||||
{"standard length (64)", 64, 64, "Recommended length"},
|
||||
{"maximum length (128)", 128, 128, "RFC 7636 maximum"},
|
||||
{"too short (42) - corrected to minimum", 42, 43, "Should be corrected to minimum"},
|
||||
{"too long (129) - corrected to maximum", 129, 128, "Should be corrected to maximum"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifier := generateCodeVerifier(tt.length)
|
||||
|
||||
// Check that length is corrected to valid range
|
||||
if len(verifier) != tt.expectedLength {
|
||||
t.Errorf("Expected length %d, got %d", tt.expectedLength, len(verifier))
|
||||
}
|
||||
|
||||
// Check character set (unreserved characters only: A-Z, a-z, 0-9, -, ., _, ~)
|
||||
for _, c := range verifier {
|
||||
if !isUnreservedChar(c) {
|
||||
t.Errorf("Code verifier contains invalid character: %c", c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCodeVerifierLength_Validation tests length validation logic
|
||||
func TestCodeVerifierLength_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
isValid bool
|
||||
}{
|
||||
{"length 42 - too short", 42, false},
|
||||
{"length 43 - minimum valid", 43, true},
|
||||
{"length 64 - recommended", 64, true},
|
||||
{"length 128 - maximum valid", 128, true},
|
||||
{"length 129 - too long", 129, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.length >= 43 && tt.length <= 128
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("Expected valid=%v for length %d, got valid=%v",
|
||||
tt.isValid, tt.length, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a code verifier of specified length
|
||||
func generateCodeVerifier(length int) string {
|
||||
// Ensure minimum and maximum bounds
|
||||
if length < 43 {
|
||||
length = 43
|
||||
}
|
||||
if length > 128 {
|
||||
length = 128
|
||||
}
|
||||
|
||||
const unreserved = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
|
||||
result := make([]byte, length)
|
||||
for i, b := range bytes {
|
||||
result[i] = unreserved[int(b)%len(unreserved)]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// isUnreservedChar checks if a character is an unreserved character per RFC 3986
|
||||
func isUnreservedChar(c rune) bool {
|
||||
return (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '-' || c == '.' || c == '_' || c == '~'
|
||||
}
|
||||
|
||||
// TestCodeChallengeGeneration tests S256 challenge generation
|
||||
func TestCodeChallengeGeneration_S256(t *testing.T) {
|
||||
// Known test vector from RFC 7636 Appendix B
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
expectedChallenge := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
if challenge != expectedChallenge {
|
||||
t.Errorf("S256 challenge mismatch: expected %s, got %s", expectedChallenge, challenge)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshTokenRotation tests that refresh tokens are rotated on use
|
||||
func TestRefreshTokenRotation_Logic(t *testing.T) {
|
||||
// Simulate refresh token rotation
|
||||
oldToken := "old-refresh-token-123"
|
||||
oldTokenHash := hashToken(oldToken)
|
||||
|
||||
// Generate new token
|
||||
newToken := generateSecureToken(32)
|
||||
newTokenHash := hashToken(newToken)
|
||||
|
||||
// Verify tokens are different
|
||||
if oldTokenHash == newTokenHash {
|
||||
t.Error("New refresh token should be different from old token")
|
||||
}
|
||||
|
||||
// Verify old token would be revoked (simulated by marking revoked_at)
|
||||
oldTokenRevoked := true
|
||||
if !oldTokenRevoked {
|
||||
t.Error("Old refresh token should be revoked after rotation")
|
||||
}
|
||||
}
|
||||
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func generateSecureToken(length int) string {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return base64.URLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// TestAccessTokenExpiry tests access token expiration handling
|
||||
func TestAccessTokenExpiry_Scenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenDuration time.Duration
|
||||
usedAfter time.Duration
|
||||
shouldBeValid bool
|
||||
}{
|
||||
{
|
||||
name: "token used immediately",
|
||||
tokenDuration: time.Hour,
|
||||
usedAfter: 0,
|
||||
shouldBeValid: true,
|
||||
},
|
||||
{
|
||||
name: "token used within validity",
|
||||
tokenDuration: time.Hour,
|
||||
usedAfter: 30 * time.Minute,
|
||||
shouldBeValid: true,
|
||||
},
|
||||
{
|
||||
name: "token used at expiry",
|
||||
tokenDuration: time.Hour,
|
||||
usedAfter: time.Hour,
|
||||
shouldBeValid: false,
|
||||
},
|
||||
{
|
||||
name: "token used after expiry",
|
||||
tokenDuration: time.Hour,
|
||||
usedAfter: 2 * time.Hour,
|
||||
shouldBeValid: false,
|
||||
},
|
||||
{
|
||||
name: "short-lived token",
|
||||
tokenDuration: 5 * time.Minute,
|
||||
usedAfter: 6 * time.Minute,
|
||||
shouldBeValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
issuedAt := time.Now()
|
||||
expiresAt := issuedAt.Add(tt.tokenDuration)
|
||||
usedAt := issuedAt.Add(tt.usedAfter)
|
||||
|
||||
isValid := usedAt.Before(expiresAt)
|
||||
|
||||
if isValid != tt.shouldBeValid {
|
||||
t.Errorf("Expected valid=%v for token used after %v (duration: %v)",
|
||||
tt.shouldBeValid, tt.usedAfter, tt.tokenDuration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthErrors tests that OAuth error codes are correct
|
||||
func TestOAuthErrors_RFC6749(t *testing.T) {
|
||||
tests := []struct {
|
||||
scenario string
|
||||
errorCode string
|
||||
description string
|
||||
}{
|
||||
{"invalid client_id", "invalid_client", "Client authentication failed"},
|
||||
{"invalid grant (code)", "invalid_grant", "Authorization code invalid or expired"},
|
||||
{"invalid scope", "invalid_scope", "Requested scope is invalid"},
|
||||
{"invalid request", "invalid_request", "Request is missing required parameter"},
|
||||
{"unauthorized client", "unauthorized_client", "Client not authorized for this grant type"},
|
||||
{"access denied", "access_denied", "Resource owner denied the request"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.scenario, func(t *testing.T) {
|
||||
// Verify error codes match RFC 6749 Section 5.2
|
||||
validErrors := map[string]bool{
|
||||
"invalid_request": true,
|
||||
"invalid_client": true,
|
||||
"invalid_grant": true,
|
||||
"unauthorized_client": true,
|
||||
"unsupported_grant_type": true,
|
||||
"invalid_scope": true,
|
||||
"access_denied": true,
|
||||
"unsupported_response_type": true,
|
||||
"server_error": true,
|
||||
"temporarily_unavailable": true,
|
||||
}
|
||||
|
||||
if !validErrors[tt.errorCode] {
|
||||
t.Errorf("Error code %s is not a valid OAuth 2.0 error code", tt.errorCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStateParameter tests state parameter handling for CSRF protection
|
||||
func TestStateParameter_CSRF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requestState string
|
||||
responseState string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "matching state",
|
||||
requestState: "abc123xyz",
|
||||
responseState: "abc123xyz",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching state",
|
||||
requestState: "abc123xyz",
|
||||
responseState: "different",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "empty request state",
|
||||
requestState: "",
|
||||
responseState: "abc123xyz",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "empty response state",
|
||||
requestState: "abc123xyz",
|
||||
responseState: "",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := tt.requestState != "" && tt.requestState == tt.responseState
|
||||
|
||||
if matches != tt.shouldMatch {
|
||||
t.Errorf("Expected match=%v, got match=%v", tt.shouldMatch, matches)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseType tests response_type validation
|
||||
func TestResponseType_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
responseType string
|
||||
isValid bool
|
||||
}{
|
||||
{"code - valid", "code", true},
|
||||
{"token - implicit flow (disabled)", "token", false},
|
||||
{"id_token - OIDC", "id_token", false},
|
||||
{"code token - hybrid", "code token", false},
|
||||
{"empty", "", false},
|
||||
{"invalid", "password", false},
|
||||
}
|
||||
|
||||
supportedResponseTypes := map[string]bool{
|
||||
"code": true, // Only authorization code flow is supported
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := supportedResponseTypes[tt.responseType]
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("Expected valid=%v for response_type=%s, got valid=%v",
|
||||
tt.isValid, tt.responseType, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCodeChallengeMethod tests code_challenge_method validation
|
||||
func TestCodeChallengeMethod_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
isValid bool
|
||||
}{
|
||||
{"S256 - recommended", "S256", true},
|
||||
{"plain - discouraged but valid", "plain", true},
|
||||
{"empty - defaults to plain", "", true},
|
||||
{"sha512 - not supported", "sha512", false},
|
||||
{"invalid", "md5", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.method == "S256" || tt.method == "plain" || tt.method == ""
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("Expected valid=%v for method=%s, got valid=%v",
|
||||
tt.isValid, tt.method, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRevocation tests token revocation behavior per RFC 7009
|
||||
func TestTokenRevocation_RFC7009(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenExists bool
|
||||
tokenRevoked bool
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "revoke existing active token",
|
||||
tokenExists: true,
|
||||
tokenRevoked: false,
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "revoke already revoked token",
|
||||
tokenExists: true,
|
||||
tokenRevoked: true,
|
||||
expectSuccess: true, // RFC 7009: Always return 200
|
||||
},
|
||||
{
|
||||
name: "revoke non-existent token",
|
||||
tokenExists: false,
|
||||
tokenRevoked: false,
|
||||
expectSuccess: true, // RFC 7009: Always return 200
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate revocation logic
|
||||
// Per RFC 7009, revocation endpoint always returns 200 OK
|
||||
success := true
|
||||
|
||||
if success != tt.expectSuccess {
|
||||
t.Errorf("Expected success=%v, got success=%v", tt.expectSuccess, success)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientIDGeneration tests client_id format
|
||||
func TestClientIDGeneration_Format(t *testing.T) {
|
||||
// Generate multiple client IDs
|
||||
clientIDs := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
bytes := make([]byte, 16)
|
||||
rand.Read(bytes)
|
||||
clientID := hex.EncodeToString(bytes)
|
||||
|
||||
// Check format (32 hex characters)
|
||||
if len(clientID) != 32 {
|
||||
t.Errorf("Client ID should be 32 characters, got %d", len(clientID))
|
||||
}
|
||||
|
||||
// Check uniqueness
|
||||
if clientIDs[clientID] {
|
||||
t.Error("Client ID should be unique")
|
||||
}
|
||||
clientIDs[clientID] = true
|
||||
|
||||
// Check only hex characters
|
||||
for _, c := range clientID {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("Client ID should only contain hex characters, found %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestScopeNormalization tests scope string normalization
|
||||
func TestScopeNormalization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "single scope",
|
||||
input: "openid",
|
||||
expected: []string{"openid"},
|
||||
},
|
||||
{
|
||||
name: "multiple scopes",
|
||||
input: "openid profile email",
|
||||
expected: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "extra spaces",
|
||||
input: "openid profile email",
|
||||
expected: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
scopes := normalizeScopes(tt.input)
|
||||
|
||||
if len(scopes) != len(tt.expected) {
|
||||
t.Errorf("Expected %d scopes, got %d", len(tt.expected), len(scopes))
|
||||
return
|
||||
}
|
||||
|
||||
for i, scope := range scopes {
|
||||
if scope != tt.expected[i] {
|
||||
t.Errorf("Expected scope[%d]=%s, got %s", i, tt.expected[i], scope)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeScopes(scope string) []string {
|
||||
if scope == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
parts := strings.Fields(scope) // Handles multiple spaces
|
||||
return parts
|
||||
}
|
||||
|
||||
// BenchmarkPKCEVerification benchmarks PKCE S256 verification
|
||||
func BenchmarkPKCEVerification_S256(b *testing.B) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTokenHashing benchmarks token hashing for storage
|
||||
func BenchmarkTokenHashing(b *testing.B) {
|
||||
token := "sample-access-token-12345678901234567890"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
hex.EncodeToString(hash[:])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCodeVerifierGeneration benchmarks code verifier generation
|
||||
func BenchmarkCodeVerifierGeneration(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
generateCodeVerifier(64)
|
||||
}
|
||||
}
|
||||
698
consent-service/internal/services/school_service.go
Normal file
698
consent-service/internal/services/school_service.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/database"
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services/matrix"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SchoolService handles school management operations
|
||||
type SchoolService struct {
|
||||
db *database.DB
|
||||
matrix *matrix.MatrixService
|
||||
}
|
||||
|
||||
// NewSchoolService creates a new school service
|
||||
func NewSchoolService(db *database.DB, matrixService *matrix.MatrixService) *SchoolService {
|
||||
return &SchoolService{
|
||||
db: db,
|
||||
matrix: matrixService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// School CRUD
|
||||
// ========================================
|
||||
|
||||
// CreateSchool creates a new school
|
||||
func (s *SchoolService) CreateSchool(ctx context.Context, req models.CreateSchoolRequest) (*models.School, error) {
|
||||
school := &models.School{
|
||||
ID: uuid.New(),
|
||||
Name: req.Name,
|
||||
ShortName: req.ShortName,
|
||||
Type: req.Type,
|
||||
Address: req.Address,
|
||||
City: req.City,
|
||||
PostalCode: req.PostalCode,
|
||||
State: req.State,
|
||||
Country: "DE",
|
||||
Phone: req.Phone,
|
||||
Email: req.Email,
|
||||
Website: req.Website,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO schools (id, name, short_name, type, address, city, postal_code, state, country, phone, email, website, is_active, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
|
||||
RETURNING id`
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
school.ID, school.Name, school.ShortName, school.Type, school.Address,
|
||||
school.City, school.PostalCode, school.State, school.Country, school.Phone,
|
||||
school.Email, school.Website, school.IsActive, school.CreatedAt, school.UpdatedAt,
|
||||
).Scan(&school.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create school: %w", err)
|
||||
}
|
||||
|
||||
// Create default timetable slots for the school
|
||||
if err := s.createDefaultTimetableSlots(ctx, school.ID); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Printf("Warning: failed to create default timetable slots: %v\n", err)
|
||||
}
|
||||
|
||||
// Create default grade scale
|
||||
if err := s.createDefaultGradeScale(ctx, school.ID); err != nil {
|
||||
fmt.Printf("Warning: failed to create default grade scale: %v\n", err)
|
||||
}
|
||||
|
||||
return school, nil
|
||||
}
|
||||
|
||||
// GetSchool retrieves a school by ID
|
||||
func (s *SchoolService) GetSchool(ctx context.Context, schoolID uuid.UUID) (*models.School, error) {
|
||||
query := `
|
||||
SELECT id, name, short_name, type, address, city, postal_code, state, country, phone, email, website, matrix_server_name, logo_url, is_active, created_at, updated_at
|
||||
FROM schools
|
||||
WHERE id = $1`
|
||||
|
||||
school := &models.School{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, schoolID).Scan(
|
||||
&school.ID, &school.Name, &school.ShortName, &school.Type, &school.Address,
|
||||
&school.City, &school.PostalCode, &school.State, &school.Country, &school.Phone,
|
||||
&school.Email, &school.Website, &school.MatrixServerName, &school.LogoURL,
|
||||
&school.IsActive, &school.CreatedAt, &school.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get school: %w", err)
|
||||
}
|
||||
|
||||
return school, nil
|
||||
}
|
||||
|
||||
// ListSchools lists all active schools
|
||||
func (s *SchoolService) ListSchools(ctx context.Context) ([]models.School, error) {
|
||||
query := `
|
||||
SELECT id, name, short_name, type, address, city, postal_code, state, country, phone, email, website, matrix_server_name, logo_url, is_active, created_at, updated_at
|
||||
FROM schools
|
||||
WHERE is_active = true
|
||||
ORDER BY name`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list schools: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var schools []models.School
|
||||
for rows.Next() {
|
||||
var school models.School
|
||||
err := rows.Scan(
|
||||
&school.ID, &school.Name, &school.ShortName, &school.Type, &school.Address,
|
||||
&school.City, &school.PostalCode, &school.State, &school.Country, &school.Phone,
|
||||
&school.Email, &school.Website, &school.MatrixServerName, &school.LogoURL,
|
||||
&school.IsActive, &school.CreatedAt, &school.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan school: %w", err)
|
||||
}
|
||||
schools = append(schools, school)
|
||||
}
|
||||
|
||||
return schools, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// School Year Management
|
||||
// ========================================
|
||||
|
||||
// CreateSchoolYear creates a new school year
|
||||
func (s *SchoolService) CreateSchoolYear(ctx context.Context, schoolID uuid.UUID, name string, startDate, endDate time.Time) (*models.SchoolYear, error) {
|
||||
schoolYear := &models.SchoolYear{
|
||||
ID: uuid.New(),
|
||||
SchoolID: schoolID,
|
||||
Name: name,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
IsCurrent: false,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO school_years (id, school_id, name, start_date, end_date, is_current, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id`
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
schoolYear.ID, schoolYear.SchoolID, schoolYear.Name,
|
||||
schoolYear.StartDate, schoolYear.EndDate, schoolYear.IsCurrent, schoolYear.CreatedAt,
|
||||
).Scan(&schoolYear.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create school year: %w", err)
|
||||
}
|
||||
|
||||
return schoolYear, nil
|
||||
}
|
||||
|
||||
// SetCurrentSchoolYear sets a school year as the current one
|
||||
func (s *SchoolService) SetCurrentSchoolYear(ctx context.Context, schoolID, schoolYearID uuid.UUID) error {
|
||||
// First, unset all current school years for this school
|
||||
_, err := s.db.Pool.Exec(ctx, `UPDATE school_years SET is_current = false WHERE school_id = $1`, schoolID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unset current school years: %w", err)
|
||||
}
|
||||
|
||||
// Then set the specified school year as current
|
||||
_, err = s.db.Pool.Exec(ctx, `UPDATE school_years SET is_current = true WHERE id = $1 AND school_id = $2`, schoolYearID, schoolID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set current school year: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentSchoolYear gets the current school year for a school
|
||||
func (s *SchoolService) GetCurrentSchoolYear(ctx context.Context, schoolID uuid.UUID) (*models.SchoolYear, error) {
|
||||
query := `
|
||||
SELECT id, school_id, name, start_date, end_date, is_current, created_at
|
||||
FROM school_years
|
||||
WHERE school_id = $1 AND is_current = true`
|
||||
|
||||
schoolYear := &models.SchoolYear{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, schoolID).Scan(
|
||||
&schoolYear.ID, &schoolYear.SchoolID, &schoolYear.Name,
|
||||
&schoolYear.StartDate, &schoolYear.EndDate, &schoolYear.IsCurrent, &schoolYear.CreatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current school year: %w", err)
|
||||
}
|
||||
|
||||
return schoolYear, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Class Management
|
||||
// ========================================
|
||||
|
||||
// CreateClass creates a new class
|
||||
func (s *SchoolService) CreateClass(ctx context.Context, schoolID uuid.UUID, req models.CreateClassRequest) (*models.Class, error) {
|
||||
schoolYearID, err := uuid.Parse(req.SchoolYearID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid school year ID: %w", err)
|
||||
}
|
||||
|
||||
class := &models.Class{
|
||||
ID: uuid.New(),
|
||||
SchoolID: schoolID,
|
||||
SchoolYearID: schoolYearID,
|
||||
Name: req.Name,
|
||||
Grade: req.Grade,
|
||||
Section: req.Section,
|
||||
Room: req.Room,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO classes (id, school_id, school_year_id, name, grade, section, room, is_active, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
class.ID, class.SchoolID, class.SchoolYearID, class.Name,
|
||||
class.Grade, class.Section, class.Room, class.IsActive, class.CreatedAt, class.UpdatedAt,
|
||||
).Scan(&class.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create class: %w", err)
|
||||
}
|
||||
|
||||
return class, nil
|
||||
}
|
||||
|
||||
// GetClass retrieves a class by ID
|
||||
func (s *SchoolService) GetClass(ctx context.Context, classID uuid.UUID) (*models.Class, error) {
|
||||
query := `
|
||||
SELECT id, school_id, school_year_id, name, grade, section, room, matrix_info_room, matrix_rep_room, is_active, created_at, updated_at
|
||||
FROM classes
|
||||
WHERE id = $1`
|
||||
|
||||
class := &models.Class{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, classID).Scan(
|
||||
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
|
||||
&class.Grade, &class.Section, &class.Room, &class.MatrixInfoRoom,
|
||||
&class.MatrixRepRoom, &class.IsActive, &class.CreatedAt, &class.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get class: %w", err)
|
||||
}
|
||||
|
||||
return class, nil
|
||||
}
|
||||
|
||||
// ListClasses lists all classes for a school in a school year
|
||||
func (s *SchoolService) ListClasses(ctx context.Context, schoolID, schoolYearID uuid.UUID) ([]models.Class, error) {
|
||||
query := `
|
||||
SELECT id, school_id, school_year_id, name, grade, section, room, matrix_info_room, matrix_rep_room, is_active, created_at, updated_at
|
||||
FROM classes
|
||||
WHERE school_id = $1 AND school_year_id = $2 AND is_active = true
|
||||
ORDER BY grade, name`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, schoolID, schoolYearID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list classes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var classes []models.Class
|
||||
for rows.Next() {
|
||||
var class models.Class
|
||||
err := rows.Scan(
|
||||
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
|
||||
&class.Grade, &class.Section, &class.Room, &class.MatrixInfoRoom,
|
||||
&class.MatrixRepRoom, &class.IsActive, &class.CreatedAt, &class.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan class: %w", err)
|
||||
}
|
||||
classes = append(classes, class)
|
||||
}
|
||||
|
||||
return classes, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Student Management
|
||||
// ========================================
|
||||
|
||||
// CreateStudent creates a new student
|
||||
func (s *SchoolService) CreateStudent(ctx context.Context, schoolID uuid.UUID, req models.CreateStudentRequest) (*models.Student, error) {
|
||||
classID, err := uuid.Parse(req.ClassID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid class ID: %w", err)
|
||||
}
|
||||
|
||||
student := &models.Student{
|
||||
ID: uuid.New(),
|
||||
SchoolID: schoolID,
|
||||
ClassID: classID,
|
||||
StudentNumber: req.StudentNumber,
|
||||
FirstName: req.FirstName,
|
||||
LastName: req.LastName,
|
||||
Gender: req.Gender,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if req.DateOfBirth != nil {
|
||||
dob, err := time.Parse("2006-01-02", *req.DateOfBirth)
|
||||
if err == nil {
|
||||
student.DateOfBirth = &dob
|
||||
}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO students (id, school_id, class_id, student_number, first_name, last_name, date_of_birth, gender, is_active, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
student.ID, student.SchoolID, student.ClassID, student.StudentNumber,
|
||||
student.FirstName, student.LastName, student.DateOfBirth, student.Gender,
|
||||
student.IsActive, student.CreatedAt, student.UpdatedAt,
|
||||
).Scan(&student.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create student: %w", err)
|
||||
}
|
||||
|
||||
return student, nil
|
||||
}
|
||||
|
||||
// GetStudent retrieves a student by ID
|
||||
func (s *SchoolService) GetStudent(ctx context.Context, studentID uuid.UUID) (*models.Student, error) {
|
||||
query := `
|
||||
SELECT id, school_id, class_id, user_id, student_number, first_name, last_name, date_of_birth, gender, matrix_user_id, matrix_dm_room, is_active, created_at, updated_at
|
||||
FROM students
|
||||
WHERE id = $1`
|
||||
|
||||
student := &models.Student{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, studentID).Scan(
|
||||
&student.ID, &student.SchoolID, &student.ClassID, &student.UserID,
|
||||
&student.StudentNumber, &student.FirstName, &student.LastName,
|
||||
&student.DateOfBirth, &student.Gender, &student.MatrixUserID,
|
||||
&student.MatrixDMRoom, &student.IsActive, &student.CreatedAt, &student.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get student: %w", err)
|
||||
}
|
||||
|
||||
return student, nil
|
||||
}
|
||||
|
||||
// ListStudentsByClass lists all students in a class
|
||||
func (s *SchoolService) ListStudentsByClass(ctx context.Context, classID uuid.UUID) ([]models.Student, error) {
|
||||
query := `
|
||||
SELECT id, school_id, class_id, user_id, student_number, first_name, last_name, date_of_birth, gender, matrix_user_id, matrix_dm_room, is_active, created_at, updated_at
|
||||
FROM students
|
||||
WHERE class_id = $1 AND is_active = true
|
||||
ORDER BY last_name, first_name`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, classID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list students: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var students []models.Student
|
||||
for rows.Next() {
|
||||
var student models.Student
|
||||
err := rows.Scan(
|
||||
&student.ID, &student.SchoolID, &student.ClassID, &student.UserID,
|
||||
&student.StudentNumber, &student.FirstName, &student.LastName,
|
||||
&student.DateOfBirth, &student.Gender, &student.MatrixUserID,
|
||||
&student.MatrixDMRoom, &student.IsActive, &student.CreatedAt, &student.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan student: %w", err)
|
||||
}
|
||||
students = append(students, student)
|
||||
}
|
||||
|
||||
return students, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Teacher Management
|
||||
// ========================================
|
||||
|
||||
// CreateTeacher creates a new teacher linked to a user account
|
||||
func (s *SchoolService) CreateTeacher(ctx context.Context, schoolID, userID uuid.UUID, firstName, lastName string, teacherCode, title *string) (*models.Teacher, error) {
|
||||
teacher := &models.Teacher{
|
||||
ID: uuid.New(),
|
||||
SchoolID: schoolID,
|
||||
UserID: userID,
|
||||
TeacherCode: teacherCode,
|
||||
Title: title,
|
||||
FirstName: firstName,
|
||||
LastName: lastName,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO teachers (id, school_id, user_id, teacher_code, title, first_name, last_name, is_active, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id`
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
teacher.ID, teacher.SchoolID, teacher.UserID, teacher.TeacherCode,
|
||||
teacher.Title, teacher.FirstName, teacher.LastName,
|
||||
teacher.IsActive, teacher.CreatedAt, teacher.UpdatedAt,
|
||||
).Scan(&teacher.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create teacher: %w", err)
|
||||
}
|
||||
|
||||
return teacher, nil
|
||||
}
|
||||
|
||||
// GetTeacher retrieves a teacher by ID
|
||||
func (s *SchoolService) GetTeacher(ctx context.Context, teacherID uuid.UUID) (*models.Teacher, error) {
|
||||
query := `
|
||||
SELECT id, school_id, user_id, teacher_code, title, first_name, last_name, matrix_user_id, is_active, created_at, updated_at
|
||||
FROM teachers
|
||||
WHERE id = $1`
|
||||
|
||||
teacher := &models.Teacher{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, teacherID).Scan(
|
||||
&teacher.ID, &teacher.SchoolID, &teacher.UserID, &teacher.TeacherCode,
|
||||
&teacher.Title, &teacher.FirstName, &teacher.LastName, &teacher.MatrixUserID,
|
||||
&teacher.IsActive, &teacher.CreatedAt, &teacher.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get teacher: %w", err)
|
||||
}
|
||||
|
||||
return teacher, nil
|
||||
}
|
||||
|
||||
// GetTeacherByUserID retrieves a teacher by their user ID
|
||||
func (s *SchoolService) GetTeacherByUserID(ctx context.Context, userID uuid.UUID) (*models.Teacher, error) {
|
||||
query := `
|
||||
SELECT id, school_id, user_id, teacher_code, title, first_name, last_name, matrix_user_id, is_active, created_at, updated_at
|
||||
FROM teachers
|
||||
WHERE user_id = $1 AND is_active = true`
|
||||
|
||||
teacher := &models.Teacher{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&teacher.ID, &teacher.SchoolID, &teacher.UserID, &teacher.TeacherCode,
|
||||
&teacher.Title, &teacher.FirstName, &teacher.LastName, &teacher.MatrixUserID,
|
||||
&teacher.IsActive, &teacher.CreatedAt, &teacher.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get teacher by user ID: %w", err)
|
||||
}
|
||||
|
||||
return teacher, nil
|
||||
}
|
||||
|
||||
// AssignClassTeacher assigns a teacher to a class
|
||||
func (s *SchoolService) AssignClassTeacher(ctx context.Context, classID, teacherID uuid.UUID, isPrimary bool) error {
|
||||
query := `
|
||||
INSERT INTO class_teachers (id, class_id, teacher_id, is_primary, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (class_id, teacher_id) DO UPDATE SET is_primary = EXCLUDED.is_primary`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query, uuid.New(), classID, teacherID, isPrimary, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to assign class teacher: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Subject Management
|
||||
// ========================================
|
||||
|
||||
// CreateSubject creates a new subject
|
||||
func (s *SchoolService) CreateSubject(ctx context.Context, schoolID uuid.UUID, name, shortName string, color *string) (*models.Subject, error) {
|
||||
subject := &models.Subject{
|
||||
ID: uuid.New(),
|
||||
SchoolID: schoolID,
|
||||
Name: name,
|
||||
ShortName: shortName,
|
||||
Color: color,
|
||||
IsActive: true,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO subjects (id, school_id, name, short_name, color, is_active, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id`
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
subject.ID, subject.SchoolID, subject.Name, subject.ShortName,
|
||||
subject.Color, subject.IsActive, subject.CreatedAt,
|
||||
).Scan(&subject.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create subject: %w", err)
|
||||
}
|
||||
|
||||
return subject, nil
|
||||
}
|
||||
|
||||
// ListSubjects lists all subjects for a school
|
||||
func (s *SchoolService) ListSubjects(ctx context.Context, schoolID uuid.UUID) ([]models.Subject, error) {
|
||||
query := `
|
||||
SELECT id, school_id, name, short_name, color, is_active, created_at
|
||||
FROM subjects
|
||||
WHERE school_id = $1 AND is_active = true
|
||||
ORDER BY name`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, schoolID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subjects: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var subjects []models.Subject
|
||||
for rows.Next() {
|
||||
var subject models.Subject
|
||||
err := rows.Scan(
|
||||
&subject.ID, &subject.SchoolID, &subject.Name, &subject.ShortName,
|
||||
&subject.Color, &subject.IsActive, &subject.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan subject: %w", err)
|
||||
}
|
||||
subjects = append(subjects, subject)
|
||||
}
|
||||
|
||||
return subjects, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Parent Onboarding
|
||||
// ========================================
|
||||
|
||||
// GenerateParentOnboardingToken generates a QR code token for parent onboarding
|
||||
func (s *SchoolService) GenerateParentOnboardingToken(ctx context.Context, schoolID, classID, studentID, createdByUserID uuid.UUID, role string) (*models.ParentOnboardingToken, error) {
|
||||
// Generate secure random token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
token := hex.EncodeToString(tokenBytes)
|
||||
|
||||
onboardingToken := &models.ParentOnboardingToken{
|
||||
ID: uuid.New(),
|
||||
SchoolID: schoolID,
|
||||
ClassID: classID,
|
||||
StudentID: studentID,
|
||||
Token: token,
|
||||
Role: role,
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour), // Valid for 72 hours
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: createdByUserID,
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO parent_onboarding_tokens (id, school_id, class_id, student_id, token, role, expires_at, created_at, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id`
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
onboardingToken.ID, onboardingToken.SchoolID, onboardingToken.ClassID,
|
||||
onboardingToken.StudentID, onboardingToken.Token, onboardingToken.Role,
|
||||
onboardingToken.ExpiresAt, onboardingToken.CreatedAt, onboardingToken.CreatedBy,
|
||||
).Scan(&onboardingToken.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create onboarding token: %w", err)
|
||||
}
|
||||
|
||||
return onboardingToken, nil
|
||||
}
|
||||
|
||||
// ValidateOnboardingToken validates and retrieves info for an onboarding token
|
||||
func (s *SchoolService) ValidateOnboardingToken(ctx context.Context, token string) (*models.ParentOnboardingToken, error) {
|
||||
query := `
|
||||
SELECT id, school_id, class_id, student_id, token, role, expires_at, used_at, used_by_user_id, created_at, created_by
|
||||
FROM parent_onboarding_tokens
|
||||
WHERE token = $1 AND used_at IS NULL AND expires_at > NOW()`
|
||||
|
||||
onboardingToken := &models.ParentOnboardingToken{}
|
||||
err := s.db.Pool.QueryRow(ctx, query, token).Scan(
|
||||
&onboardingToken.ID, &onboardingToken.SchoolID, &onboardingToken.ClassID,
|
||||
&onboardingToken.StudentID, &onboardingToken.Token, &onboardingToken.Role,
|
||||
&onboardingToken.ExpiresAt, &onboardingToken.UsedAt, &onboardingToken.UsedByUserID,
|
||||
&onboardingToken.CreatedAt, &onboardingToken.CreatedBy,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid or expired token: %w", err)
|
||||
}
|
||||
|
||||
return onboardingToken, nil
|
||||
}
|
||||
|
||||
// RedeemOnboardingToken marks a token as used and creates the parent account
|
||||
func (s *SchoolService) RedeemOnboardingToken(ctx context.Context, token string, userID uuid.UUID) error {
|
||||
query := `
|
||||
UPDATE parent_onboarding_tokens
|
||||
SET used_at = NOW(), used_by_user_id = $1
|
||||
WHERE token = $2 AND used_at IS NULL AND expires_at > NOW()`
|
||||
|
||||
result, err := s.db.Pool.Exec(ctx, query, userID, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to redeem token: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("token not found or already used")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Helper Functions
|
||||
// ========================================
|
||||
|
||||
func (s *SchoolService) createDefaultTimetableSlots(ctx context.Context, schoolID uuid.UUID) error {
|
||||
slots := []struct {
|
||||
Number int
|
||||
StartTime string
|
||||
EndTime string
|
||||
IsBreak bool
|
||||
Name string
|
||||
}{
|
||||
{1, "08:00", "08:45", false, "1. Stunde"},
|
||||
{2, "08:45", "09:30", false, "2. Stunde"},
|
||||
{3, "09:30", "09:50", true, "Erste Pause"},
|
||||
{4, "09:50", "10:35", false, "3. Stunde"},
|
||||
{5, "10:35", "11:20", false, "4. Stunde"},
|
||||
{6, "11:20", "11:40", true, "Zweite Pause"},
|
||||
{7, "11:40", "12:25", false, "5. Stunde"},
|
||||
{8, "12:25", "13:10", false, "6. Stunde"},
|
||||
{9, "13:10", "14:00", true, "Mittagspause"},
|
||||
{10, "14:00", "14:45", false, "7. Stunde"},
|
||||
{11, "14:45", "15:30", false, "8. Stunde"},
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO timetable_slots (id, school_id, slot_number, start_time, end_time, is_break, name)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (school_id, slot_number) DO NOTHING`
|
||||
|
||||
for _, slot := range slots {
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
uuid.New(), schoolID, slot.Number, slot.StartTime, slot.EndTime, slot.IsBreak, slot.Name,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchoolService) createDefaultGradeScale(ctx context.Context, schoolID uuid.UUID) error {
|
||||
query := `
|
||||
INSERT INTO grade_scales (id, school_id, name, min_value, max_value, passing_value, is_ascending, is_default, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT DO NOTHING`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
uuid.New(), schoolID, "1-6 (Noten)", 1.0, 6.0, 4.0, false, true, time.Now(),
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
424
consent-service/internal/services/school_service_test.go
Normal file
424
consent-service/internal/services/school_service_test.go
Normal file
@@ -0,0 +1,424 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestGenerateOnboardingToken tests the QR code token generation
|
||||
func TestGenerateOnboardingToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
studentID uuid.UUID
|
||||
createdBy uuid.UUID
|
||||
role string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid parent token",
|
||||
studentID: uuid.New(),
|
||||
createdBy: uuid.New(),
|
||||
role: "parent",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid parent_representative token",
|
||||
studentID: uuid.New(),
|
||||
createdBy: uuid.New(),
|
||||
role: "parent_representative",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty student ID",
|
||||
studentID: uuid.Nil,
|
||||
createdBy: uuid.New(),
|
||||
role: "parent",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &models.ParentOnboardingToken{
|
||||
StudentID: tt.studentID,
|
||||
CreatedBy: tt.createdBy,
|
||||
Role: tt.role,
|
||||
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
||||
}
|
||||
|
||||
// Validate token fields
|
||||
if tt.studentID == uuid.Nil && !tt.expectError {
|
||||
t.Errorf("expected error for empty student ID")
|
||||
}
|
||||
|
||||
if token.Role != "parent" && token.Role != "parent_representative" {
|
||||
if !tt.expectError {
|
||||
t.Errorf("invalid role: %s", token.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Check expiration is in the future
|
||||
if token.ExpiresAt.Before(time.Now()) {
|
||||
t.Errorf("token expiration should be in the future")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateSchoolData tests school data validation
|
||||
func TestValidateSchoolData(t *testing.T) {
|
||||
address1 := "Musterstraße 1, 20095 Hamburg"
|
||||
address2 := "Musterstraße 1"
|
||||
address3 := "Parkweg 5"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
school models.School
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid school",
|
||||
school: models.School{
|
||||
Name: "Testschule Hamburg",
|
||||
Address: &address1,
|
||||
Type: "gymnasium",
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "empty name",
|
||||
school: models.School{
|
||||
Name: "",
|
||||
Address: &address2,
|
||||
Type: "gymnasium",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "valid grundschule",
|
||||
school: models.School{
|
||||
Name: "Grundschule Am Park",
|
||||
Address: &address3,
|
||||
Type: "grundschule",
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateSchool(tt.school)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateSchool is a helper function for validation
|
||||
func validateSchool(school models.School) bool {
|
||||
if school.Name == "" {
|
||||
return false
|
||||
}
|
||||
if school.Type == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestValidateClassData tests class data validation
|
||||
func TestValidateClassData(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
class models.Class
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid class 5a",
|
||||
class: models.Class{
|
||||
Name: "5a",
|
||||
Grade: 5,
|
||||
SchoolID: uuid.New(),
|
||||
SchoolYearID: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "invalid grade level 0",
|
||||
class: models.Class{
|
||||
Name: "0a",
|
||||
Grade: 0,
|
||||
SchoolID: uuid.New(),
|
||||
SchoolYearID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid grade level 14",
|
||||
class: models.Class{
|
||||
Name: "14a",
|
||||
Grade: 14,
|
||||
SchoolID: uuid.New(),
|
||||
SchoolYearID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing school ID",
|
||||
class: models.Class{
|
||||
Name: "5a",
|
||||
Grade: 5,
|
||||
SchoolID: uuid.Nil,
|
||||
SchoolYearID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateClass(tt.class)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateClass is a helper function for class validation
|
||||
func validateClass(class models.Class) bool {
|
||||
if class.Name == "" {
|
||||
return false
|
||||
}
|
||||
if class.Grade < 1 || class.Grade > 13 {
|
||||
return false
|
||||
}
|
||||
if class.SchoolID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if class.SchoolYearID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestValidateStudentData tests student data validation
|
||||
func TestValidateStudentData(t *testing.T) {
|
||||
dob := time.Date(2014, 5, 15, 0, 0, 0, 0, time.UTC)
|
||||
futureDob := time.Now().AddDate(1, 0, 0)
|
||||
studentNum := "2024-001"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
student models.Student
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid student",
|
||||
student: models.Student{
|
||||
FirstName: "Max",
|
||||
LastName: "Mustermann",
|
||||
DateOfBirth: &dob,
|
||||
SchoolID: uuid.New(),
|
||||
ClassID: uuid.New(),
|
||||
StudentNumber: &studentNum,
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "empty first name",
|
||||
student: models.Student{
|
||||
FirstName: "",
|
||||
LastName: "Mustermann",
|
||||
DateOfBirth: &dob,
|
||||
SchoolID: uuid.New(),
|
||||
ClassID: uuid.New(),
|
||||
StudentNumber: &studentNum,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "future birth date",
|
||||
student: models.Student{
|
||||
FirstName: "Max",
|
||||
LastName: "Mustermann",
|
||||
DateOfBirth: &futureDob,
|
||||
SchoolID: uuid.New(),
|
||||
ClassID: uuid.New(),
|
||||
StudentNumber: &studentNum,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing class ID",
|
||||
student: models.Student{
|
||||
FirstName: "Max",
|
||||
LastName: "Mustermann",
|
||||
DateOfBirth: &dob,
|
||||
SchoolID: uuid.New(),
|
||||
ClassID: uuid.Nil,
|
||||
StudentNumber: &studentNum,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateStudent(tt.student)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateStudent is a helper function for student validation
|
||||
func validateStudent(student models.Student) bool {
|
||||
if student.FirstName == "" || student.LastName == "" {
|
||||
return false
|
||||
}
|
||||
if student.DateOfBirth != nil && student.DateOfBirth.After(time.Now()) {
|
||||
return false
|
||||
}
|
||||
if student.ClassID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestValidateTeacherData tests teacher data validation
|
||||
func TestValidateTeacherData(t *testing.T) {
|
||||
code := "SCH"
|
||||
codeLong := "SCHMI"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
teacher models.Teacher
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid teacher",
|
||||
teacher: models.Teacher{
|
||||
FirstName: "Anna",
|
||||
LastName: "Schmidt",
|
||||
UserID: uuid.New(),
|
||||
TeacherCode: &code,
|
||||
SchoolID: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "empty first name",
|
||||
teacher: models.Teacher{
|
||||
FirstName: "",
|
||||
LastName: "Schmidt",
|
||||
UserID: uuid.New(),
|
||||
TeacherCode: &code,
|
||||
SchoolID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "teacher code too long",
|
||||
teacher: models.Teacher{
|
||||
FirstName: "Anna",
|
||||
LastName: "Schmidt",
|
||||
UserID: uuid.New(),
|
||||
TeacherCode: &codeLong,
|
||||
SchoolID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateTeacher(tt.teacher)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateTeacher is a helper function for teacher validation
|
||||
func validateTeacher(teacher models.Teacher) bool {
|
||||
if teacher.FirstName == "" || teacher.LastName == "" {
|
||||
return false
|
||||
}
|
||||
if teacher.TeacherCode != nil && len(*teacher.TeacherCode) > 4 {
|
||||
return false
|
||||
}
|
||||
if teacher.SchoolID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestSchoolYearValidation tests school year date validation
|
||||
func TestSchoolYearValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
year models.SchoolYear
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid school year 2024/2025",
|
||||
year: models.SchoolYear{
|
||||
Name: "2024/2025",
|
||||
StartDate: time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 7, 31, 0, 0, 0, 0, time.UTC),
|
||||
SchoolID: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "end before start",
|
||||
year: models.SchoolYear{
|
||||
Name: "2024/2025",
|
||||
StartDate: time.Date(2025, 8, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2024, 7, 31, 0, 0, 0, 0, time.UTC),
|
||||
SchoolID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "same start and end",
|
||||
year: models.SchoolYear{
|
||||
Name: "2024/2025",
|
||||
StartDate: time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC),
|
||||
SchoolID: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateSchoolYear(tt.year)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateSchoolYear is a helper function for school year validation
|
||||
func validateSchoolYear(year models.SchoolYear) bool {
|
||||
if year.Name == "" {
|
||||
return false
|
||||
}
|
||||
if year.SchoolID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if !year.EndDate.After(year.StartDate) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
15
consent-service/internal/services/test_helpers.go
Normal file
15
consent-service/internal/services/test_helpers.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package services
|
||||
|
||||
// ValidationError represents a validation error in tests
|
||||
// This is a shared test helper type used across multiple test files
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
if e.Field != "" {
|
||||
return e.Field + ": " + e.Message
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
485
consent-service/internal/services/totp_service.go
Normal file
485
consent-service/internal/services/totp_service.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
qrcode "github.com/skip2/go-qrcode"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTOTPNotEnabled = errors.New("2FA is not enabled for this user")
|
||||
ErrTOTPAlreadyEnabled = errors.New("2FA is already enabled for this user")
|
||||
ErrTOTPInvalidCode = errors.New("invalid 2FA code")
|
||||
ErrTOTPChallengeExpired = errors.New("2FA challenge expired")
|
||||
ErrRecoveryCodeInvalid = errors.New("invalid recovery code")
|
||||
ErrRecoveryCodeUsed = errors.New("recovery code already used")
|
||||
)
|
||||
|
||||
const (
|
||||
TOTPPeriod = 30 // TOTP period in seconds
|
||||
TOTPDigits = 6 // Number of digits in TOTP code
|
||||
TOTPSecretLen = 20 // Length of TOTP secret in bytes
|
||||
RecoveryCodeCount = 10 // Number of recovery codes to generate
|
||||
RecoveryCodeLen = 8 // Length of each recovery code
|
||||
ChallengeExpiry = 5 * time.Minute // 2FA challenge expiry
|
||||
)
|
||||
|
||||
// TOTPService handles Two-Factor Authentication using TOTP
|
||||
type TOTPService struct {
|
||||
db *pgxpool.Pool
|
||||
issuer string
|
||||
}
|
||||
|
||||
// NewTOTPService creates a new TOTPService
|
||||
func NewTOTPService(db *pgxpool.Pool, issuer string) *TOTPService {
|
||||
return &TOTPService{
|
||||
db: db,
|
||||
issuer: issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecret generates a new TOTP secret
|
||||
func (s *TOTPService) GenerateSecret() (string, error) {
|
||||
secret := make([]byte, TOTPSecretLen)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return "", fmt.Errorf("failed to generate secret: %w", err)
|
||||
}
|
||||
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil
|
||||
}
|
||||
|
||||
// GenerateRecoveryCodes generates a set of recovery codes
|
||||
func (s *TOTPService) GenerateRecoveryCodes() ([]string, error) {
|
||||
codes := make([]string, RecoveryCodeCount)
|
||||
for i := 0; i < RecoveryCodeCount; i++ {
|
||||
codeBytes := make([]byte, RecoveryCodeLen/2)
|
||||
if _, err := rand.Read(codeBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate recovery code: %w", err)
|
||||
}
|
||||
codes[i] = strings.ToUpper(hex.EncodeToString(codeBytes))
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// GenerateQRCode generates a QR code for TOTP setup
|
||||
func (s *TOTPService) GenerateQRCode(secret, email string) (string, error) {
|
||||
// Create otpauth URL
|
||||
otpauthURL := fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d",
|
||||
s.issuer, email, secret, s.issuer, TOTPDigits, TOTPPeriod)
|
||||
|
||||
// Generate QR code
|
||||
qr, err := qrcode.New(otpauthURL, qrcode.Medium)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate QR code: %w", err)
|
||||
}
|
||||
|
||||
// Convert to PNG
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, qr.Image(256)); err != nil {
|
||||
return "", fmt.Errorf("failed to encode QR code: %w", err)
|
||||
}
|
||||
|
||||
// Convert to data URL
|
||||
dataURL := fmt.Sprintf("data:image/png;base64,%s", base64.StdEncoding.EncodeToString(buf.Bytes()))
|
||||
return dataURL, nil
|
||||
}
|
||||
|
||||
// GenerateTOTP generates the current TOTP code for a secret
|
||||
func (s *TOTPService) GenerateTOTP(secret string) (string, error) {
|
||||
return s.GenerateTOTPAt(secret, time.Now())
|
||||
}
|
||||
|
||||
// GenerateTOTPAt generates a TOTP code for a specific time
|
||||
func (s *TOTPService) GenerateTOTPAt(secret string, t time.Time) (string, error) {
|
||||
// Decode secret
|
||||
secretBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid secret: %w", err)
|
||||
}
|
||||
|
||||
// Calculate counter
|
||||
counter := uint64(t.Unix()) / TOTPPeriod
|
||||
|
||||
// Generate HOTP
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, counter)
|
||||
|
||||
mac := hmac.New(sha1.New, secretBytes)
|
||||
mac.Write(buf)
|
||||
hash := mac.Sum(nil)
|
||||
|
||||
// Dynamic truncation
|
||||
offset := hash[len(hash)-1] & 0x0f
|
||||
code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff
|
||||
|
||||
// Format code
|
||||
codeStr := fmt.Sprintf("%0*d", TOTPDigits, code%1000000)
|
||||
return codeStr, nil
|
||||
}
|
||||
|
||||
// ValidateTOTP validates a TOTP code (allows 1 period drift)
|
||||
func (s *TOTPService) ValidateTOTP(secret, code string) bool {
|
||||
now := time.Now()
|
||||
|
||||
// Check current, previous, and next period
|
||||
for _, offset := range []int{0, -1, 1} {
|
||||
t := now.Add(time.Duration(offset*TOTPPeriod) * time.Second)
|
||||
expected, err := s.GenerateTOTPAt(secret, t)
|
||||
if err == nil && expected == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Setup2FA initiates 2FA setup for a user
|
||||
func (s *TOTPService) Setup2FA(ctx context.Context, userID uuid.UUID, email string) (*models.Setup2FAResponse, error) {
|
||||
// Check if 2FA is already enabled
|
||||
var exists bool
|
||||
err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM user_totp WHERE user_id = $1 AND verified = true)`, userID).Scan(&exists)
|
||||
if err == nil && exists {
|
||||
return nil, ErrTOTPAlreadyEnabled
|
||||
}
|
||||
|
||||
// Generate secret
|
||||
secret, err := s.GenerateSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate recovery codes
|
||||
recoveryCodes, err := s.GenerateRecoveryCodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate QR code
|
||||
qrCode, err := s.GenerateQRCode(secret, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Hash recovery codes for storage
|
||||
hashedCodes := make([]string, len(recoveryCodes))
|
||||
for i, code := range recoveryCodes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
recoveryCodesJSON, _ := json.Marshal(hashedCodes)
|
||||
|
||||
// Store or update TOTP record (unverified)
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO user_totp (user_id, secret, verified, recovery_codes, created_at, updated_at)
|
||||
VALUES ($1, $2, false, $3, NOW(), NOW())
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
secret = $2,
|
||||
verified = false,
|
||||
recovery_codes = $3,
|
||||
updated_at = NOW()
|
||||
`, userID, secret, recoveryCodesJSON)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store TOTP: %w", err)
|
||||
}
|
||||
|
||||
return &models.Setup2FAResponse{
|
||||
Secret: secret,
|
||||
QRCodeDataURL: qrCode,
|
||||
RecoveryCodes: recoveryCodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify2FASetup verifies the 2FA setup with a code
|
||||
func (s *TOTPService) Verify2FASetup(ctx context.Context, userID uuid.UUID, code string) error {
|
||||
// Get TOTP record
|
||||
var secret string
|
||||
var verified bool
|
||||
err := s.db.QueryRow(ctx, `SELECT secret, verified FROM user_totp WHERE user_id = $1`, userID).Scan(&secret, &verified)
|
||||
if err != nil {
|
||||
return ErrTOTPNotEnabled
|
||||
}
|
||||
|
||||
if verified {
|
||||
return ErrTOTPAlreadyEnabled
|
||||
}
|
||||
|
||||
// Validate code
|
||||
if !s.ValidateTOTP(secret, code) {
|
||||
return ErrTOTPInvalidCode
|
||||
}
|
||||
|
||||
// Mark as verified and enable 2FA
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE user_totp SET verified = true, enabled_at = NOW(), updated_at = NOW() WHERE user_id = $1
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify TOTP: %w", err)
|
||||
}
|
||||
|
||||
// Update user record
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE users SET two_factor_enabled = true, two_factor_verified_at = NOW(), updated_at = NOW() WHERE id = $1
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChallenge creates a 2FA challenge for login
|
||||
func (s *TOTPService) CreateChallenge(ctx context.Context, userID uuid.UUID, ipAddress, userAgent string) (string, error) {
|
||||
// Generate challenge ID
|
||||
challengeBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(challengeBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate challenge: %w", err)
|
||||
}
|
||||
challengeID := base64.URLEncoding.EncodeToString(challengeBytes)
|
||||
|
||||
// Store challenge
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO two_factor_challenges (user_id, challenge_id, ip_address, user_agent, expires_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
`, userID, challengeID, ipAddress, userAgent, time.Now().Add(ChallengeExpiry))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create challenge: %w", err)
|
||||
}
|
||||
|
||||
return challengeID, nil
|
||||
}
|
||||
|
||||
// VerifyChallenge verifies a 2FA challenge with a TOTP code
|
||||
func (s *TOTPService) VerifyChallenge(ctx context.Context, challengeID, code string) (*uuid.UUID, error) {
|
||||
var challenge models.TwoFactorChallenge
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, used_at FROM two_factor_challenges WHERE challenge_id = $1
|
||||
`, challengeID).Scan(&challenge.ID, &challenge.UserID, &challenge.ExpiresAt, &challenge.UsedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if challenge.UsedAt != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if time.Now().After(challenge.ExpiresAt) {
|
||||
return nil, ErrTOTPChallengeExpired
|
||||
}
|
||||
|
||||
// Get TOTP secret
|
||||
var secret string
|
||||
err = s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, challenge.UserID).Scan(&secret)
|
||||
if err != nil {
|
||||
return nil, ErrTOTPNotEnabled
|
||||
}
|
||||
|
||||
// Validate TOTP code
|
||||
if !s.ValidateTOTP(secret, code) {
|
||||
return nil, ErrTOTPInvalidCode
|
||||
}
|
||||
|
||||
// Mark challenge as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE two_factor_challenges SET used_at = NOW() WHERE id = $1`, challenge.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to mark challenge as used: %w", err)
|
||||
}
|
||||
|
||||
// Update last used time
|
||||
_, _ = s.db.Exec(ctx, `UPDATE user_totp SET last_used_at = NOW() WHERE user_id = $1`, challenge.UserID)
|
||||
|
||||
return &challenge.UserID, nil
|
||||
}
|
||||
|
||||
// VerifyChallengeWithRecoveryCode verifies a 2FA challenge with a recovery code
|
||||
func (s *TOTPService) VerifyChallengeWithRecoveryCode(ctx context.Context, challengeID, recoveryCode string) (*uuid.UUID, error) {
|
||||
var challenge models.TwoFactorChallenge
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, used_at FROM two_factor_challenges WHERE challenge_id = $1
|
||||
`, challengeID).Scan(&challenge.ID, &challenge.UserID, &challenge.ExpiresAt, &challenge.UsedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if challenge.UsedAt != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if time.Now().After(challenge.ExpiresAt) {
|
||||
return nil, ErrTOTPChallengeExpired
|
||||
}
|
||||
|
||||
// Get recovery codes
|
||||
var recoveryCodesJSON []byte
|
||||
err = s.db.QueryRow(ctx, `SELECT recovery_codes FROM user_totp WHERE user_id = $1 AND verified = true`, challenge.UserID).Scan(&recoveryCodesJSON)
|
||||
if err != nil {
|
||||
return nil, ErrTOTPNotEnabled
|
||||
}
|
||||
|
||||
var hashedCodes []string
|
||||
json.Unmarshal(recoveryCodesJSON, &hashedCodes)
|
||||
|
||||
// Hash the provided recovery code
|
||||
codeHash := sha256.Sum256([]byte(strings.ToUpper(recoveryCode)))
|
||||
codeHashStr := hex.EncodeToString(codeHash[:])
|
||||
|
||||
// Find and remove the recovery code
|
||||
found := false
|
||||
newCodes := make([]string, 0, len(hashedCodes)-1)
|
||||
for _, hc := range hashedCodes {
|
||||
if hc == codeHashStr && !found {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newCodes = append(newCodes, hc)
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, ErrRecoveryCodeInvalid
|
||||
}
|
||||
|
||||
// Update recovery codes
|
||||
newCodesJSON, _ := json.Marshal(newCodes)
|
||||
_, err = s.db.Exec(ctx, `UPDATE user_totp SET recovery_codes = $1, updated_at = NOW() WHERE user_id = $2`, newCodesJSON, challenge.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update recovery codes: %w", err)
|
||||
}
|
||||
|
||||
// Mark challenge as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE two_factor_challenges SET used_at = NOW() WHERE id = $1`, challenge.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to mark challenge as used: %w", err)
|
||||
}
|
||||
|
||||
return &challenge.UserID, nil
|
||||
}
|
||||
|
||||
// Disable2FA disables 2FA for a user
|
||||
func (s *TOTPService) Disable2FA(ctx context.Context, userID uuid.UUID, code string) error {
|
||||
// Get TOTP secret
|
||||
var secret string
|
||||
err := s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, userID).Scan(&secret)
|
||||
if err != nil {
|
||||
return ErrTOTPNotEnabled
|
||||
}
|
||||
|
||||
// Validate code
|
||||
if !s.ValidateTOTP(secret, code) {
|
||||
return ErrTOTPInvalidCode
|
||||
}
|
||||
|
||||
// Delete TOTP record
|
||||
_, err = s.db.Exec(ctx, `DELETE FROM user_totp WHERE user_id = $1`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete TOTP: %w", err)
|
||||
}
|
||||
|
||||
// Update user record
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE users SET two_factor_enabled = false, two_factor_verified_at = NULL, updated_at = NOW() WHERE id = $1
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus returns the 2FA status for a user
|
||||
func (s *TOTPService) GetStatus(ctx context.Context, userID uuid.UUID) (*models.TwoFactorStatusResponse, error) {
|
||||
var totp models.UserTOTP
|
||||
var recoveryCodesJSON []byte
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, verified, enabled_at, recovery_codes FROM user_totp WHERE user_id = $1
|
||||
`, userID).Scan(&totp.ID, &totp.Verified, &totp.EnabledAt, &recoveryCodesJSON)
|
||||
|
||||
if err != nil {
|
||||
// 2FA not set up
|
||||
return &models.TwoFactorStatusResponse{
|
||||
Enabled: false,
|
||||
Verified: false,
|
||||
RecoveryCodesCount: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var hashedCodes []string
|
||||
json.Unmarshal(recoveryCodesJSON, &hashedCodes)
|
||||
|
||||
return &models.TwoFactorStatusResponse{
|
||||
Enabled: totp.Verified,
|
||||
Verified: totp.Verified,
|
||||
EnabledAt: totp.EnabledAt,
|
||||
RecoveryCodesCount: len(hashedCodes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RegenerateRecoveryCodes generates new recovery codes (requires current TOTP code)
|
||||
func (s *TOTPService) RegenerateRecoveryCodes(ctx context.Context, userID uuid.UUID, code string) ([]string, error) {
|
||||
// Get TOTP secret
|
||||
var secret string
|
||||
err := s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, userID).Scan(&secret)
|
||||
if err != nil {
|
||||
return nil, ErrTOTPNotEnabled
|
||||
}
|
||||
|
||||
// Validate code
|
||||
if !s.ValidateTOTP(secret, code) {
|
||||
return nil, ErrTOTPInvalidCode
|
||||
}
|
||||
|
||||
// Generate new recovery codes
|
||||
recoveryCodes, err := s.GenerateRecoveryCodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Hash recovery codes for storage
|
||||
hashedCodes := make([]string, len(recoveryCodes))
|
||||
for i, rc := range recoveryCodes {
|
||||
hash := sha256.Sum256([]byte(rc))
|
||||
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
recoveryCodesJSON, _ := json.Marshal(hashedCodes)
|
||||
|
||||
// Update recovery codes
|
||||
_, err = s.db.Exec(ctx, `UPDATE user_totp SET recovery_codes = $1, updated_at = NOW() WHERE user_id = $2`, recoveryCodesJSON, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update recovery codes: %w", err)
|
||||
}
|
||||
|
||||
return recoveryCodes, nil
|
||||
}
|
||||
|
||||
// IsTwoFactorEnabled checks if 2FA is enabled for a user
|
||||
func (s *TOTPService) IsTwoFactorEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
var enabled bool
|
||||
err := s.db.QueryRow(ctx, `SELECT two_factor_enabled FROM users WHERE id = $1`, userID).Scan(&enabled)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return enabled, nil
|
||||
}
|
||||
378
consent-service/internal/services/totp_service_test.go
Normal file
378
consent-service/internal/services/totp_service_test.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTOTPGeneration tests TOTP code generation
|
||||
func TestTOTPGeneration_ValidSecret(t *testing.T) {
|
||||
// Test secret (Base32 encoded)
|
||||
secret := "JBSWY3DPEHPK3PXP" // This is "Hello!" in Base32
|
||||
|
||||
// Decode secret
|
||||
secretBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode secret: %v", err)
|
||||
}
|
||||
|
||||
// Generate TOTP for current time
|
||||
now := time.Now()
|
||||
counter := uint64(now.Unix()) / 30
|
||||
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, counter)
|
||||
|
||||
mac := hmac.New(sha1.New, secretBytes)
|
||||
mac.Write(buf)
|
||||
hash := mac.Sum(nil)
|
||||
|
||||
// Dynamic truncation
|
||||
offset := hash[len(hash)-1] & 0x0f
|
||||
code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff
|
||||
totpCode := code % 1000000
|
||||
|
||||
// Check that code is 6 digits
|
||||
if totpCode < 0 || totpCode > 999999 {
|
||||
t.Errorf("TOTP code should be 6 digits, got %d", totpCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTOTPGeneration_SameTimeProducesSameCode tests deterministic generation
|
||||
func TestTOTPGeneration_Deterministic(t *testing.T) {
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
|
||||
|
||||
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
code1 := generateTOTPAt(secretBytes, fixedTime)
|
||||
code2 := generateTOTPAt(secretBytes, fixedTime)
|
||||
|
||||
if code1 != code2 {
|
||||
t.Errorf("Same time should produce same code: got %s and %s", code1, code2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTOTPGeneration_DifferentTimesProduceDifferentCodes tests time sensitivity
|
||||
func TestTOTPGeneration_TimeSensitive(t *testing.T) {
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
|
||||
|
||||
time1 := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
time2 := time1.Add(30 * time.Second) // Next TOTP period
|
||||
|
||||
code1 := generateTOTPAt(secretBytes, time1)
|
||||
code2 := generateTOTPAt(secretBytes, time2)
|
||||
|
||||
if code1 == code2 {
|
||||
t.Error("Different TOTP periods should produce different codes")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for TOTP generation at specific time
|
||||
func generateTOTPAt(secretBytes []byte, t time.Time) string {
|
||||
counter := uint64(t.Unix()) / 30
|
||||
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, counter)
|
||||
|
||||
mac := hmac.New(sha1.New, secretBytes)
|
||||
mac.Write(buf)
|
||||
hash := mac.Sum(nil)
|
||||
|
||||
offset := hash[len(hash)-1] & 0x0f
|
||||
code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff
|
||||
|
||||
return padCode(code % 1000000)
|
||||
}
|
||||
|
||||
func padCode(code uint32) string {
|
||||
s := ""
|
||||
for i := 0; i < 6; i++ {
|
||||
s = string(rune('0'+code%10)) + s
|
||||
code /= 10
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// TestTOTPValidation_WithDrift tests validation with clock drift allowance
|
||||
func TestTOTPValidation_WithDrift(t *testing.T) {
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Generate current code
|
||||
currentCode := generateTOTPAt(secretBytes, now)
|
||||
|
||||
// Generate previous period code
|
||||
previousCode := generateTOTPAt(secretBytes, now.Add(-30*time.Second))
|
||||
|
||||
// Generate next period code
|
||||
nextCode := generateTOTPAt(secretBytes, now.Add(30*time.Second))
|
||||
|
||||
// All three should be valid for current validation (allowing 1 period drift)
|
||||
validCodes := []string{currentCode, previousCode, nextCode}
|
||||
|
||||
for _, code := range validCodes {
|
||||
isValid := validateTOTPWithDrift(secretBytes, code, now)
|
||||
if !isValid {
|
||||
t.Errorf("Code %s should be valid with drift allowance", code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateTOTPWithDrift validates a TOTP code allowing for clock drift
|
||||
func validateTOTPWithDrift(secretBytes []byte, code string, now time.Time) bool {
|
||||
for _, offset := range []int{0, -1, 1} {
|
||||
t := now.Add(time.Duration(offset*30) * time.Second)
|
||||
expected := generateTOTPAt(secretBytes, t)
|
||||
if expected == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestRecoveryCodeGeneration tests recovery code format
|
||||
func TestRecoveryCodeGeneration_Format(t *testing.T) {
|
||||
// Simulate recovery code generation
|
||||
codeBytes := make([]byte, 4) // 8 hex chars = 4 bytes
|
||||
for i := range codeBytes {
|
||||
codeBytes[i] = byte(i + 1) // Deterministic for testing
|
||||
}
|
||||
code := strings.ToUpper(hex.EncodeToString(codeBytes))
|
||||
|
||||
// Check format
|
||||
if len(code) != 8 {
|
||||
t.Errorf("Recovery code should be 8 characters, got %d", len(code))
|
||||
}
|
||||
|
||||
// Check uppercase
|
||||
if code != strings.ToUpper(code) {
|
||||
t.Error("Recovery code should be uppercase")
|
||||
}
|
||||
|
||||
// Check alphanumeric (hex only contains 0-9 and A-F)
|
||||
for _, c := range code {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F')) {
|
||||
t.Errorf("Recovery code should only contain hex characters, found '%c'", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecoveryCodeHashing tests that recovery codes are hashed for storage
|
||||
func TestRecoveryCodeHashing_Consistency(t *testing.T) {
|
||||
code := "ABCD1234"
|
||||
|
||||
hash1 := sha256.Sum256([]byte(code))
|
||||
hash2 := sha256.Sum256([]byte(code))
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("Recovery code hashing should be consistent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryCodeHashing_CaseInsensitive(t *testing.T) {
|
||||
code1 := "ABCD1234"
|
||||
code2 := "abcd1234"
|
||||
|
||||
hash1 := sha256.Sum256([]byte(strings.ToUpper(code1)))
|
||||
hash2 := sha256.Sum256([]byte(strings.ToUpper(code2)))
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("Recovery codes should be case-insensitive when normalized to uppercase")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecretGeneration tests that secrets are valid Base32
|
||||
func TestSecretGeneration_ValidBase32(t *testing.T) {
|
||||
// Simulate secret generation (20 bytes -> Base32 without padding)
|
||||
secretBytes := make([]byte, 20)
|
||||
for i := range secretBytes {
|
||||
secretBytes[i] = byte(i * 13) // Deterministic for testing
|
||||
}
|
||||
|
||||
secret := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secretBytes)
|
||||
|
||||
// Verify it can be decoded
|
||||
decoded, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
|
||||
if err != nil {
|
||||
t.Errorf("Generated secret should be valid Base32: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded) != 20 {
|
||||
t.Errorf("Decoded secret should be 20 bytes, got %d", len(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
// TestQRCodeOtpauthURL tests otpauth URL format
|
||||
func TestQRCodeOtpauthURL_Format(t *testing.T) {
|
||||
issuer := "BreakPilot"
|
||||
email := "test@example.com"
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
period := 30
|
||||
digits := 6
|
||||
|
||||
url := "otpauth://totp/" + issuer + ":" + email +
|
||||
"?secret=" + secret +
|
||||
"&issuer=" + issuer +
|
||||
"&algorithm=SHA1" +
|
||||
"&digits=" + string(rune('0'+digits)) +
|
||||
"&period=" + string(rune('0'+period/10)) + string(rune('0'+period%10))
|
||||
|
||||
// Check URL starts with otpauth://totp/
|
||||
if !strings.HasPrefix(url, "otpauth://totp/") {
|
||||
t.Error("OTP auth URL should start with otpauth://totp/")
|
||||
}
|
||||
|
||||
// Check contains required parameters
|
||||
if !strings.Contains(url, "secret=") {
|
||||
t.Error("OTP auth URL should contain secret parameter")
|
||||
}
|
||||
if !strings.Contains(url, "issuer=") {
|
||||
t.Error("OTP auth URL should contain issuer parameter")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChallengeExpiry tests 2FA challenge expiration
|
||||
func TestChallengeExpiry_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiryMins int
|
||||
usedAfter int // minutes after creation
|
||||
shouldAllow bool
|
||||
}{
|
||||
{
|
||||
name: "challenge used within expiry",
|
||||
expiryMins: 5,
|
||||
usedAfter: 2,
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "challenge used at expiry",
|
||||
expiryMins: 5,
|
||||
usedAfter: 5,
|
||||
shouldAllow: false, // Expired
|
||||
},
|
||||
{
|
||||
name: "challenge used after expiry",
|
||||
expiryMins: 5,
|
||||
usedAfter: 10,
|
||||
shouldAllow: false,
|
||||
},
|
||||
{
|
||||
name: "challenge used immediately",
|
||||
expiryMins: 5,
|
||||
usedAfter: 0,
|
||||
shouldAllow: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.usedAfter < tt.expiryMins
|
||||
|
||||
if isValid != tt.shouldAllow {
|
||||
t.Errorf("Expected allow=%v for challenge used after %d mins (expiry: %d mins)",
|
||||
tt.shouldAllow, tt.usedAfter, tt.expiryMins)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecoveryCodeOneTimeUse tests that recovery codes can only be used once
|
||||
func TestRecoveryCodeOneTimeUse(t *testing.T) {
|
||||
initialCodes := []string{
|
||||
sha256Hash("CODE0001"),
|
||||
sha256Hash("CODE0002"),
|
||||
sha256Hash("CODE0003"),
|
||||
}
|
||||
|
||||
// Use CODE0002
|
||||
usedCodeHash := sha256Hash("CODE0002")
|
||||
|
||||
// Remove used code from list
|
||||
var remainingCodes []string
|
||||
for _, code := range initialCodes {
|
||||
if code != usedCodeHash {
|
||||
remainingCodes = append(remainingCodes, code)
|
||||
}
|
||||
}
|
||||
|
||||
if len(remainingCodes) != 2 {
|
||||
t.Errorf("Should have 2 remaining codes after using one, got %d", len(remainingCodes))
|
||||
}
|
||||
|
||||
// Verify used code is not in remaining
|
||||
for _, code := range remainingCodes {
|
||||
if code == usedCodeHash {
|
||||
t.Error("Used recovery code should be removed from list")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sha256Hash(s string) string {
|
||||
h := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// TestTwoFactorEnableFlow tests the 2FA enable workflow
|
||||
func TestTwoFactorEnableFlow_States(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState bool // verified
|
||||
action string
|
||||
expectedState bool
|
||||
}{
|
||||
{
|
||||
name: "fresh user - not verified",
|
||||
initialState: false,
|
||||
action: "none",
|
||||
expectedState: false,
|
||||
},
|
||||
{
|
||||
name: "user verifies 2FA",
|
||||
initialState: false,
|
||||
action: "verify",
|
||||
expectedState: true,
|
||||
},
|
||||
{
|
||||
name: "already verified - stays verified",
|
||||
initialState: true,
|
||||
action: "verify",
|
||||
expectedState: true,
|
||||
},
|
||||
{
|
||||
name: "user disables 2FA",
|
||||
initialState: true,
|
||||
action: "disable",
|
||||
expectedState: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
state := tt.initialState
|
||||
|
||||
switch tt.action {
|
||||
case "verify":
|
||||
state = true
|
||||
case "disable":
|
||||
state = false
|
||||
}
|
||||
|
||||
if state != tt.expectedState {
|
||||
t.Errorf("Expected state=%v after action '%s', got state=%v",
|
||||
tt.expectedState, tt.action, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user