feat(training+controls): interactive video pipeline, training blocks, control generator, CE libraries
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 37s
CI/CD / test-python-backend-compliance (push) Successful in 39s
CI/CD / test-python-document-crawler (push) Successful in 26s
CI/CD / test-python-dsms-gateway (push) Successful in 23s
CI/CD / validate-canonical-controls (push) Successful in 12s
CI/CD / Deploy (push) Has been skipped

Interactive Training Videos (CP-TRAIN):
- DB migration 022: training_checkpoints + checkpoint_progress tables
- NarratorScript generation via Anthropic (AI Teacher persona, German)
- TTS batch synthesis + interactive video pipeline (slides + checkpoint slides + FFmpeg)
- 4 new API endpoints: generate-interactive, interactive-manifest, checkpoint submit, checkpoint progress
- InteractiveVideoPlayer component (HTML5 Video, quiz overlay, seek protection, progress tracking)
- Learner portal integration with automatic completion on all checkpoints passed
- 30 new tests (handler validation + grading logic + manifest/progress + seek protection)

Training Blocks:
- Block generator, block store, block config CRUD + preview/generate endpoints
- Migration 021: training_blocks schema

Control Generator + Canonical Library:
- Control generator routes + service enhancements
- Canonical control library helpers, sidebar entry
- Citation backfill service + tests
- CE libraries data (hazard, protection, evidence, lifecycle, components)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-16 21:41:48 +01:00
parent d2133dbfa2
commit 4f6bc8f6f6
50 changed files with 17299 additions and 198 deletions

View File

@@ -25,7 +25,6 @@ func NewRAGHandlers(corpusVersionStore *ucca.CorpusVersionStore) *RAGHandlers {
// AllowedCollections is the whitelist of Qdrant collections that can be queried.
var AllowedCollections = map[string]bool{
"bp_compliance_ce": true,
"bp_compliance_recht": true,
"bp_compliance_gesetze": true,
"bp_compliance_datenschutz": true,
"bp_compliance_gdpr": true,

View File

@@ -3,7 +3,9 @@ package handlers
import (
"net/http"
"strconv"
"time"
"github.com/breakpilot/ai-compliance-sdk/internal/academy"
"github.com/breakpilot/ai-compliance-sdk/internal/rbac"
"github.com/breakpilot/ai-compliance-sdk/internal/training"
"github.com/gin-gonic/gin"
@@ -14,13 +16,17 @@ import (
type TrainingHandlers struct {
store *training.Store
contentGenerator *training.ContentGenerator
blockGenerator *training.BlockGenerator
ttsClient *training.TTSClient
}
// NewTrainingHandlers creates new training handlers
func NewTrainingHandlers(store *training.Store, contentGenerator *training.ContentGenerator) *TrainingHandlers {
func NewTrainingHandlers(store *training.Store, contentGenerator *training.ContentGenerator, blockGenerator *training.BlockGenerator, ttsClient *training.TTSClient) *TrainingHandlers {
return &TrainingHandlers{
store: store,
contentGenerator: contentGenerator,
blockGenerator: blockGenerator,
ttsClient: ttsClient,
}
}
@@ -212,6 +218,33 @@ func (h *TrainingHandlers) UpdateModule(c *gin.Context) {
c.JSON(http.StatusOK, module)
}
// DeleteModule deletes a training module
// DELETE /sdk/v1/training/modules/:id
func (h *TrainingHandlers) DeleteModule(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid module ID"})
return
}
module, err := h.store.GetModule(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if module == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "module not found"})
return
}
if err := h.store.DeleteModule(c.Request.Context(), id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ============================================================================
// Matrix Endpoints
// ============================================================================
@@ -459,6 +492,48 @@ func (h *TrainingHandlers) UpdateAssignmentProgress(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": string(status), "progress": req.Progress})
}
// UpdateAssignment updates assignment fields (e.g. deadline)
// PUT /sdk/v1/training/assignments/:id
func (h *TrainingHandlers) UpdateAssignment(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid assignment ID"})
return
}
var req struct {
Deadline *string `json:"deadline"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if req.Deadline != nil {
deadline, err := time.Parse(time.RFC3339, *req.Deadline)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid deadline format (use RFC3339)"})
return
}
if err := h.store.UpdateAssignmentDeadline(c.Request.Context(), id, deadline); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
assignment, err := h.store.GetAssignment(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if assignment == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "assignment not found"})
return
}
c.JSON(http.StatusOK, assignment)
}
// CompleteAssignment marks an assignment as completed
// POST /sdk/v1/training/assignments/:id/complete
func (h *TrainingHandlers) CompleteAssignment(c *gin.Context) {
@@ -1111,3 +1186,679 @@ func (h *TrainingHandlers) PreviewVideoScript(c *gin.Context) {
c.JSON(http.StatusOK, script)
}
// ============================================================================
// Training Block Endpoints (Controls → Schulungsmodule)
// ============================================================================
// ListBlockConfigs returns all block configs for the tenant
// GET /sdk/v1/training/blocks
func (h *TrainingHandlers) ListBlockConfigs(c *gin.Context) {
tenantID := rbac.GetTenantID(c)
configs, err := h.store.ListBlockConfigs(c.Request.Context(), tenantID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"blocks": configs,
"total": len(configs),
})
}
// CreateBlockConfig creates a new block configuration
// POST /sdk/v1/training/blocks
func (h *TrainingHandlers) CreateBlockConfig(c *gin.Context) {
tenantID := rbac.GetTenantID(c)
var req training.CreateBlockConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
config := &training.TrainingBlockConfig{
TenantID: tenantID,
Name: req.Name,
Description: req.Description,
DomainFilter: req.DomainFilter,
CategoryFilter: req.CategoryFilter,
SeverityFilter: req.SeverityFilter,
TargetAudienceFilter: req.TargetAudienceFilter,
RegulationArea: req.RegulationArea,
ModuleCodePrefix: req.ModuleCodePrefix,
FrequencyType: req.FrequencyType,
DurationMinutes: req.DurationMinutes,
PassThreshold: req.PassThreshold,
MaxControlsPerModule: req.MaxControlsPerModule,
}
if config.FrequencyType == "" {
config.FrequencyType = training.FrequencyAnnual
}
if config.DurationMinutes == 0 {
config.DurationMinutes = 45
}
if config.PassThreshold == 0 {
config.PassThreshold = 70
}
if config.MaxControlsPerModule == 0 {
config.MaxControlsPerModule = 20
}
if err := h.store.CreateBlockConfig(c.Request.Context(), config); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, config)
}
// GetBlockConfig returns a single block config
// GET /sdk/v1/training/blocks/:id
func (h *TrainingHandlers) GetBlockConfig(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid block config ID"})
return
}
config, err := h.store.GetBlockConfig(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if config == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "block config not found"})
return
}
c.JSON(http.StatusOK, config)
}
// UpdateBlockConfig updates a block config
// PUT /sdk/v1/training/blocks/:id
func (h *TrainingHandlers) UpdateBlockConfig(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid block config ID"})
return
}
config, err := h.store.GetBlockConfig(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if config == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "block config not found"})
return
}
var req training.UpdateBlockConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name != nil {
config.Name = *req.Name
}
if req.Description != nil {
config.Description = *req.Description
}
if req.DomainFilter != nil {
config.DomainFilter = *req.DomainFilter
}
if req.CategoryFilter != nil {
config.CategoryFilter = *req.CategoryFilter
}
if req.SeverityFilter != nil {
config.SeverityFilter = *req.SeverityFilter
}
if req.TargetAudienceFilter != nil {
config.TargetAudienceFilter = *req.TargetAudienceFilter
}
if req.MaxControlsPerModule != nil {
config.MaxControlsPerModule = *req.MaxControlsPerModule
}
if req.DurationMinutes != nil {
config.DurationMinutes = *req.DurationMinutes
}
if req.PassThreshold != nil {
config.PassThreshold = *req.PassThreshold
}
if req.IsActive != nil {
config.IsActive = *req.IsActive
}
if err := h.store.UpdateBlockConfig(c.Request.Context(), config); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, config)
}
// DeleteBlockConfig deletes a block config
// DELETE /sdk/v1/training/blocks/:id
func (h *TrainingHandlers) DeleteBlockConfig(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid block config ID"})
return
}
if err := h.store.DeleteBlockConfig(c.Request.Context(), id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// PreviewBlock performs a dry run showing matching controls and proposed roles
// POST /sdk/v1/training/blocks/:id/preview
func (h *TrainingHandlers) PreviewBlock(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid block config ID"})
return
}
preview, err := h.blockGenerator.Preview(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, preview)
}
// GenerateBlock runs the full generation pipeline
// POST /sdk/v1/training/blocks/:id/generate
func (h *TrainingHandlers) GenerateBlock(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid block config ID"})
return
}
var req training.GenerateBlockRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Defaults are fine
req.Language = "de"
req.AutoMatrix = true
}
result, err := h.blockGenerator.Generate(c.Request.Context(), id, req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// GetBlockControls returns control links for a block config
// GET /sdk/v1/training/blocks/:id/controls
func (h *TrainingHandlers) GetBlockControls(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid block config ID"})
return
}
links, err := h.store.GetControlLinksForBlock(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"controls": links,
"total": len(links),
})
}
// ListCanonicalControls returns filtered canonical controls for browsing
// GET /sdk/v1/training/canonical/controls
func (h *TrainingHandlers) ListCanonicalControls(c *gin.Context) {
domain := c.Query("domain")
category := c.Query("category")
severity := c.Query("severity")
targetAudience := c.Query("target_audience")
controls, err := h.store.QueryCanonicalControls(c.Request.Context(),
domain, category, severity, targetAudience,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"controls": controls,
"total": len(controls),
})
}
// GetCanonicalMeta returns aggregated metadata about canonical controls
// GET /sdk/v1/training/canonical/meta
func (h *TrainingHandlers) GetCanonicalMeta(c *gin.Context) {
meta, err := h.store.GetCanonicalControlMeta(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, meta)
}
// ============================================================================
// Media Streaming Endpoint
// ============================================================================
// StreamMedia returns a redirect to a presigned URL for a media file
// GET /sdk/v1/training/media/:mediaId/stream
func (h *TrainingHandlers) StreamMedia(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid media ID"})
return
}
media, err := h.store.GetMedia(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if media == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "media not found"})
return
}
if h.ttsClient == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "media streaming not available"})
return
}
url, err := h.ttsClient.GetPresignedURL(c.Request.Context(), media.Bucket, media.ObjectKey)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate streaming URL: " + err.Error()})
return
}
c.Redirect(http.StatusTemporaryRedirect, url)
}
// ============================================================================
// Certificate Endpoints
// ============================================================================
// GenerateCertificate generates a certificate for a completed assignment
// POST /sdk/v1/training/certificates/generate/:assignmentId
func (h *TrainingHandlers) GenerateCertificate(c *gin.Context) {
assignmentID, err := uuid.Parse(c.Param("assignmentId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid assignment ID"})
return
}
tenantID := rbac.GetTenantID(c)
assignment, err := h.store.GetAssignment(c.Request.Context(), assignmentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if assignment == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "assignment not found"})
return
}
if assignment.Status != training.AssignmentStatusCompleted {
c.JSON(http.StatusBadRequest, gin.H{"error": "assignment is not completed"})
return
}
if assignment.QuizPassed == nil || !*assignment.QuizPassed {
c.JSON(http.StatusBadRequest, gin.H{"error": "quiz has not been passed"})
return
}
// Generate certificate ID
certID := uuid.New()
if err := h.store.SetCertificateID(c.Request.Context(), assignmentID, certID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Audit log
userID := rbac.GetUserID(c)
h.store.LogAction(c.Request.Context(), &training.AuditLogEntry{
TenantID: tenantID,
UserID: &userID,
Action: training.AuditActionCertificateIssued,
EntityType: training.AuditEntityCertificate,
EntityID: &certID,
Details: map[string]interface{}{
"assignment_id": assignmentID.String(),
"user_name": assignment.UserName,
"module_title": assignment.ModuleTitle,
},
})
// Reload assignment with certificate_id
assignment, _ = h.store.GetAssignment(c.Request.Context(), assignmentID)
c.JSON(http.StatusOK, gin.H{
"certificate_id": certID,
"assignment": assignment,
})
}
// DownloadCertificatePDF generates and returns a PDF certificate
// GET /sdk/v1/training/certificates/:id/pdf
func (h *TrainingHandlers) DownloadCertificatePDF(c *gin.Context) {
certID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid certificate ID"})
return
}
assignment, err := h.store.GetAssignmentByCertificateID(c.Request.Context(), certID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if assignment == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "certificate not found"})
return
}
// Get module for title
module, _ := h.store.GetModule(c.Request.Context(), assignment.ModuleID)
courseName := assignment.ModuleTitle
if module != nil {
courseName = module.Title
}
score := 0
if assignment.QuizScore != nil {
score = int(*assignment.QuizScore)
}
issuedAt := assignment.UpdatedAt
if assignment.CompletedAt != nil {
issuedAt = *assignment.CompletedAt
}
// Use academy PDF generator
pdfBytes, err := academy.GenerateCertificatePDF(academy.CertificateData{
CertificateID: certID.String(),
UserName: assignment.UserName,
CourseName: courseName,
Score: score,
IssuedAt: issuedAt,
ValidUntil: issuedAt.AddDate(1, 0, 0),
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "PDF generation failed: " + err.Error()})
return
}
c.Header("Content-Disposition", "attachment; filename=zertifikat-"+certID.String()[:8]+".pdf")
c.Data(http.StatusOK, "application/pdf", pdfBytes)
}
// ListCertificates returns all certificates for a tenant
// GET /sdk/v1/training/certificates
func (h *TrainingHandlers) ListCertificates(c *gin.Context) {
tenantID := rbac.GetTenantID(c)
certificates, err := h.store.ListCertificates(c.Request.Context(), tenantID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"certificates": certificates,
"total": len(certificates),
})
}
// ============================================================================
// Interactive Video Endpoints
// ============================================================================
// GenerateInteractiveVideo triggers the full interactive video pipeline
// POST /sdk/v1/training/content/:moduleId/generate-interactive
func (h *TrainingHandlers) GenerateInteractiveVideo(c *gin.Context) {
moduleID, err := uuid.Parse(c.Param("moduleId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid module ID"})
return
}
module, err := h.store.GetModule(c.Request.Context(), moduleID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if module == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "module not found"})
return
}
media, err := h.contentGenerator.GenerateInteractiveVideo(c.Request.Context(), *module)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, media)
}
// GetInteractiveManifest returns the interactive video manifest with checkpoints and progress
// GET /sdk/v1/training/content/:moduleId/interactive-manifest
func (h *TrainingHandlers) GetInteractiveManifest(c *gin.Context) {
moduleID, err := uuid.Parse(c.Param("moduleId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid module ID"})
return
}
// Get interactive video media
mediaList, err := h.store.GetMediaForModule(c.Request.Context(), moduleID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Find interactive video
var interactiveMedia *training.TrainingMedia
for i := range mediaList {
if mediaList[i].MediaType == training.MediaTypeInteractiveVideo && mediaList[i].Status == training.MediaStatusCompleted {
interactiveMedia = &mediaList[i]
break
}
}
if interactiveMedia == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "no interactive video found for this module"})
return
}
// Get checkpoints
checkpoints, err := h.store.ListCheckpoints(c.Request.Context(), moduleID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Optional: get assignment ID for progress
assignmentIDStr := c.Query("assignment_id")
// Build manifest entries
entries := make([]training.CheckpointManifestEntry, len(checkpoints))
for i, cp := range checkpoints {
// Get questions for this checkpoint
questions, _ := h.store.GetCheckpointQuestions(c.Request.Context(), cp.ID)
cpQuestions := make([]training.CheckpointQuestion, len(questions))
for j, q := range questions {
cpQuestions[j] = training.CheckpointQuestion{
Question: q.Question,
Options: q.Options,
CorrectIndex: q.CorrectIndex,
Explanation: q.Explanation,
}
}
entry := training.CheckpointManifestEntry{
CheckpointID: cp.ID,
Index: cp.CheckpointIndex,
Title: cp.Title,
TimestampSeconds: cp.TimestampSeconds,
Questions: cpQuestions,
}
// Get progress if assignment_id provided
if assignmentIDStr != "" {
if assignmentID, err := uuid.Parse(assignmentIDStr); err == nil {
progress, _ := h.store.GetCheckpointProgress(c.Request.Context(), assignmentID, cp.ID)
entry.Progress = progress
}
}
entries[i] = entry
}
// Get stream URL
streamURL := ""
if h.ttsClient != nil {
url, err := h.ttsClient.GetPresignedURL(c.Request.Context(), interactiveMedia.Bucket, interactiveMedia.ObjectKey)
if err == nil {
streamURL = url
}
}
manifest := training.InteractiveVideoManifest{
MediaID: interactiveMedia.ID,
StreamURL: streamURL,
Checkpoints: entries,
}
c.JSON(http.StatusOK, manifest)
}
// SubmitCheckpointQuiz handles checkpoint quiz submission
// POST /sdk/v1/training/checkpoints/:checkpointId/submit
func (h *TrainingHandlers) SubmitCheckpointQuiz(c *gin.Context) {
checkpointID, err := uuid.Parse(c.Param("checkpointId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid checkpoint ID"})
return
}
var req training.SubmitCheckpointQuizRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
assignmentID, err := uuid.Parse(req.AssignmentID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid assignment ID"})
return
}
// Get checkpoint questions
questions, err := h.store.GetCheckpointQuestions(c.Request.Context(), checkpointID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(questions) == 0 {
c.JSON(http.StatusNotFound, gin.H{"error": "no questions found for this checkpoint"})
return
}
// Grade answers
correctCount := 0
feedback := make([]training.CheckpointQuizFeedback, len(questions))
for i, q := range questions {
isCorrect := false
if i < len(req.Answers) && req.Answers[i] == q.CorrectIndex {
isCorrect = true
correctCount++
}
feedback[i] = training.CheckpointQuizFeedback{
Question: q.Question,
Correct: isCorrect,
Explanation: q.Explanation,
}
}
score := float64(correctCount) / float64(len(questions)) * 100
passed := score >= 70 // 70% threshold for checkpoint
// Update progress
progress := &training.CheckpointProgress{
AssignmentID: assignmentID,
CheckpointID: checkpointID,
Passed: passed,
Attempts: 1,
}
if err := h.store.UpsertCheckpointProgress(c.Request.Context(), progress); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Audit log
userID := rbac.GetUserID(c)
h.store.LogAction(c.Request.Context(), &training.AuditLogEntry{
TenantID: rbac.GetTenantID(c),
UserID: &userID,
Action: training.AuditAction("checkpoint_submitted"),
EntityType: training.AuditEntityType("checkpoint"),
EntityID: &checkpointID,
Details: map[string]interface{}{
"assignment_id": assignmentID.String(),
"score": score,
"passed": passed,
"correct": correctCount,
"total": len(questions),
},
})
c.JSON(http.StatusOK, training.SubmitCheckpointQuizResponse{
Passed: passed,
Score: score,
Feedback: feedback,
})
}
// GetCheckpointProgress returns all checkpoint progress for an assignment
// GET /sdk/v1/training/checkpoints/progress/:assignmentId
func (h *TrainingHandlers) GetCheckpointProgress(c *gin.Context) {
assignmentID, err := uuid.Parse(c.Param("assignmentId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid assignment ID"})
return
}
progress, err := h.store.ListCheckpointProgress(c.Request.Context(), assignmentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"progress": progress,
"total": len(progress),
})
}

View File

@@ -0,0 +1,691 @@
package handlers
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
)
// newTestContext, parseResponse, and gin.SetMode are defined in iace_handler_test.go
// ============================================================================
// Module Endpoint Tests
// ============================================================================
func TestGetModule_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/modules/not-a-uuid", nil, nil, gin.Params{{Key: "id", Value: "not-a-uuid"}})
h.GetModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGetModule_EmptyID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/modules/", nil, nil, gin.Params{{Key: "id", Value: ""}})
h.GetModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCreateModule_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/modules", nil, nil, nil)
h.CreateModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCreateModule_MissingTitle_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{"module_code": "T01", "regulation_area": "dsgvo", "frequency_type": "annual"}
w, c := newTestContext("POST", "/modules", body, nil, nil)
h.CreateModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCreateModule_MissingModuleCode_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{"title": "Test", "regulation_area": "dsgvo", "frequency_type": "annual"}
w, c := newTestContext("POST", "/modules", body, nil, nil)
h.CreateModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCreateModule_MissingRegulationArea_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{"module_code": "T01", "title": "Test", "frequency_type": "annual"}
w, c := newTestContext("POST", "/modules", body, nil, nil)
h.CreateModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestUpdateModule_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("PUT", "/modules/bad", map[string]interface{}{"title": "x"}, nil, gin.Params{{Key: "id", Value: "bad"}})
h.UpdateModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestDeleteModule_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("DELETE", "/modules/bad", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.DeleteModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestDeleteModule_EmptyID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("DELETE", "/modules/", nil, nil, gin.Params{{Key: "id", Value: ""}})
h.DeleteModule(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Matrix Endpoint Tests
// ============================================================================
func TestSetMatrixEntry_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/matrix", nil, nil, nil)
h.SetMatrixEntry(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestSetMatrixEntry_MissingRoleCode_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{"module_id": "00000000-0000-0000-0000-000000000001"}
w, c := newTestContext("POST", "/matrix", body, nil, nil)
h.SetMatrixEntry(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestDeleteMatrixEntry_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("DELETE", "/matrix/R1/bad", nil, nil, gin.Params{
{Key: "role", Value: "R1"},
{Key: "moduleId", Value: "bad"},
})
h.DeleteMatrixEntry(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Assignment Endpoint Tests
// ============================================================================
func TestGetAssignment_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/assignments/bad", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.GetAssignment(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestStartAssignment_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/assignments/bad/start", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.StartAssignment(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestUpdateAssignmentProgress_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/assignments/bad/progress", map[string]interface{}{"progress": 50}, nil, gin.Params{{Key: "id", Value: "bad"}})
h.UpdateAssignmentProgress(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestUpdateAssignmentProgress_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/assignments/00000000-0000-0000-0000-000000000001/progress", nil, nil,
gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000001"}})
h.UpdateAssignmentProgress(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCompleteAssignment_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/assignments/bad/complete", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.CompleteAssignment(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestComputeAssignments_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/assignments/compute", nil, nil, nil)
h.ComputeAssignments(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestComputeAssignments_MissingUserID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{"user_name": "Test", "user_email": "test@test.de", "roles": []string{"R1"}}
w, c := newTestContext("POST", "/assignments/compute", body, nil, nil)
h.ComputeAssignments(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Quiz Endpoint Tests
// ============================================================================
func TestGetQuiz_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/quiz/bad", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GetQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestSubmitQuiz_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/quiz/bad/submit", map[string]interface{}{}, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.SubmitQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestSubmitQuiz_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/quiz/00000000-0000-0000-0000-000000000001/submit", nil, nil,
gin.Params{{Key: "moduleId", Value: "00000000-0000-0000-0000-000000000001"}})
h.SubmitQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGetQuizAttempts_InvalidAssignmentID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/quiz/attempts/bad", nil, nil, gin.Params{{Key: "assignmentId", Value: "bad"}})
h.GetQuizAttempts(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Content Endpoint Tests
// ============================================================================
func TestGetContent_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/content/bad", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GetContent(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestPublishContent_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/bad/publish", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.PublishContent(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGenerateContent_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/generate", nil, nil, nil)
h.GenerateContent(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGenerateQuiz_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/generate-quiz", nil, nil, nil)
h.GenerateQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Media Endpoint Tests
// ============================================================================
func TestGetModuleMedia_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/media/module/bad", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GetModuleMedia(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGetMediaURL_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/media/bad/url", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.GetMediaURL(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestPublishMedia_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/media/bad/publish", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.PublishMedia(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestStreamMedia_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/media/bad/stream", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.StreamMedia(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGenerateAudio_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/bad/generate-audio", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GenerateAudio(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGenerateVideo_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/bad/generate-video", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GenerateVideo(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestPreviewVideoScript_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/bad/preview-script", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.PreviewVideoScript(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Certificate Endpoint Tests
// ============================================================================
func TestGenerateCertificate_InvalidAssignmentID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/certificates/generate/bad", nil, nil, gin.Params{{Key: "assignmentId", Value: "bad"}})
h.GenerateCertificate(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestDownloadCertificatePDF_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/certificates/bad/pdf", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.DownloadCertificatePDF(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestVerifyCertificate_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/certificates/bad/verify", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.VerifyCertificate(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestListCertificates_NilStore_Panics(t *testing.T) {
// This tests that a nil store doesn't silently succeed
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic with nil store")
}
}()
h := &TrainingHandlers{}
_, c := newTestContext("GET", "/certificates", nil, nil, nil)
h.ListCertificates(c)
}
// ============================================================================
// Interactive Video Endpoint Tests (User Journey: Admin generates video)
// ============================================================================
func TestGenerateInteractiveVideo_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content/bad/generate-interactive", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GenerateInteractiveVideo(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
resp := parseResponse(w)
if resp["error"] == nil {
t.Error("Response should contain 'error' key")
}
}
func TestGenerateInteractiveVideo_EmptyModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/content//generate-interactive", nil, nil, gin.Params{{Key: "moduleId", Value: ""}})
h.GenerateInteractiveVideo(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGetInteractiveManifest_InvalidModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/content/bad/interactive-manifest", nil, nil, gin.Params{{Key: "moduleId", Value: "bad"}})
h.GetInteractiveManifest(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
resp := parseResponse(w)
if resp["error"] == nil {
t.Error("Response should contain 'error' key")
}
}
func TestGetInteractiveManifest_EmptyModuleID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/content//interactive-manifest", nil, nil, gin.Params{{Key: "moduleId", Value: ""}})
h.GetInteractiveManifest(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Checkpoint Quiz Endpoint Tests (User Journey: Learner takes quiz)
// ============================================================================
func TestSubmitCheckpointQuiz_InvalidCheckpointID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{
"assignment_id": "00000000-0000-0000-0000-000000000001",
"answers": []int{0, 1, 2},
}
w, c := newTestContext("POST", "/checkpoints/bad/submit", body, nil, gin.Params{{Key: "checkpointId", Value: "bad"}})
h.SubmitCheckpointQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
resp := parseResponse(w)
if resp["error"] == nil {
t.Error("Response should contain 'error' key")
}
}
func TestSubmitCheckpointQuiz_EmptyCheckpointID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{
"assignment_id": "00000000-0000-0000-0000-000000000001",
"answers": []int{0},
}
w, c := newTestContext("POST", "/checkpoints//submit", body, nil, gin.Params{{Key: "checkpointId", Value: ""}})
h.SubmitCheckpointQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestSubmitCheckpointQuiz_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/checkpoints/00000000-0000-0000-0000-000000000001/submit", nil, nil,
gin.Params{{Key: "checkpointId", Value: "00000000-0000-0000-0000-000000000001"}})
h.SubmitCheckpointQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestSubmitCheckpointQuiz_InvalidAssignmentID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{
"assignment_id": "not-a-uuid",
"answers": []int{0},
}
w, c := newTestContext("POST", "/checkpoints/00000000-0000-0000-0000-000000000001/submit", body, nil,
gin.Params{{Key: "checkpointId", Value: "00000000-0000-0000-0000-000000000001"}})
h.SubmitCheckpointQuiz(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestSubmitCheckpointQuiz_ValidIDs_NilStore_Panics(t *testing.T) {
// When both IDs are valid, handler reaches store → panic with nil store
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic with nil store")
}
}()
h := &TrainingHandlers{}
body := map[string]interface{}{
"assignment_id": "00000000-0000-0000-0000-000000000001",
"answers": []int{0},
}
_, c := newTestContext("POST", "/checkpoints/00000000-0000-0000-0000-000000000001/submit", body, nil,
gin.Params{{Key: "checkpointId", Value: "00000000-0000-0000-0000-000000000001"}})
h.SubmitCheckpointQuiz(c)
}
// ============================================================================
// Checkpoint Progress Endpoint Tests (User Journey: Learner views progress)
// ============================================================================
func TestGetCheckpointProgress_InvalidAssignmentID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/checkpoints/progress/bad", nil, nil, gin.Params{{Key: "assignmentId", Value: "bad"}})
h.GetCheckpointProgress(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
resp := parseResponse(w)
if resp["error"] == nil {
t.Error("Response should contain 'error' key")
}
}
func TestGetCheckpointProgress_EmptyAssignmentID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/checkpoints/progress/", nil, nil, gin.Params{{Key: "assignmentId", Value: ""}})
h.GetCheckpointProgress(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Interactive Video Error Format Tests (table-driven)
// ============================================================================
func TestInteractiveEndpoints_InvalidID_ResponseContainsErrorKey(t *testing.T) {
tests := []struct {
name string
method string
handler func(h *TrainingHandlers, c *gin.Context)
params gin.Params
}{
{"GenerateInteractiveVideo", "POST",
func(h *TrainingHandlers, c *gin.Context) { h.GenerateInteractiveVideo(c) },
gin.Params{{Key: "moduleId", Value: "x"}}},
{"GetInteractiveManifest", "GET",
func(h *TrainingHandlers, c *gin.Context) { h.GetInteractiveManifest(c) },
gin.Params{{Key: "moduleId", Value: "x"}}},
{"SubmitCheckpointQuiz", "POST",
func(h *TrainingHandlers, c *gin.Context) { h.SubmitCheckpointQuiz(c) },
gin.Params{{Key: "checkpointId", Value: "x"}}},
{"GetCheckpointProgress", "GET",
func(h *TrainingHandlers, c *gin.Context) { h.GetCheckpointProgress(c) },
gin.Params{{Key: "assignmentId", Value: "x"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext(tt.method, "/test", nil, nil, tt.params)
tt.handler(h, c)
if w.Code != http.StatusBadRequest {
t.Errorf("%s: Expected 400, got %d", tt.name, w.Code)
}
resp := parseResponse(w)
if resp["error"] == nil {
t.Errorf("%s: response should contain 'error' key", tt.name)
}
})
}
}
// ============================================================================
// Block Endpoint Tests
// ============================================================================
func TestGetBlockConfig_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/blocks/bad", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.GetBlockConfig(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCreateBlockConfig_EmptyBody_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/blocks", nil, nil, nil)
h.CreateBlockConfig(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestCreateBlockConfig_MissingName_Returns400(t *testing.T) {
h := &TrainingHandlers{}
body := map[string]interface{}{"regulation_area": "dsgvo", "module_code_prefix": "BLK"}
w, c := newTestContext("POST", "/blocks", body, nil, nil)
h.CreateBlockConfig(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestUpdateBlockConfig_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("PUT", "/blocks/bad", map[string]interface{}{"name": "x"}, nil, gin.Params{{Key: "id", Value: "bad"}})
h.UpdateBlockConfig(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestDeleteBlockConfig_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("DELETE", "/blocks/bad", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.DeleteBlockConfig(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestPreviewBlock_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/blocks/bad/preview", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.PreviewBlock(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGenerateBlock_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("POST", "/blocks/bad/generate", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.GenerateBlock(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
func TestGetBlockControls_InvalidID_Returns400(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext("GET", "/blocks/bad/controls", nil, nil, gin.Params{{Key: "id", Value: "bad"}})
h.GetBlockControls(c)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400, got %d", w.Code)
}
}
// ============================================================================
// Response Error Format Tests
// ============================================================================
func TestInvalidID_ResponseContainsErrorKey(t *testing.T) {
tests := []struct {
name string
method string
handler func(h *TrainingHandlers, c *gin.Context)
params gin.Params
}{
{"GetModule", "GET", func(h *TrainingHandlers, c *gin.Context) { h.GetModule(c) }, gin.Params{{Key: "id", Value: "x"}}},
{"DeleteModule", "DELETE", func(h *TrainingHandlers, c *gin.Context) { h.DeleteModule(c) }, gin.Params{{Key: "id", Value: "x"}}},
{"StreamMedia", "GET", func(h *TrainingHandlers, c *gin.Context) { h.StreamMedia(c) }, gin.Params{{Key: "id", Value: "x"}}},
{"GenerateCertificate", "POST", func(h *TrainingHandlers, c *gin.Context) { h.GenerateCertificate(c) }, gin.Params{{Key: "assignmentId", Value: "x"}}},
{"DownloadCertificatePDF", "GET", func(h *TrainingHandlers, c *gin.Context) { h.DownloadCertificatePDF(c) }, gin.Params{{Key: "id", Value: "x"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &TrainingHandlers{}
w, c := newTestContext(tt.method, "/test", nil, nil, tt.params)
tt.handler(h, c)
resp := parseResponse(w)
if resp["error"] == nil {
t.Errorf("%s: response should contain 'error' key", tt.name)
}
})
}
}

View File

@@ -0,0 +1,282 @@
package training
import (
"context"
"fmt"
"math"
"github.com/google/uuid"
)
// BlockGenerator orchestrates the Controls → Training Modules pipeline
type BlockGenerator struct {
store *Store
contentGenerator *ContentGenerator
}
// NewBlockGenerator creates a new block generator
func NewBlockGenerator(store *Store, contentGenerator *ContentGenerator) *BlockGenerator {
return &BlockGenerator{
store: store,
contentGenerator: contentGenerator,
}
}
// Preview performs a dry run: loads matching controls, computes module split and roles
func (bg *BlockGenerator) Preview(ctx context.Context, configID uuid.UUID) (*PreviewBlockResponse, error) {
config, err := bg.store.GetBlockConfig(ctx, configID)
if err != nil {
return nil, fmt.Errorf("load block config: %w", err)
}
if config == nil {
return nil, fmt.Errorf("block config not found")
}
controls, err := bg.store.QueryCanonicalControls(ctx,
config.DomainFilter, config.CategoryFilter,
config.SeverityFilter, config.TargetAudienceFilter,
)
if err != nil {
return nil, fmt.Errorf("query controls: %w", err)
}
maxPerModule := config.MaxControlsPerModule
if maxPerModule <= 0 {
maxPerModule = 20
}
moduleCount := int(math.Ceil(float64(len(controls)) / float64(maxPerModule)))
if moduleCount == 0 && len(controls) > 0 {
moduleCount = 1
}
roles := bg.deriveRoles(controls, config.TargetAudienceFilter)
return &PreviewBlockResponse{
ControlCount: len(controls),
ModuleCount: moduleCount,
Controls: controls,
ProposedRoles: roles,
}, nil
}
// Generate executes the full pipeline: Controls → Modules → Links → CTM → Content
func (bg *BlockGenerator) Generate(ctx context.Context, configID uuid.UUID, req GenerateBlockRequest) (*GenerateBlockResponse, error) {
config, err := bg.store.GetBlockConfig(ctx, configID)
if err != nil {
return nil, fmt.Errorf("load block config: %w", err)
}
if config == nil {
return nil, fmt.Errorf("block config not found")
}
// 1. Load matching controls
controls, err := bg.store.QueryCanonicalControls(ctx,
config.DomainFilter, config.CategoryFilter,
config.SeverityFilter, config.TargetAudienceFilter,
)
if err != nil {
return nil, fmt.Errorf("query controls: %w", err)
}
if len(controls) == 0 {
return &GenerateBlockResponse{}, nil
}
// 2. Chunk controls into module-sized groups
maxPerModule := config.MaxControlsPerModule
if maxPerModule <= 0 {
maxPerModule = 20
}
chunks := chunkControls(controls, maxPerModule)
// 3. Derive target roles for CTM
roles := bg.deriveRoles(controls, config.TargetAudienceFilter)
// 4. Count existing modules with this prefix for auto-numbering
existingCount, err := bg.store.CountModulesWithPrefix(ctx, config.TenantID, config.ModuleCodePrefix)
if err != nil {
existingCount = 0
}
language := req.Language
if language == "" {
language = "de"
}
resp := &GenerateBlockResponse{}
for i, chunk := range chunks {
moduleNum := existingCount + i + 1
moduleCode := fmt.Sprintf("%s-%02d", config.ModuleCodePrefix, moduleNum)
// Build a descriptive title from the first few controls
title := bg.buildModuleTitle(config, chunk, i+1, len(chunks))
// a. Create TrainingModule
module := &TrainingModule{
TenantID: config.TenantID,
ModuleCode: moduleCode,
Title: title,
Description: config.Description,
RegulationArea: config.RegulationArea,
NIS2Relevant: config.RegulationArea == RegulationNIS2,
ISOControls: bg.extractControlIDs(chunk),
FrequencyType: config.FrequencyType,
ValidityDays: 365,
RiskWeight: 2.0,
ContentType: "text",
DurationMinutes: config.DurationMinutes,
PassThreshold: config.PassThreshold,
IsActive: true,
SortOrder: moduleNum,
}
if err := bg.store.CreateModule(ctx, module); err != nil {
resp.Errors = append(resp.Errors, fmt.Sprintf("create module %s: %v", moduleCode, err))
continue
}
resp.ModulesCreated++
// b. Create control links (traceability)
for j, ctrl := range chunk {
link := &TrainingBlockControlLink{
BlockConfigID: config.ID,
ModuleID: module.ID,
ControlID: ctrl.ControlID,
ControlTitle: ctrl.Title,
ControlObjective: ctrl.Objective,
ControlRequirements: ctrl.Requirements,
SortOrder: j,
}
if err := bg.store.CreateBlockControlLink(ctx, link); err != nil {
resp.Errors = append(resp.Errors, fmt.Sprintf("link %s→%s: %v", moduleCode, ctrl.ControlID, err))
continue
}
resp.ControlsLinked++
}
// c. Create CTM entries (target_audience → roles)
if req.AutoMatrix {
for _, role := range roles {
entry := &TrainingMatrixEntry{
TenantID: config.TenantID,
RoleCode: role,
ModuleID: module.ID,
IsMandatory: true,
Priority: 1,
}
if err := bg.store.SetMatrixEntry(ctx, entry); err != nil {
resp.Errors = append(resp.Errors, fmt.Sprintf("matrix %s→%s: %v", role, moduleCode, err))
continue
}
resp.MatrixEntriesCreated++
}
}
// d. Generate LLM content
_, err := bg.contentGenerator.GenerateBlockContent(ctx, *module, chunk, language)
if err != nil {
resp.Errors = append(resp.Errors, fmt.Sprintf("content %s: %v", moduleCode, err))
continue
}
resp.ContentGenerated++
}
// 5. Update last_generated_at
bg.store.UpdateBlockConfigLastGenerated(ctx, config.ID)
// 6. Audit log
bg.store.LogAction(ctx, &AuditLogEntry{
TenantID: config.TenantID,
Action: AuditAction("block_generated"),
EntityType: AuditEntityModule,
Details: map[string]interface{}{
"block_config_id": config.ID.String(),
"block_name": config.Name,
"modules_created": resp.ModulesCreated,
"controls_linked": resp.ControlsLinked,
"content_generated": resp.ContentGenerated,
},
})
return resp, nil
}
// deriveRoles computes which CTM roles should receive the generated modules
func (bg *BlockGenerator) deriveRoles(controls []CanonicalControlSummary, audienceFilter string) []string {
roleSet := map[string]bool{}
// If a specific audience filter is set, use the mapping
if audienceFilter != "" {
if roles, ok := TargetAudienceRoleMapping[audienceFilter]; ok {
for _, r := range roles {
roleSet[r] = true
}
}
}
// Additionally derive roles from control categories
for _, ctrl := range controls {
if ctrl.Category != "" {
if roles, ok := CategoryRoleMapping[ctrl.Category]; ok {
for _, r := range roles {
roleSet[r] = true
}
}
}
// Also check per-control target_audience
if ctrl.TargetAudience != "" && audienceFilter == "" {
if roles, ok := TargetAudienceRoleMapping[ctrl.TargetAudience]; ok {
for _, r := range roles {
roleSet[r] = true
}
}
}
}
// If nothing derived, default to R9 (Alle Mitarbeiter)
if len(roleSet) == 0 {
roleSet[RoleR9] = true
}
roles := make([]string, 0, len(roleSet))
for r := range roleSet {
roles = append(roles, r)
}
return roles
}
// buildModuleTitle creates a descriptive module title
func (bg *BlockGenerator) buildModuleTitle(config *TrainingBlockConfig, controls []CanonicalControlSummary, partNum, totalParts int) string {
base := config.Name
if totalParts > 1 {
base = fmt.Sprintf("%s (Teil %d/%d)", config.Name, partNum, totalParts)
}
return base
}
// extractControlIDs returns the control IDs from a slice of controls
func (bg *BlockGenerator) extractControlIDs(controls []CanonicalControlSummary) []string {
ids := make([]string, len(controls))
for i, c := range controls {
ids[i] = c.ControlID
}
return ids
}
// chunkControls splits controls into groups of maxSize
func chunkControls(controls []CanonicalControlSummary, maxSize int) [][]CanonicalControlSummary {
if maxSize <= 0 {
maxSize = 20
}
var chunks [][]CanonicalControlSummary
for i := 0; i < len(controls); i += maxSize {
end := i + maxSize
if end > len(controls) {
end = len(controls)
}
chunks = append(chunks, controls[i:end])
}
return chunks
}

View File

@@ -0,0 +1,224 @@
package training
import (
"testing"
)
func TestChunkControls_EmptySlice(t *testing.T) {
chunks := chunkControls(nil, 20)
if len(chunks) != 0 {
t.Errorf("expected 0 chunks, got %d", len(chunks))
}
}
func TestChunkControls_SingleChunk(t *testing.T) {
controls := make([]CanonicalControlSummary, 5)
for i := range controls {
controls[i].ControlID = "CTRL-" + string(rune('A'+i))
}
chunks := chunkControls(controls, 20)
if len(chunks) != 1 {
t.Errorf("expected 1 chunk, got %d", len(chunks))
}
if len(chunks[0]) != 5 {
t.Errorf("expected 5 controls in chunk, got %d", len(chunks[0]))
}
}
func TestChunkControls_MultipleChunks(t *testing.T) {
controls := make([]CanonicalControlSummary, 25)
for i := range controls {
controls[i].ControlID = "CTRL-" + string(rune('A'+i%26))
}
chunks := chunkControls(controls, 10)
if len(chunks) != 3 {
t.Errorf("expected 3 chunks, got %d", len(chunks))
}
if len(chunks[0]) != 10 {
t.Errorf("expected 10 in first chunk, got %d", len(chunks[0]))
}
if len(chunks[2]) != 5 {
t.Errorf("expected 5 in last chunk, got %d", len(chunks[2]))
}
}
func TestChunkControls_ExactMultiple(t *testing.T) {
controls := make([]CanonicalControlSummary, 20)
chunks := chunkControls(controls, 10)
if len(chunks) != 2 {
t.Errorf("expected 2 chunks, got %d", len(chunks))
}
}
func TestBlockGenerator_DeriveRoles_EnterpriseAudience(t *testing.T) {
bg := &BlockGenerator{}
controls := []CanonicalControlSummary{
{ControlID: "AUTH-001", Category: "authentication", TargetAudience: "enterprise"},
}
roles := bg.deriveRoles(controls, "enterprise")
// Should include enterprise mapping roles
roleSet := map[string]bool{}
for _, r := range roles {
roleSet[r] = true
}
// Enterprise maps to R1, R4, R5, R6, R7, R9
if !roleSet[RoleR1] {
t.Error("expected R1 for enterprise audience")
}
if !roleSet[RoleR9] {
t.Error("expected R9 for enterprise audience")
}
}
func TestBlockGenerator_DeriveRoles_AuthorityAudience(t *testing.T) {
bg := &BlockGenerator{}
controls := []CanonicalControlSummary{
{ControlID: "GOV-001", Category: "governance", TargetAudience: "authority"},
}
roles := bg.deriveRoles(controls, "authority")
roleSet := map[string]bool{}
for _, r := range roles {
roleSet[r] = true
}
// Authority maps to R10
if !roleSet[RoleR10] {
t.Error("expected R10 for authority audience")
}
}
func TestBlockGenerator_DeriveRoles_NoFilter(t *testing.T) {
bg := &BlockGenerator{}
controls := []CanonicalControlSummary{
{ControlID: "ENC-001", Category: "encryption", TargetAudience: "provider"},
}
roles := bg.deriveRoles(controls, "")
roleSet := map[string]bool{}
for _, r := range roles {
roleSet[r] = true
}
// Without audience filter, should use per-control audience + category
// encryption → R2, R8
// provider → R2, R8
if !roleSet[RoleR2] {
t.Error("expected R2 from encryption category + provider audience")
}
if !roleSet[RoleR8] {
t.Error("expected R8 from encryption category + provider audience")
}
}
func TestBlockGenerator_DeriveRoles_DefaultToR9(t *testing.T) {
bg := &BlockGenerator{}
controls := []CanonicalControlSummary{
{ControlID: "UNK-001", Category: "", TargetAudience: ""},
}
roles := bg.deriveRoles(controls, "")
if len(roles) != 1 || roles[0] != RoleR9 {
t.Errorf("expected [R9] default, got %v", roles)
}
}
func TestBlockGenerator_ExtractControlIDs(t *testing.T) {
bg := &BlockGenerator{}
controls := []CanonicalControlSummary{
{ControlID: "AUTH-001"},
{ControlID: "AUTH-002"},
{ControlID: "ENC-010"},
}
ids := bg.extractControlIDs(controls)
if len(ids) != 3 {
t.Errorf("expected 3 IDs, got %d", len(ids))
}
if ids[0] != "AUTH-001" || ids[1] != "AUTH-002" || ids[2] != "ENC-010" {
t.Errorf("unexpected IDs: %v", ids)
}
}
func TestBlockGenerator_BuildModuleTitle_SinglePart(t *testing.T) {
bg := &BlockGenerator{}
config := &TrainingBlockConfig{Name: "Authentifizierung"}
controls := []CanonicalControlSummary{{ControlID: "AUTH-001"}}
title := bg.buildModuleTitle(config, controls, 1, 1)
if title != "Authentifizierung" {
t.Errorf("expected 'Authentifizierung', got '%s'", title)
}
}
func TestBlockGenerator_BuildModuleTitle_MultiPart(t *testing.T) {
bg := &BlockGenerator{}
config := &TrainingBlockConfig{Name: "Authentifizierung"}
controls := []CanonicalControlSummary{{ControlID: "AUTH-001"}}
title := bg.buildModuleTitle(config, controls, 2, 3)
expected := "Authentifizierung (Teil 2/3)"
if title != expected {
t.Errorf("expected '%s', got '%s'", expected, title)
}
}
func TestNilIfEmpty(t *testing.T) {
tests := []struct {
input string
expected bool // true = nil result
}{
{"", true},
{" ", true},
{"value", false},
{" value ", false},
}
for _, tt := range tests {
result := nilIfEmpty(tt.input)
if tt.expected && result != nil {
t.Errorf("nilIfEmpty(%q) = %v, expected nil", tt.input, *result)
}
if !tt.expected && result == nil {
t.Errorf("nilIfEmpty(%q) = nil, expected non-nil", tt.input)
}
}
}
func TestTargetAudienceRoleMapping_AllKeys(t *testing.T) {
expectedKeys := []string{"enterprise", "authority", "provider", "all"}
for _, key := range expectedKeys {
roles, ok := TargetAudienceRoleMapping[key]
if !ok {
t.Errorf("missing key '%s' in TargetAudienceRoleMapping", key)
}
if len(roles) == 0 {
t.Errorf("empty roles for key '%s'", key)
}
}
}
func TestCategoryRoleMapping_HasEntries(t *testing.T) {
if len(CategoryRoleMapping) == 0 {
t.Error("CategoryRoleMapping is empty")
}
// Verify some expected entries
if _, ok := CategoryRoleMapping["encryption"]; !ok {
t.Error("missing 'encryption' in CategoryRoleMapping")
}
if _, ok := CategoryRoleMapping["authentication"]; !ok {
t.Error("missing 'authentication' in CategoryRoleMapping")
}
if _, ok := CategoryRoleMapping["data_protection"]; !ok {
t.Error("missing 'data_protection' in CategoryRoleMapping")
}
}

View File

@@ -0,0 +1,484 @@
package training
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
// ============================================================================
// Block Config CRUD
// ============================================================================
// CreateBlockConfig creates a new training block configuration
func (s *Store) CreateBlockConfig(ctx context.Context, config *TrainingBlockConfig) error {
config.ID = uuid.New()
config.CreatedAt = time.Now().UTC()
config.UpdatedAt = config.CreatedAt
if !config.IsActive {
config.IsActive = true
}
_, err := s.pool.Exec(ctx, `
INSERT INTO training_block_configs (
id, tenant_id, name, description,
domain_filter, category_filter, severity_filter, target_audience_filter,
regulation_area, module_code_prefix, frequency_type,
duration_minutes, pass_threshold, max_controls_per_module,
is_active, created_at, updated_at
) VALUES (
$1, $2, $3, $4,
$5, $6, $7, $8,
$9, $10, $11,
$12, $13, $14,
$15, $16, $17
)
`,
config.ID, config.TenantID, config.Name, config.Description,
nilIfEmpty(config.DomainFilter), nilIfEmpty(config.CategoryFilter),
nilIfEmpty(config.SeverityFilter), nilIfEmpty(config.TargetAudienceFilter),
string(config.RegulationArea), config.ModuleCodePrefix, string(config.FrequencyType),
config.DurationMinutes, config.PassThreshold, config.MaxControlsPerModule,
config.IsActive, config.CreatedAt, config.UpdatedAt,
)
return err
}
// GetBlockConfig retrieves a block config by ID
func (s *Store) GetBlockConfig(ctx context.Context, id uuid.UUID) (*TrainingBlockConfig, error) {
var config TrainingBlockConfig
var regulationArea, frequencyType string
var domainFilter, categoryFilter, severityFilter, targetAudienceFilter *string
err := s.pool.QueryRow(ctx, `
SELECT
id, tenant_id, name, description,
domain_filter, category_filter, severity_filter, target_audience_filter,
regulation_area, module_code_prefix, frequency_type,
duration_minutes, pass_threshold, max_controls_per_module,
is_active, last_generated_at, created_at, updated_at
FROM training_block_configs WHERE id = $1
`, id).Scan(
&config.ID, &config.TenantID, &config.Name, &config.Description,
&domainFilter, &categoryFilter, &severityFilter, &targetAudienceFilter,
&regulationArea, &config.ModuleCodePrefix, &frequencyType,
&config.DurationMinutes, &config.PassThreshold, &config.MaxControlsPerModule,
&config.IsActive, &config.LastGeneratedAt, &config.CreatedAt, &config.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
config.RegulationArea = RegulationArea(regulationArea)
config.FrequencyType = FrequencyType(frequencyType)
if domainFilter != nil {
config.DomainFilter = *domainFilter
}
if categoryFilter != nil {
config.CategoryFilter = *categoryFilter
}
if severityFilter != nil {
config.SeverityFilter = *severityFilter
}
if targetAudienceFilter != nil {
config.TargetAudienceFilter = *targetAudienceFilter
}
return &config, nil
}
// ListBlockConfigs returns all block configs for a tenant
func (s *Store) ListBlockConfigs(ctx context.Context, tenantID uuid.UUID) ([]TrainingBlockConfig, error) {
rows, err := s.pool.Query(ctx, `
SELECT
id, tenant_id, name, description,
domain_filter, category_filter, severity_filter, target_audience_filter,
regulation_area, module_code_prefix, frequency_type,
duration_minutes, pass_threshold, max_controls_per_module,
is_active, last_generated_at, created_at, updated_at
FROM training_block_configs
WHERE tenant_id = $1
ORDER BY created_at DESC
`, tenantID)
if err != nil {
return nil, err
}
defer rows.Close()
var configs []TrainingBlockConfig
for rows.Next() {
var config TrainingBlockConfig
var regulationArea, frequencyType string
var domainFilter, categoryFilter, severityFilter, targetAudienceFilter *string
if err := rows.Scan(
&config.ID, &config.TenantID, &config.Name, &config.Description,
&domainFilter, &categoryFilter, &severityFilter, &targetAudienceFilter,
&regulationArea, &config.ModuleCodePrefix, &frequencyType,
&config.DurationMinutes, &config.PassThreshold, &config.MaxControlsPerModule,
&config.IsActive, &config.LastGeneratedAt, &config.CreatedAt, &config.UpdatedAt,
); err != nil {
return nil, err
}
config.RegulationArea = RegulationArea(regulationArea)
config.FrequencyType = FrequencyType(frequencyType)
if domainFilter != nil {
config.DomainFilter = *domainFilter
}
if categoryFilter != nil {
config.CategoryFilter = *categoryFilter
}
if severityFilter != nil {
config.SeverityFilter = *severityFilter
}
if targetAudienceFilter != nil {
config.TargetAudienceFilter = *targetAudienceFilter
}
configs = append(configs, config)
}
if configs == nil {
configs = []TrainingBlockConfig{}
}
return configs, nil
}
// UpdateBlockConfig updates a block config
func (s *Store) UpdateBlockConfig(ctx context.Context, config *TrainingBlockConfig) error {
config.UpdatedAt = time.Now().UTC()
_, err := s.pool.Exec(ctx, `
UPDATE training_block_configs SET
name = $2, description = $3,
domain_filter = $4, category_filter = $5,
severity_filter = $6, target_audience_filter = $7,
max_controls_per_module = $8, duration_minutes = $9,
pass_threshold = $10, is_active = $11, updated_at = $12
WHERE id = $1
`,
config.ID, config.Name, config.Description,
nilIfEmpty(config.DomainFilter), nilIfEmpty(config.CategoryFilter),
nilIfEmpty(config.SeverityFilter), nilIfEmpty(config.TargetAudienceFilter),
config.MaxControlsPerModule, config.DurationMinutes,
config.PassThreshold, config.IsActive, config.UpdatedAt,
)
return err
}
// DeleteBlockConfig deletes a block config (cascades to control links)
func (s *Store) DeleteBlockConfig(ctx context.Context, id uuid.UUID) error {
_, err := s.pool.Exec(ctx, `DELETE FROM training_block_configs WHERE id = $1`, id)
return err
}
// UpdateBlockConfigLastGenerated updates the last_generated_at timestamp
func (s *Store) UpdateBlockConfigLastGenerated(ctx context.Context, id uuid.UUID) error {
now := time.Now().UTC()
_, err := s.pool.Exec(ctx, `
UPDATE training_block_configs SET last_generated_at = $2, updated_at = $2 WHERE id = $1
`, id, now)
return err
}
// ============================================================================
// Block Control Links
// ============================================================================
// CreateBlockControlLink creates a link between a block config, a module, and a control
func (s *Store) CreateBlockControlLink(ctx context.Context, link *TrainingBlockControlLink) error {
link.ID = uuid.New()
link.CreatedAt = time.Now().UTC()
requirements, _ := json.Marshal(link.ControlRequirements)
_, err := s.pool.Exec(ctx, `
INSERT INTO training_block_control_links (
id, block_config_id, module_id, control_id,
control_title, control_objective, control_requirements,
sort_order, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`,
link.ID, link.BlockConfigID, link.ModuleID, link.ControlID,
link.ControlTitle, link.ControlObjective, requirements,
link.SortOrder, link.CreatedAt,
)
return err
}
// GetControlLinksForBlock returns all control links for a block config
func (s *Store) GetControlLinksForBlock(ctx context.Context, blockConfigID uuid.UUID) ([]TrainingBlockControlLink, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, block_config_id, module_id, control_id,
control_title, control_objective, control_requirements,
sort_order, created_at
FROM training_block_control_links
WHERE block_config_id = $1
ORDER BY sort_order
`, blockConfigID)
if err != nil {
return nil, err
}
defer rows.Close()
var links []TrainingBlockControlLink
for rows.Next() {
var link TrainingBlockControlLink
var requirements []byte
if err := rows.Scan(
&link.ID, &link.BlockConfigID, &link.ModuleID, &link.ControlID,
&link.ControlTitle, &link.ControlObjective, &requirements,
&link.SortOrder, &link.CreatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(requirements, &link.ControlRequirements)
if link.ControlRequirements == nil {
link.ControlRequirements = []string{}
}
links = append(links, link)
}
if links == nil {
links = []TrainingBlockControlLink{}
}
return links, nil
}
// GetControlLinksForModule returns all control links for a specific module
func (s *Store) GetControlLinksForModule(ctx context.Context, moduleID uuid.UUID) ([]TrainingBlockControlLink, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, block_config_id, module_id, control_id,
control_title, control_objective, control_requirements,
sort_order, created_at
FROM training_block_control_links
WHERE module_id = $1
ORDER BY sort_order
`, moduleID)
if err != nil {
return nil, err
}
defer rows.Close()
var links []TrainingBlockControlLink
for rows.Next() {
var link TrainingBlockControlLink
var requirements []byte
if err := rows.Scan(
&link.ID, &link.BlockConfigID, &link.ModuleID, &link.ControlID,
&link.ControlTitle, &link.ControlObjective, &requirements,
&link.SortOrder, &link.CreatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(requirements, &link.ControlRequirements)
if link.ControlRequirements == nil {
link.ControlRequirements = []string{}
}
links = append(links, link)
}
if links == nil {
links = []TrainingBlockControlLink{}
}
return links, nil
}
// ============================================================================
// Canonical Controls Query (reads from shared DB table)
// ============================================================================
// QueryCanonicalControls queries canonical_controls with dynamic filters.
// Domain is derived from the control_id prefix (e.g. "AUTH" from "AUTH-042").
func (s *Store) QueryCanonicalControls(ctx context.Context,
domain, category, severity, targetAudience string,
) ([]CanonicalControlSummary, error) {
query := `SELECT control_id, title, objective, rationale,
requirements, severity, COALESCE(category, ''), COALESCE(target_audience, ''), COALESCE(tags, '[]')
FROM canonical_controls
WHERE release_state NOT IN ('deprecated', 'draft')
AND customer_visible = true`
args := []interface{}{}
argIdx := 1
if domain != "" {
query += fmt.Sprintf(` AND LEFT(control_id, %d) = $%d`, len(domain), argIdx)
args = append(args, domain)
argIdx++
}
if category != "" {
query += fmt.Sprintf(` AND category = $%d`, argIdx)
args = append(args, category)
argIdx++
}
if severity != "" {
query += fmt.Sprintf(` AND severity = $%d`, argIdx)
args = append(args, severity)
argIdx++
}
if targetAudience != "" {
query += fmt.Sprintf(` AND (target_audience = $%d OR target_audience = 'all')`, argIdx)
args = append(args, targetAudience)
argIdx++
}
query += ` ORDER BY control_id`
rows, err := s.pool.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("query canonical controls: %w", err)
}
defer rows.Close()
var controls []CanonicalControlSummary
for rows.Next() {
var c CanonicalControlSummary
var requirementsJSON, tagsJSON []byte
if err := rows.Scan(
&c.ControlID, &c.Title, &c.Objective, &c.Rationale,
&requirementsJSON, &c.Severity, &c.Category, &c.TargetAudience, &tagsJSON,
); err != nil {
return nil, err
}
json.Unmarshal(requirementsJSON, &c.Requirements)
if c.Requirements == nil {
c.Requirements = []string{}
}
json.Unmarshal(tagsJSON, &c.Tags)
if c.Tags == nil {
c.Tags = []string{}
}
controls = append(controls, c)
}
if controls == nil {
controls = []CanonicalControlSummary{}
}
return controls, nil
}
// GetCanonicalControlMeta returns aggregated metadata about canonical controls
func (s *Store) GetCanonicalControlMeta(ctx context.Context) (*CanonicalControlMeta, error) {
meta := &CanonicalControlMeta{}
// Total count
err := s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM canonical_controls
WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true
`).Scan(&meta.Total)
if err != nil {
return nil, fmt.Errorf("count canonical controls: %w", err)
}
// Domains (derived from control_id prefix)
rows, err := s.pool.Query(ctx, `
SELECT LEFT(control_id, POSITION('-' IN control_id) - 1) AS domain, COUNT(*) AS cnt
FROM canonical_controls
WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true
AND POSITION('-' IN control_id) > 0
GROUP BY domain ORDER BY cnt DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var d DomainCount
if err := rows.Scan(&d.Domain, &d.Count); err != nil {
return nil, err
}
meta.Domains = append(meta.Domains, d)
}
if meta.Domains == nil {
meta.Domains = []DomainCount{}
}
// Categories
catRows, err := s.pool.Query(ctx, `
SELECT COALESCE(category, 'uncategorized') AS cat, COUNT(*) AS cnt
FROM canonical_controls
WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true
GROUP BY cat ORDER BY cnt DESC
`)
if err != nil {
return nil, err
}
defer catRows.Close()
for catRows.Next() {
var c CategoryCount
if err := catRows.Scan(&c.Category, &c.Count); err != nil {
return nil, err
}
meta.Categories = append(meta.Categories, c)
}
if meta.Categories == nil {
meta.Categories = []CategoryCount{}
}
// Target audiences
audRows, err := s.pool.Query(ctx, `
SELECT COALESCE(target_audience, 'unset') AS aud, COUNT(*) AS cnt
FROM canonical_controls
WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true
GROUP BY aud ORDER BY cnt DESC
`)
if err != nil {
return nil, err
}
defer audRows.Close()
for audRows.Next() {
var a AudienceCount
if err := audRows.Scan(&a.Audience, &a.Count); err != nil {
return nil, err
}
meta.Audiences = append(meta.Audiences, a)
}
if meta.Audiences == nil {
meta.Audiences = []AudienceCount{}
}
return meta, nil
}
// ============================================================================
// Helpers
// ============================================================================
// CountModulesWithPrefix counts existing modules with a given code prefix for auto-numbering
func (s *Store) CountModulesWithPrefix(ctx context.Context, tenantID uuid.UUID, prefix string) (int, error) {
var count int
err := s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM training_modules
WHERE tenant_id = $1 AND module_code LIKE $2
`, tenantID, prefix+"-%").Scan(&count)
return count, err
}
func nilIfEmpty(s string) *string {
s = strings.TrimSpace(s)
if s == "" {
return nil
}
return &s
}

View File

@@ -294,6 +294,133 @@ func parseQuizResponse(response string, moduleID uuid.UUID) ([]QuizQuestion, err
return questions, nil
}
// GenerateBlockContent generates training content for a module based on linked canonical controls
func (g *ContentGenerator) GenerateBlockContent(
ctx context.Context,
module TrainingModule,
controls []CanonicalControlSummary,
language string,
) (*ModuleContent, error) {
if language == "" {
language = "de"
}
prompt := buildBlockContentPrompt(module, controls, language)
resp, err := g.registry.Chat(ctx, &llm.ChatRequest{
Messages: []llm.Message{
{Role: "system", Content: getContentSystemPrompt(language)},
{Role: "user", Content: prompt},
},
Temperature: 0.15,
MaxTokens: 8192,
})
if err != nil {
return nil, fmt.Errorf("LLM block content generation failed: %w", err)
}
contentBody := resp.Message.Content
// PII check
if g.piiDetector != nil && g.piiDetector.ContainsPII(contentBody) {
findings := g.piiDetector.FindPII(contentBody)
for _, f := range findings {
contentBody = strings.ReplaceAll(contentBody, f.Match, "[REDACTED]")
}
}
summary := contentBody
if len(summary) > 200 {
summary = summary[:200] + "..."
}
content := &ModuleContent{
ModuleID: module.ID,
ContentFormat: ContentFormatMarkdown,
ContentBody: contentBody,
Summary: summary,
GeneratedBy: "llm_block_" + resp.Provider,
LLMModel: resp.Model,
IsPublished: false,
}
if err := g.store.CreateModuleContent(ctx, content); err != nil {
return nil, fmt.Errorf("failed to save block content: %w", err)
}
// Audit log
g.store.LogAction(ctx, &AuditLogEntry{
TenantID: module.TenantID,
Action: AuditActionContentGenerated,
EntityType: AuditEntityModule,
EntityID: &module.ID,
Details: map[string]interface{}{
"module_code": module.ModuleCode,
"provider": resp.Provider,
"model": resp.Model,
"content_id": content.ID.String(),
"version": content.Version,
"tokens_used": resp.Usage.TotalTokens,
"controls_count": len(controls),
"source": "block_generator",
},
})
return content, nil
}
// buildBlockContentPrompt creates a prompt that incorporates canonical controls
func buildBlockContentPrompt(module TrainingModule, controls []CanonicalControlSummary, language string) string {
var sb strings.Builder
if language == "en" {
sb.WriteString(fmt.Sprintf("Create training material for the following compliance module:\n\n"))
sb.WriteString(fmt.Sprintf("**Module Code:** %s\n", module.ModuleCode))
sb.WriteString(fmt.Sprintf("**Title:** %s\n", module.Title))
sb.WriteString(fmt.Sprintf("**Duration:** %d minutes\n\n", module.DurationMinutes))
sb.WriteString(fmt.Sprintf("This module is based on %d security controls:\n\n", len(controls)))
} else {
sb.WriteString(fmt.Sprintf("Erstelle Schulungsmaterial fuer folgendes Compliance-Modul:\n\n"))
sb.WriteString(fmt.Sprintf("**Modulcode:** %s\n", module.ModuleCode))
sb.WriteString(fmt.Sprintf("**Titel:** %s\n", module.Title))
sb.WriteString(fmt.Sprintf("**Dauer:** %d Minuten\n\n", module.DurationMinutes))
sb.WriteString(fmt.Sprintf("Dieses Modul basiert auf %d Sicherheits-Controls:\n\n", len(controls)))
}
for i, ctrl := range controls {
sb.WriteString(fmt.Sprintf("### Control %d: %s — %s\n", i+1, ctrl.ControlID, ctrl.Title))
sb.WriteString(fmt.Sprintf("**Ziel:** %s\n", ctrl.Objective))
if len(ctrl.Requirements) > 0 {
sb.WriteString("**Anforderungen:**\n")
for _, req := range ctrl.Requirements {
sb.WriteString(fmt.Sprintf("- %s\n", req))
}
}
sb.WriteString("\n")
}
if language == "en" {
sb.WriteString(`Create the material as Markdown:
1. Introduction: Why are these controls important?
2. Per control: Explanation, practical tips, examples
3. Summary + action items
4. Checklist for daily work
Use clear, understandable language. Target audience: employees in companies (50-1,500 employees).`)
} else {
sb.WriteString(`Erstelle das Material als Markdown:
1. Einfuehrung: Warum sind diese Controls wichtig?
2. Pro Control: Erklaerung, praktische Hinweise, Beispiele
3. Zusammenfassung + Handlungsanweisungen
4. Checkliste fuer den Alltag
Verwende klare, verstaendliche Sprache. Zielgruppe sind Mitarbeiter in Unternehmen (50-1.500 MA).
Formatiere den Inhalt als Markdown mit Ueberschriften, Aufzaehlungen und Hervorhebungen.`)
}
return sb.String()
}
// GenerateAllModuleContent generates text content for all modules that don't have published content yet
func (g *ContentGenerator) GenerateAllModuleContent(ctx context.Context, tenantID uuid.UUID, language string) (*BulkResult, error) {
if language == "" {
@@ -600,3 +727,252 @@ func truncateText(text string, maxLen int) string {
}
return text[:maxLen] + "..."
}
// ============================================================================
// Interactive Video Pipeline
// ============================================================================
const narratorSystemPrompt = `Du bist ein professioneller AI Teacher fuer Compliance-Schulungen.
Dein Stil ist foermlich aber freundlich, klar und paedagogisch wertvoll.
Du sprichst die Lernenden direkt an ("Sie") und fuehrst sie durch die Schulung.
Du erzeugst IMMER deutschsprachige Inhalte.
Dein Output ist ein JSON-Objekt im Format NarratorScript.
Jede Section sollte etwa 3 Minuten Sprechzeit haben (~450 Woerter Narrator-Text).
Nach jeder Section kommt ein Checkpoint mit 3-5 Quiz-Fragen.
Die Fragen testen das Verstaendnis des gerade Gelernten.
Jede Frage hat genau 4 Antwortmoeglichkeiten, wobei correct_index (0-basiert) die richtige Antwort angibt.
Antworte NUR mit dem JSON-Objekt, ohne Markdown-Codeblock-Wrapper.`
// GenerateNarratorScript generates a narrator-style video script with checkpoints via LLM
func (g *ContentGenerator) GenerateNarratorScript(ctx context.Context, module TrainingModule) (*NarratorScript, error) {
content, err := g.store.GetPublishedContent(ctx, module.ID)
if err != nil {
return nil, fmt.Errorf("failed to get content: %w", err)
}
contentContext := ""
if content != nil {
contentContext = fmt.Sprintf("\n\n**Vorhandener Schulungsinhalt (als Basis):**\n%s", truncateText(content.ContentBody, 4000))
}
prompt := fmt.Sprintf(`Erstelle ein interaktives Schulungsvideo-Skript mit Erzaehlerpersona und Checkpoints.
**Modul:** %s — %s
**Verordnung:** %s
**Beschreibung:** %s
**Dauer:** ca. %d Minuten
%s
Erstelle ein NarratorScript-JSON mit:
- "title": Titel der Schulung
- "intro": Begruessungstext ("Hallo, ich bin Ihr AI Teacher. Heute lernen Sie...")
- "sections": Array mit 3-4 Abschnitten, jeder mit:
- "heading": Abschnittsueberschrift
- "narrator_text": Fliesstext im Erzaehlstil (~450 Woerter, ~3 Min Sprechzeit)
- "bullet_points": 3-5 Kernpunkte fuer die Folie
- "transition": Ueberleitung zum naechsten Abschnitt oder Checkpoint
- "checkpoint": Quiz-Block mit:
- "title": Checkpoint-Titel
- "questions": Array mit 3-5 Fragen, je:
- "question": Fragetext
- "options": Array mit 4 Antworten
- "correct_index": Index der richtigen Antwort (0-basiert)
- "explanation": Erklaerung der richtigen Antwort
- "outro": Abschlussworte
- "total_duration_estimate": geschaetzte Gesamtdauer in Sekunden
Antworte NUR mit dem JSON-Objekt.`,
module.ModuleCode, module.Title,
string(module.RegulationArea),
module.Description,
module.DurationMinutes,
contentContext,
)
resp, err := g.registry.Chat(ctx, &llm.ChatRequest{
Messages: []llm.Message{
{Role: "system", Content: narratorSystemPrompt},
{Role: "user", Content: prompt},
},
Temperature: 0.2,
MaxTokens: 8192,
})
if err != nil {
return nil, fmt.Errorf("LLM narrator script generation failed: %w", err)
}
return parseNarratorScript(resp.Message.Content)
}
// parseNarratorScript extracts a NarratorScript from LLM output
func parseNarratorScript(content string) (*NarratorScript, error) {
// Find JSON object in response
start := strings.Index(content, "{")
end := strings.LastIndex(content, "}")
if start < 0 || end <= start {
return nil, fmt.Errorf("no JSON object found in LLM response")
}
jsonStr := content[start : end+1]
var script NarratorScript
if err := json.Unmarshal([]byte(jsonStr), &script); err != nil {
return nil, fmt.Errorf("failed to parse narrator script JSON: %w", err)
}
if len(script.Sections) == 0 {
return nil, fmt.Errorf("narrator script has no sections")
}
return &script, nil
}
// GenerateInteractiveVideo orchestrates the full interactive video pipeline:
// NarratorScript → TTS Audio → Slides+Video → DB Checkpoints + Quiz Questions
func (g *ContentGenerator) GenerateInteractiveVideo(ctx context.Context, module TrainingModule) (*TrainingMedia, error) {
if g.ttsClient == nil {
return nil, fmt.Errorf("TTS client not configured")
}
// 1. Generate NarratorScript via LLM
script, err := g.GenerateNarratorScript(ctx, module)
if err != nil {
return nil, fmt.Errorf("narrator script generation failed: %w", err)
}
// 2. Synthesize audio per section via TTS service
sections := make([]SectionAudio, len(script.Sections))
for i, s := range script.Sections {
// Combine narrator text with intro/outro for first/last section
text := s.NarratorText
if i == 0 && script.Intro != "" {
text = script.Intro + "\n\n" + text
}
if i == len(script.Sections)-1 && script.Outro != "" {
text = text + "\n\n" + script.Outro
}
sections[i] = SectionAudio{
Text: text,
Heading: s.Heading,
}
}
audioResp, err := g.ttsClient.SynthesizeSections(ctx, &SynthesizeSectionsRequest{
Sections: sections,
Voice: "de_DE-thorsten-high",
ModuleID: module.ID.String(),
})
if err != nil {
return nil, fmt.Errorf("section audio synthesis failed: %w", err)
}
// 3. Generate interactive video via TTS service
videoResp, err := g.ttsClient.GenerateInteractiveVideo(ctx, &GenerateInteractiveVideoRequest{
Script: script,
Audio: audioResp,
ModuleID: module.ID.String(),
})
if err != nil {
return nil, fmt.Errorf("interactive video generation failed: %w", err)
}
// 4. Save TrainingMedia record
scriptJSON, _ := json.Marshal(script)
media := &TrainingMedia{
ModuleID: module.ID,
MediaType: MediaTypeInteractiveVideo,
Status: MediaStatusProcessing,
Bucket: "compliance-training-video",
ObjectKey: fmt.Sprintf("video/%s/interactive.mp4", module.ID.String()),
MimeType: "video/mp4",
Language: "de",
GeneratedBy: "tts_ffmpeg_interactive",
Metadata: scriptJSON,
}
if err := g.store.CreateMedia(ctx, media); err != nil {
return nil, fmt.Errorf("failed to create media record: %w", err)
}
// Update media with video result
media.Status = MediaStatusCompleted
media.FileSizeBytes = videoResp.SizeBytes
media.DurationSeconds = videoResp.DurationSeconds
media.ObjectKey = videoResp.ObjectKey
media.Bucket = videoResp.Bucket
g.store.UpdateMediaStatus(ctx, media.ID, MediaStatusCompleted, videoResp.SizeBytes, videoResp.DurationSeconds, "")
// Auto-publish
g.store.PublishMedia(ctx, media.ID, true)
// 5. Create Checkpoints + Quiz Questions in DB
// Clear old checkpoints first
g.store.DeleteCheckpointsForModule(ctx, module.ID)
for i, section := range script.Sections {
if section.Checkpoint == nil {
continue
}
// Calculate timestamp from cumulative audio durations
var timestamp float64
if i < len(audioResp.Sections) {
// Checkpoint timestamp = end of this section's audio
timestamp = audioResp.Sections[i].StartTimestamp + audioResp.Sections[i].Duration
}
cp := &Checkpoint{
ModuleID: module.ID,
CheckpointIndex: i,
Title: section.Checkpoint.Title,
TimestampSeconds: timestamp,
}
if err := g.store.CreateCheckpoint(ctx, cp); err != nil {
return nil, fmt.Errorf("failed to create checkpoint %d: %w", i, err)
}
// Save quiz questions for this checkpoint
for j, q := range section.Checkpoint.Questions {
question := &QuizQuestion{
ModuleID: module.ID,
Question: q.Question,
Options: q.Options,
CorrectIndex: q.CorrectIndex,
Explanation: q.Explanation,
Difficulty: DifficultyMedium,
SortOrder: j,
}
if err := g.store.CreateCheckpointQuizQuestion(ctx, question, cp.ID); err != nil {
return nil, fmt.Errorf("failed to create checkpoint question: %w", err)
}
}
}
// 6. Audit log
g.store.LogAction(ctx, &AuditLogEntry{
TenantID: module.TenantID,
Action: AuditAction("interactive_video_generated"),
EntityType: AuditEntityModule,
EntityID: &module.ID,
Details: map[string]interface{}{
"module_code": module.ModuleCode,
"media_id": media.ID.String(),
"duration_seconds": videoResp.DurationSeconds,
"sections": len(script.Sections),
"checkpoints": countCheckpoints(script),
},
})
return media, nil
}
func countCheckpoints(script *NarratorScript) int {
count := 0
for _, s := range script.Sections {
if s.Checkpoint != nil {
count++
}
}
return count
}

View File

@@ -0,0 +1,552 @@
package training
import (
"testing"
"github.com/google/uuid"
)
// =============================================================================
// buildContentPrompt Tests
// =============================================================================
func TestBuildContentPrompt_ContainsModuleCode(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-TRAIN-001",
Title: "DSGVO Grundlagen",
Description: "Basis-Schulung",
RegulationArea: RegulationDSGVO,
DurationMinutes: 30,
}
prompt := buildContentPrompt(module, "de")
if !containsSubstring(prompt, "CP-TRAIN-001") {
t.Error("Prompt should contain module code")
}
}
func TestBuildContentPrompt_ContainsTitle(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "DSGVO Grundlagen",
RegulationArea: RegulationDSGVO,
DurationMinutes: 30,
}
prompt := buildContentPrompt(module, "de")
if !containsSubstring(prompt, "DSGVO Grundlagen") {
t.Error("Prompt should contain module title")
}
}
func TestBuildContentPrompt_ContainsRegulationLabel(t *testing.T) {
tests := []struct {
name string
area RegulationArea
expected string
}{
{"DSGVO", RegulationDSGVO, "Datenschutz-Grundverordnung"},
{"NIS2", RegulationNIS2, "NIS-2-Richtlinie"},
{"ISO27001", RegulationISO27001, "ISO 27001"},
{"AIAct", RegulationAIAct, "AI Act"},
{"GeschGehG", RegulationGeschGehG, "Geschaeftsgeheimnisgesetz"},
{"HinSchG", RegulationHinSchG, "Hinweisgeberschutzgesetz"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test Module",
RegulationArea: tt.area,
DurationMinutes: 30,
}
prompt := buildContentPrompt(module, "de")
if !containsSubstring(prompt, tt.expected) {
t.Errorf("Prompt should contain regulation label '%s' for area '%s'", tt.expected, tt.area)
}
})
}
}
func TestBuildContentPrompt_ContainsDuration(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test",
RegulationArea: RegulationDSGVO,
DurationMinutes: 45,
}
prompt := buildContentPrompt(module, "de")
if !containsSubstring(prompt, "45 Minuten") {
t.Error("Prompt should contain duration in minutes")
}
}
func TestBuildContentPrompt_UnknownRegulationArea(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test",
RegulationArea: RegulationArea("custom_regulation"),
DurationMinutes: 30,
}
prompt := buildContentPrompt(module, "de")
if !containsSubstring(prompt, "custom_regulation") {
t.Error("Unknown regulation area should fall back to raw string")
}
}
// =============================================================================
// buildQuizPrompt Tests
// =============================================================================
func TestBuildQuizPrompt_ContainsQuestionCount(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test Module",
RegulationArea: RegulationDSGVO,
}
prompt := buildQuizPrompt(module, "", 10)
if !containsSubstring(prompt, "10") {
t.Error("Quiz prompt should contain question count")
}
}
func TestBuildQuizPrompt_ContainsContentContext(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test",
RegulationArea: RegulationDSGVO,
}
prompt := buildQuizPrompt(module, "This is the module content about DSGVO.", 5)
if !containsSubstring(prompt, "This is the module content about DSGVO.") {
t.Error("Quiz prompt should include content context")
}
}
func TestBuildQuizPrompt_TruncatesLongContent(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test",
RegulationArea: RegulationDSGVO,
}
// Create content longer than 3000 chars
longContent := ""
for i := 0; i < 400; i++ {
longContent += "ABCDEFGHIJ" // 10 chars * 400 = 4000 chars
}
prompt := buildQuizPrompt(module, longContent, 5)
if containsSubstring(prompt, longContent) {
t.Error("Quiz prompt should truncate content longer than 3000 chars")
}
if !containsSubstring(prompt, "...") {
t.Error("Truncated content should end with '...'")
}
}
func TestBuildQuizPrompt_EmptyContent(t *testing.T) {
module := TrainingModule{
ModuleCode: "CP-001",
Title: "Test",
RegulationArea: RegulationDSGVO,
}
prompt := buildQuizPrompt(module, "", 5)
if containsSubstring(prompt, "Schulungsinhalt als Kontext") {
t.Error("Empty content should not add context section")
}
}
// =============================================================================
// parseQuizResponse Tests
// =============================================================================
func TestParseQuizResponse_ValidJSON(t *testing.T) {
moduleID := uuid.New()
response := `[
{
"question": "Was ist die DSGVO?",
"options": ["EU-Verordnung", "Bundesgesetz", "Landesgesetz", "Internationale Konvention"],
"correct_index": 0,
"explanation": "Die DSGVO ist eine EU-Verordnung.",
"difficulty": "easy"
}
]`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 1 {
t.Fatalf("Expected 1 question, got %d", len(questions))
}
if questions[0].Question != "Was ist die DSGVO?" {
t.Errorf("Expected question text, got '%s'", questions[0].Question)
}
if questions[0].CorrectIndex != 0 {
t.Errorf("Expected correct_index 0, got %d", questions[0].CorrectIndex)
}
if questions[0].Difficulty != DifficultyEasy {
t.Errorf("Expected difficulty 'easy', got '%s'", questions[0].Difficulty)
}
if questions[0].ModuleID != moduleID {
t.Error("Module ID should be set on parsed question")
}
if !questions[0].IsActive {
t.Error("Parsed questions should be active by default")
}
}
func TestParseQuizResponse_InvalidJSON(t *testing.T) {
moduleID := uuid.New()
_, err := parseQuizResponse("not valid json at all", moduleID)
if err == nil {
t.Error("Expected error for invalid JSON")
}
}
func TestParseQuizResponse_JSONWithSurroundingText(t *testing.T) {
moduleID := uuid.New()
response := `Here are the questions:
[
{
"question": "Test?",
"options": ["A", "B", "C", "D"],
"correct_index": 1,
"explanation": "B is correct.",
"difficulty": "medium"
}
]
I hope these are helpful!`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 1 {
t.Fatalf("Expected 1 question, got %d", len(questions))
}
}
func TestParseQuizResponse_SkipsMalformedOptions(t *testing.T) {
moduleID := uuid.New()
response := `[
{
"question": "Good question?",
"options": ["A", "B", "C", "D"],
"correct_index": 0,
"explanation": "A is correct.",
"difficulty": "easy"
},
{
"question": "Bad question?",
"options": ["A", "B"],
"correct_index": 0,
"explanation": "Only 2 options.",
"difficulty": "easy"
}
]`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 1 {
t.Errorf("Expected 1 valid question (malformed should be skipped), got %d", len(questions))
}
}
func TestParseQuizResponse_SkipsInvalidCorrectIndex(t *testing.T) {
moduleID := uuid.New()
response := `[
{
"question": "Bad index?",
"options": ["A", "B", "C", "D"],
"correct_index": 5,
"explanation": "Index out of range.",
"difficulty": "medium"
}
]`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 0 {
t.Errorf("Expected 0 questions (invalid index should be skipped), got %d", len(questions))
}
}
func TestParseQuizResponse_NegativeCorrectIndex(t *testing.T) {
moduleID := uuid.New()
response := `[
{
"question": "Negative index?",
"options": ["A", "B", "C", "D"],
"correct_index": -1,
"explanation": "Negative index.",
"difficulty": "easy"
}
]`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 0 {
t.Errorf("Expected 0 questions (negative index should be skipped), got %d", len(questions))
}
}
func TestParseQuizResponse_DefaultsDifficultyToMedium(t *testing.T) {
moduleID := uuid.New()
response := `[
{
"question": "Test?",
"options": ["A", "B", "C", "D"],
"correct_index": 0,
"explanation": "A is correct.",
"difficulty": "unknown_difficulty"
}
]`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 1 {
t.Fatalf("Expected 1 question, got %d", len(questions))
}
if questions[0].Difficulty != DifficultyMedium {
t.Errorf("Expected difficulty to default to 'medium', got '%s'", questions[0].Difficulty)
}
}
func TestParseQuizResponse_MultipleQuestions(t *testing.T) {
moduleID := uuid.New()
response := `[
{"question":"Q1?","options":["A","B","C","D"],"correct_index":0,"explanation":"","difficulty":"easy"},
{"question":"Q2?","options":["A","B","C","D"],"correct_index":1,"explanation":"","difficulty":"medium"},
{"question":"Q3?","options":["A","B","C","D"],"correct_index":2,"explanation":"","difficulty":"hard"}
]`
questions, err := parseQuizResponse(response, moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 3 {
t.Errorf("Expected 3 questions, got %d", len(questions))
}
}
func TestParseQuizResponse_EmptyArray(t *testing.T) {
moduleID := uuid.New()
questions, err := parseQuizResponse("[]", moduleID)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(questions) != 0 {
t.Errorf("Expected 0 questions, got %d", len(questions))
}
}
// =============================================================================
// truncateText Tests
// =============================================================================
func TestTruncateText_ShortText(t *testing.T) {
result := truncateText("hello", 100)
if result != "hello" {
t.Errorf("Short text should not be truncated, got '%s'", result)
}
}
func TestTruncateText_ExactLength(t *testing.T) {
result := truncateText("12345", 5)
if result != "12345" {
t.Errorf("Text at exact max length should not be truncated, got '%s'", result)
}
}
func TestTruncateText_LongText(t *testing.T) {
result := truncateText("1234567890", 5)
if result != "12345..." {
t.Errorf("Expected '12345...', got '%s'", result)
}
}
func TestTruncateText_EmptyString(t *testing.T) {
result := truncateText("", 10)
if result != "" {
t.Errorf("Empty string should remain empty, got '%s'", result)
}
}
// =============================================================================
// System Prompt Tests
// =============================================================================
func TestGetContentSystemPrompt_German(t *testing.T) {
prompt := getContentSystemPrompt("de")
if !containsSubstring(prompt, "Compliance-Schulungsinhalte") {
t.Error("German system prompt should mention Compliance-Schulungsinhalte")
}
if !containsSubstring(prompt, "Markdown") {
t.Error("System prompt should mention Markdown format")
}
}
func TestGetContentSystemPrompt_English(t *testing.T) {
prompt := getContentSystemPrompt("en")
if !containsSubstring(prompt, "compliance training content") {
t.Error("English system prompt should mention compliance training content")
}
}
func TestGetQuizSystemPrompt_ContainsJSONFormat(t *testing.T) {
prompt := getQuizSystemPrompt()
if !containsSubstring(prompt, "JSON") {
t.Error("Quiz system prompt should mention JSON format")
}
if !containsSubstring(prompt, "correct_index") {
t.Error("Quiz system prompt should show correct_index field")
}
}
// =============================================================================
// buildBlockContentPrompt Tests
// =============================================================================
func TestBuildBlockContentPrompt_ContainsModuleInfo(t *testing.T) {
module := TrainingModule{
ModuleCode: "BLK-AUTH-001",
Title: "Authentication Controls",
DurationMinutes: 45,
}
controls := []CanonicalControlSummary{
{
ControlID: "AUTH-001",
Title: "Multi-Factor Authentication",
Objective: "Ensure MFA is enabled",
Requirements: []string{"Enable MFA for all users"},
},
}
prompt := buildBlockContentPrompt(module, controls, "de")
if !containsSubstring(prompt, "BLK-AUTH-001") {
t.Error("Block prompt should contain module code")
}
if !containsSubstring(prompt, "Authentication Controls") {
t.Error("Block prompt should contain module title")
}
if !containsSubstring(prompt, "45 Minuten") {
t.Error("Block prompt should contain duration")
}
}
func TestBuildBlockContentPrompt_ContainsControlDetails(t *testing.T) {
module := TrainingModule{
ModuleCode: "BLK-001",
Title: "Test",
DurationMinutes: 30,
}
controls := []CanonicalControlSummary{
{
ControlID: "CTRL-001",
Title: "Test Control",
Objective: "Test objective",
Requirements: []string{"Req 1", "Req 2"},
},
}
prompt := buildBlockContentPrompt(module, controls, "de")
if !containsSubstring(prompt, "CTRL-001") {
t.Error("Prompt should contain control ID")
}
if !containsSubstring(prompt, "Test Control") {
t.Error("Prompt should contain control title")
}
if !containsSubstring(prompt, "Test objective") {
t.Error("Prompt should contain control objective")
}
if !containsSubstring(prompt, "Req 1") {
t.Error("Prompt should contain control requirements")
}
}
func TestBuildBlockContentPrompt_EnglishVersion(t *testing.T) {
module := TrainingModule{
ModuleCode: "BLK-001",
Title: "Test",
DurationMinutes: 30,
}
controls := []CanonicalControlSummary{}
prompt := buildBlockContentPrompt(module, controls, "en")
if !containsSubstring(prompt, "Create training material") {
t.Error("English prompt should use English text")
}
}
func TestBuildBlockContentPrompt_MultipleControls(t *testing.T) {
module := TrainingModule{
ModuleCode: "BLK-001",
Title: "Test",
DurationMinutes: 30,
}
controls := []CanonicalControlSummary{
{ControlID: "CTRL-001", Title: "First Control", Objective: "Obj 1"},
{ControlID: "CTRL-002", Title: "Second Control", Objective: "Obj 2"},
{ControlID: "CTRL-003", Title: "Third Control", Objective: "Obj 3"},
}
prompt := buildBlockContentPrompt(module, controls, "de")
if !containsSubstring(prompt, "3 Sicherheits-Controls") {
t.Error("Prompt should mention the count of controls")
}
if !containsSubstring(prompt, "Control 1") {
t.Error("Prompt should number controls")
}
if !containsSubstring(prompt, "Control 3") {
t.Error("Prompt should include all controls")
}
}
// =============================================================================
// Helpers
// =============================================================================
func containsSubstring(s, substr string) bool {
return len(s) >= len(substr) && searchSubstring(s, substr)
}
func searchSubstring(s, substr string) bool {
if len(substr) == 0 {
return true
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,159 @@
package training
import (
"testing"
)
// =============================================================================
// Escalation Threshold Tests
// =============================================================================
func TestEscalationThresholds_Values(t *testing.T) {
tests := []struct {
name string
threshold int
expected int
}{
{"L1 is 7 days", EscalationThresholdL1, 7},
{"L2 is 14 days", EscalationThresholdL2, 14},
{"L3 is 30 days", EscalationThresholdL3, 30},
{"L4 is 45 days", EscalationThresholdL4, 45},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.threshold != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, tt.threshold)
}
})
}
}
func TestEscalationThresholds_Ascending(t *testing.T) {
if EscalationThresholdL1 >= EscalationThresholdL2 {
t.Errorf("L1 (%d) should be < L2 (%d)", EscalationThresholdL1, EscalationThresholdL2)
}
if EscalationThresholdL2 >= EscalationThresholdL3 {
t.Errorf("L2 (%d) should be < L3 (%d)", EscalationThresholdL2, EscalationThresholdL3)
}
if EscalationThresholdL3 >= EscalationThresholdL4 {
t.Errorf("L3 (%d) should be < L4 (%d)", EscalationThresholdL3, EscalationThresholdL4)
}
}
// =============================================================================
// Escalation Label Tests
// =============================================================================
func TestEscalationLabels_AllLevelsPresent(t *testing.T) {
expectedLevels := []int{0, 1, 2, 3, 4}
for _, level := range expectedLevels {
label, ok := EscalationLabels[level]
if !ok {
t.Errorf("Missing label for escalation level %d", level)
}
if label == "" {
t.Errorf("Empty label for escalation level %d", level)
}
}
}
func TestEscalationLabels_Level0_NoEscalation(t *testing.T) {
label := EscalationLabels[0]
if label != "Keine Eskalation" {
t.Errorf("Expected 'Keine Eskalation', got '%s'", label)
}
}
func TestEscalationLabels_Level4_ComplianceOfficer(t *testing.T) {
label := EscalationLabels[4]
if label != "Benachrichtigung Compliance Officer" {
t.Errorf("Expected 'Benachrichtigung Compliance Officer', got '%s'", label)
}
}
func TestEscalationLabels_NoExtraLevels(t *testing.T) {
if len(EscalationLabels) != 5 {
t.Errorf("Expected exactly 5 escalation levels (0-4), got %d", len(EscalationLabels))
}
}
func TestEscalationLabels_LevelContent(t *testing.T) {
tests := []struct {
level int
contains string
}{
{1, "Mitarbeiter"},
{2, "Teamleitung"},
{3, "Management"},
{4, "Compliance Officer"},
}
for _, tt := range tests {
t.Run(EscalationLabels[tt.level], func(t *testing.T) {
label := EscalationLabels[tt.level]
if label == "" {
t.Fatalf("Label for level %d is empty", tt.level)
}
found := false
for i := 0; i <= len(label)-len(tt.contains); i++ {
if label[i:i+len(tt.contains)] == tt.contains {
found = true
break
}
}
if !found {
t.Errorf("Label '%s' should contain '%s'", label, tt.contains)
}
})
}
}
// =============================================================================
// Role Constants and Labels Tests
// =============================================================================
func TestRoleLabels_AllRolesHaveLabels(t *testing.T) {
roles := []string{RoleR1, RoleR2, RoleR3, RoleR4, RoleR5, RoleR6, RoleR7, RoleR8, RoleR9, RoleR10}
for _, role := range roles {
label, ok := RoleLabels[role]
if !ok {
t.Errorf("Missing label for role %s", role)
}
if label == "" {
t.Errorf("Empty label for role %s", role)
}
}
}
func TestNIS2RoleMapping_AllRolesMapped(t *testing.T) {
roles := []string{RoleR1, RoleR2, RoleR3, RoleR4, RoleR5, RoleR6, RoleR7, RoleR8, RoleR9, RoleR10}
for _, role := range roles {
nis2Level, ok := NIS2RoleMapping[role]
if !ok {
t.Errorf("Missing NIS2 mapping for role %s", role)
}
if nis2Level == "" {
t.Errorf("Empty NIS2 level for role %s", role)
}
}
}
func TestTargetAudienceRoleMapping_AllAudiencesPresent(t *testing.T) {
audiences := []string{"enterprise", "authority", "provider", "all"}
for _, aud := range audiences {
roles, ok := TargetAudienceRoleMapping[aud]
if !ok {
t.Errorf("Missing audience mapping for '%s'", aud)
}
if len(roles) == 0 {
t.Errorf("Empty roles for audience '%s'", aud)
}
}
}
func TestTargetAudienceRoleMapping_AllContainsAllRoles(t *testing.T) {
allRoles := TargetAudienceRoleMapping["all"]
if len(allRoles) != 10 {
t.Errorf("Expected 'all' audience to map to 10 roles, got %d", len(allRoles))
}
}

View File

@@ -0,0 +1,801 @@
package training
import (
"encoding/json"
"testing"
)
// =============================================================================
// parseNarratorScript Tests
// =============================================================================
func TestParseNarratorScript_ValidJSON(t *testing.T) {
input := `{
"title": "DSGVO Grundlagen",
"intro": "Hallo, ich bin Ihr AI Teacher.",
"sections": [
{
"heading": "Einfuehrung",
"narrator_text": "Willkommen zur Schulung ueber die DSGVO.",
"bullet_points": ["Punkt 1", "Punkt 2"],
"transition": "Bevor wir fortfahren...",
"checkpoint": {
"title": "Checkpoint 1",
"questions": [
{
"question": "Was ist die DSGVO?",
"options": ["EU-Verordnung", "Bundesgesetz", "Landesgesetz", "Internationale Konvention"],
"correct_index": 0,
"explanation": "Die DSGVO ist eine EU-Verordnung."
}
]
}
}
],
"outro": "Vielen Dank fuer Ihre Aufmerksamkeit.",
"total_duration_estimate": 600
}`
script, err := parseNarratorScript(input)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if script.Title != "DSGVO Grundlagen" {
t.Errorf("Expected title 'DSGVO Grundlagen', got '%s'", script.Title)
}
if script.Intro != "Hallo, ich bin Ihr AI Teacher." {
t.Errorf("Expected intro text, got '%s'", script.Intro)
}
if len(script.Sections) != 1 {
t.Fatalf("Expected 1 section, got %d", len(script.Sections))
}
if script.Sections[0].Heading != "Einfuehrung" {
t.Errorf("Expected heading 'Einfuehrung', got '%s'", script.Sections[0].Heading)
}
if script.Sections[0].Checkpoint == nil {
t.Fatal("Expected checkpoint, got nil")
}
if len(script.Sections[0].Checkpoint.Questions) != 1 {
t.Fatalf("Expected 1 question, got %d", len(script.Sections[0].Checkpoint.Questions))
}
if script.Sections[0].Checkpoint.Questions[0].CorrectIndex != 0 {
t.Errorf("Expected correct_index 0, got %d", script.Sections[0].Checkpoint.Questions[0].CorrectIndex)
}
if script.Outro != "Vielen Dank fuer Ihre Aufmerksamkeit." {
t.Errorf("Expected outro text, got '%s'", script.Outro)
}
if script.TotalDurationEstimate != 600 {
t.Errorf("Expected 600 seconds estimate, got %d", script.TotalDurationEstimate)
}
}
func TestParseNarratorScript_WithSurroundingText(t *testing.T) {
input := `Here is the narrator script:
{
"title": "NIS-2 Schulung",
"intro": "Willkommen",
"sections": [
{
"heading": "Abschnitt 1",
"narrator_text": "Text hier.",
"bullet_points": ["BP1"],
"transition": "Weiter"
}
],
"outro": "Ende",
"total_duration_estimate": 300
}
I hope this helps!`
script, err := parseNarratorScript(input)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if script.Title != "NIS-2 Schulung" {
t.Errorf("Expected title 'NIS-2 Schulung', got '%s'", script.Title)
}
}
func TestParseNarratorScript_InvalidJSON(t *testing.T) {
_, err := parseNarratorScript("not valid json")
if err == nil {
t.Error("Expected error for invalid JSON")
}
}
func TestParseNarratorScript_NoSections(t *testing.T) {
input := `{"title": "Test", "intro": "Hi", "sections": [], "outro": "Bye", "total_duration_estimate": 0}`
_, err := parseNarratorScript(input)
if err == nil {
t.Error("Expected error for empty sections")
}
}
func TestParseNarratorScript_NoJSON(t *testing.T) {
_, err := parseNarratorScript("Just plain text without any JSON")
if err == nil {
t.Error("Expected error when no JSON object found")
}
}
func TestParseNarratorScript_SectionWithoutCheckpoint(t *testing.T) {
input := `{
"title": "Test",
"intro": "Hi",
"sections": [
{
"heading": "Section 1",
"narrator_text": "Some text",
"bullet_points": ["P1"],
"transition": "Next"
}
],
"outro": "Bye",
"total_duration_estimate": 180
}`
script, err := parseNarratorScript(input)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if script.Sections[0].Checkpoint != nil {
t.Error("Section without checkpoint definition should have nil Checkpoint")
}
}
func TestParseNarratorScript_MultipleSectionsWithCheckpoints(t *testing.T) {
input := `{
"title": "Multi-Section",
"intro": "Start",
"sections": [
{
"heading": "S1",
"narrator_text": "Text 1",
"bullet_points": [],
"transition": "T1",
"checkpoint": {
"title": "CP1",
"questions": [
{"question": "Q1?", "options": ["A", "B", "C", "D"], "correct_index": 0, "explanation": "E1"},
{"question": "Q2?", "options": ["A", "B", "C", "D"], "correct_index": 1, "explanation": "E2"}
]
}
},
{
"heading": "S2",
"narrator_text": "Text 2",
"bullet_points": ["BP"],
"transition": "T2",
"checkpoint": {
"title": "CP2",
"questions": [
{"question": "Q3?", "options": ["A", "B", "C", "D"], "correct_index": 2, "explanation": "E3"}
]
}
},
{
"heading": "S3",
"narrator_text": "Text 3",
"bullet_points": [],
"transition": "T3"
}
],
"outro": "End",
"total_duration_estimate": 900
}`
script, err := parseNarratorScript(input)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if len(script.Sections) != 3 {
t.Fatalf("Expected 3 sections, got %d", len(script.Sections))
}
if script.Sections[0].Checkpoint == nil {
t.Error("Section 0 should have a checkpoint")
}
if len(script.Sections[0].Checkpoint.Questions) != 2 {
t.Errorf("Section 0 checkpoint should have 2 questions, got %d", len(script.Sections[0].Checkpoint.Questions))
}
if script.Sections[1].Checkpoint == nil {
t.Error("Section 1 should have a checkpoint")
}
if script.Sections[2].Checkpoint != nil {
t.Error("Section 2 should not have a checkpoint")
}
}
// =============================================================================
// countCheckpoints Tests
// =============================================================================
func TestCountCheckpoints_WithCheckpoints(t *testing.T) {
script := &NarratorScript{
Sections: []NarratorSection{
{Checkpoint: &CheckpointDefinition{Title: "CP1"}},
{Checkpoint: nil},
{Checkpoint: &CheckpointDefinition{Title: "CP3"}},
},
}
count := countCheckpoints(script)
if count != 2 {
t.Errorf("Expected 2 checkpoints, got %d", count)
}
}
func TestCountCheckpoints_NoCheckpoints(t *testing.T) {
script := &NarratorScript{
Sections: []NarratorSection{
{Heading: "S1"},
{Heading: "S2"},
},
}
count := countCheckpoints(script)
if count != 0 {
t.Errorf("Expected 0 checkpoints, got %d", count)
}
}
func TestCountCheckpoints_EmptySections(t *testing.T) {
script := &NarratorScript{}
count := countCheckpoints(script)
if count != 0 {
t.Errorf("Expected 0 checkpoints, got %d", count)
}
}
// =============================================================================
// NarratorScript JSON Serialization Tests
// =============================================================================
func TestNarratorScript_JSONRoundTrip(t *testing.T) {
original := NarratorScript{
Title: "Test",
Intro: "Hello",
Sections: []NarratorSection{
{
Heading: "H1",
NarratorText: "NT1",
BulletPoints: []string{"BP1"},
Transition: "T1",
Checkpoint: &CheckpointDefinition{
Title: "CP1",
Questions: []CheckpointQuestion{
{
Question: "Q?",
Options: []string{"A", "B", "C", "D"},
CorrectIndex: 2,
Explanation: "C is correct",
},
},
},
},
},
Outro: "Bye",
TotalDurationEstimate: 600,
}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var decoded NarratorScript
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if decoded.Title != original.Title {
t.Errorf("Title mismatch: %s != %s", decoded.Title, original.Title)
}
if len(decoded.Sections) != 1 {
t.Fatalf("Expected 1 section, got %d", len(decoded.Sections))
}
if decoded.Sections[0].Checkpoint == nil {
t.Fatal("Checkpoint should not be nil after round-trip")
}
if decoded.Sections[0].Checkpoint.Questions[0].CorrectIndex != 2 {
t.Errorf("CorrectIndex mismatch: got %d", decoded.Sections[0].Checkpoint.Questions[0].CorrectIndex)
}
}
// =============================================================================
// InteractiveVideoManifest Tests
// =============================================================================
func TestInteractiveVideoManifest_JSON(t *testing.T) {
manifest := InteractiveVideoManifest{
StreamURL: "https://example.com/video.mp4",
Checkpoints: []CheckpointManifestEntry{
{
Index: 0,
Title: "CP1",
TimestampSeconds: 180.5,
Questions: []CheckpointQuestion{
{
Question: "Q?",
Options: []string{"A", "B", "C", "D"},
CorrectIndex: 1,
Explanation: "B",
},
},
},
},
}
data, err := json.Marshal(manifest)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var decoded InteractiveVideoManifest
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if len(decoded.Checkpoints) != 1 {
t.Fatalf("Expected 1 checkpoint, got %d", len(decoded.Checkpoints))
}
if decoded.Checkpoints[0].TimestampSeconds != 180.5 {
t.Errorf("Timestamp mismatch: got %f", decoded.Checkpoints[0].TimestampSeconds)
}
}
// =============================================================================
// SubmitCheckpointQuizRequest/Response Tests
// =============================================================================
func TestSubmitCheckpointQuizResponse_JSON(t *testing.T) {
resp := SubmitCheckpointQuizResponse{
Passed: true,
Score: 80.0,
Feedback: []CheckpointQuizFeedback{
{Question: "Q1?", Correct: true, Explanation: "Correct!"},
{Question: "Q2?", Correct: false, Explanation: "Wrong answer."},
},
}
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var decoded SubmitCheckpointQuizResponse
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if !decoded.Passed {
t.Error("Expected passed=true")
}
if decoded.Score != 80.0 {
t.Errorf("Expected score 80.0, got %f", decoded.Score)
}
if len(decoded.Feedback) != 2 {
t.Fatalf("Expected 2 feedback items, got %d", len(decoded.Feedback))
}
if decoded.Feedback[1].Correct {
t.Error("Second feedback should be incorrect")
}
}
// =============================================================================
// narratorSystemPrompt Tests
// =============================================================================
func TestNarratorSystemPrompt_ContainsKeyPhrases(t *testing.T) {
if !containsSubstring(narratorSystemPrompt, "AI Teacher") {
t.Error("System prompt should mention AI Teacher")
}
if !containsSubstring(narratorSystemPrompt, "Checkpoint") {
t.Error("System prompt should mention Checkpoint")
}
if !containsSubstring(narratorSystemPrompt, "JSON") {
t.Error("System prompt should mention JSON format")
}
if !containsSubstring(narratorSystemPrompt, "correct_index") {
t.Error("System prompt should mention correct_index")
}
}
// =============================================================================
// Checkpoint Grading Logic Tests (User Journey: Learner scores quiz)
// =============================================================================
func TestCheckpointGrading_AllCorrect_ScoreIs100(t *testing.T) {
questions := []CheckpointQuestion{
{Question: "Q1?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 0},
{Question: "Q2?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 1},
{Question: "Q3?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 2},
}
answers := []int{0, 1, 2}
correctCount := 0
for i, q := range questions {
if i < len(answers) && answers[i] == q.CorrectIndex {
correctCount++
}
}
score := float64(correctCount) / float64(len(questions)) * 100
passed := score >= 70
if score != 100.0 {
t.Errorf("Expected score 100, got %f", score)
}
if !passed {
t.Error("Expected passed=true with 100% score")
}
}
func TestCheckpointGrading_NoneCorrect_ScoreIs0(t *testing.T) {
questions := []CheckpointQuestion{
{Question: "Q1?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 0},
{Question: "Q2?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 1},
{Question: "Q3?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 2},
}
answers := []int{3, 3, 3}
correctCount := 0
for i, q := range questions {
if i < len(answers) && answers[i] == q.CorrectIndex {
correctCount++
}
}
score := float64(correctCount) / float64(len(questions)) * 100
passed := score >= 70
if score != 0.0 {
t.Errorf("Expected score 0, got %f", score)
}
if passed {
t.Error("Expected passed=false with 0% score")
}
}
func TestCheckpointGrading_ExactlyAt70Percent_Passes(t *testing.T) {
// 7 out of 10 correct = 70% — exactly at threshold
questions := make([]CheckpointQuestion, 10)
answers := make([]int, 10)
for i := 0; i < 10; i++ {
questions[i] = CheckpointQuestion{
Question: "Q?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 0,
}
if i < 7 {
answers[i] = 0 // correct
} else {
answers[i] = 1 // wrong
}
}
correctCount := 0
for i, q := range questions {
if i < len(answers) && answers[i] == q.CorrectIndex {
correctCount++
}
}
score := float64(correctCount) / float64(len(questions)) * 100
passed := score >= 70
if score != 70.0 {
t.Errorf("Expected score 70, got %f", score)
}
if !passed {
t.Error("Expected passed=true at exactly 70%")
}
}
func TestCheckpointGrading_JustBelow70Percent_Fails(t *testing.T) {
// 2 out of 3 correct = 66.67% — below threshold
questions := []CheckpointQuestion{
{Question: "Q1?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 0},
{Question: "Q2?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 1},
{Question: "Q3?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 2},
}
answers := []int{0, 1, 3} // 2 correct, 1 wrong
correctCount := 0
for i, q := range questions {
if i < len(answers) && answers[i] == q.CorrectIndex {
correctCount++
}
}
score := float64(correctCount) / float64(len(questions)) * 100
passed := score >= 70
if passed {
t.Errorf("Expected passed=false at %.2f%%", score)
}
}
func TestCheckpointGrading_FewerAnswersThanQuestions_MarksUnansweredWrong(t *testing.T) {
questions := []CheckpointQuestion{
{Question: "Q1?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 0},
{Question: "Q2?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 1},
{Question: "Q3?", Options: []string{"A", "B", "C", "D"}, CorrectIndex: 2},
}
answers := []int{0} // Only 1 answer for 3 questions
correctCount := 0
for i, q := range questions {
if i < len(answers) && answers[i] == q.CorrectIndex {
correctCount++
}
}
if correctCount != 1 {
t.Errorf("Expected 1 correct, got %d", correctCount)
}
score := float64(correctCount) / float64(len(questions)) * 100
if score > 34 {
t.Errorf("Expected score ~33.3%%, got %f", score)
}
}
func TestCheckpointGrading_EmptyAnswers_AllWrong(t *testing.T) {
questions := []CheckpointQuestion{
{Question: "Q1?", Options: []string{"A", "B"}, CorrectIndex: 0},
{Question: "Q2?", Options: []string{"A", "B"}, CorrectIndex: 1},
}
answers := []int{}
correctCount := 0
for i, q := range questions {
if i < len(answers) && answers[i] == q.CorrectIndex {
correctCount++
}
}
if correctCount != 0 {
t.Errorf("Expected 0 correct with empty answers, got %d", correctCount)
}
}
// =============================================================================
// Feedback Generation Tests (User Journey: Learner sees feedback)
// =============================================================================
func TestCheckpointFeedback_CorrectAnswerGetsCorrectFlag(t *testing.T) {
questions := []CheckpointQuestion{
{Question: "Was ist DSGVO?", Options: []string{"EU-Verordnung", "Bundesgesetz"}, CorrectIndex: 0, Explanation: "EU-Verordnung"},
{Question: "Wer ist DSB?", Options: []string{"IT-Leiter", "Datenschutzbeauftragter"}, CorrectIndex: 1, Explanation: "DSB Rolle"},
}
answers := []int{0, 0} // First correct, second wrong
feedback := make([]CheckpointQuizFeedback, len(questions))
for i, q := range questions {
isCorrect := false
if i < len(answers) && answers[i] == q.CorrectIndex {
isCorrect = true
}
feedback[i] = CheckpointQuizFeedback{
Question: q.Question,
Correct: isCorrect,
Explanation: q.Explanation,
}
}
if !feedback[0].Correct {
t.Error("First answer should be marked correct")
}
if feedback[1].Correct {
t.Error("Second answer should be marked incorrect")
}
if feedback[0].Question != "Was ist DSGVO?" {
t.Errorf("Unexpected question text: %s", feedback[0].Question)
}
if feedback[1].Explanation != "DSB Rolle" {
t.Errorf("Explanation should be preserved: got %s", feedback[1].Explanation)
}
}
// =============================================================================
// NarratorScript Pipeline Tests (User Journey: Admin generates video)
// =============================================================================
func TestNarratorScript_SectionCounting(t *testing.T) {
tests := []struct {
name string
sectionCount int
checkpointCount int
}{
{"3 sections, all with checkpoints", 3, 3},
{"4 sections, 2 with checkpoints", 4, 2},
{"1 section, no checkpoint", 1, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sections := make([]NarratorSection, tt.sectionCount)
cpAdded := 0
for i := 0; i < tt.sectionCount; i++ {
sections[i] = NarratorSection{
Heading: "Section",
NarratorText: "Text",
BulletPoints: []string{},
Transition: "Next",
}
if cpAdded < tt.checkpointCount {
sections[i].Checkpoint = &CheckpointDefinition{
Title: "CP",
Questions: []CheckpointQuestion{{Question: "Q?", Options: []string{"A", "B"}, CorrectIndex: 0}},
}
cpAdded++
}
}
script := &NarratorScript{
Title: "Test",
Intro: "Hi",
Sections: sections,
Outro: "Bye",
}
if len(script.Sections) != tt.sectionCount {
t.Errorf("Expected %d sections, got %d", tt.sectionCount, len(script.Sections))
}
if countCheckpoints(script) != tt.checkpointCount {
t.Errorf("Expected %d checkpoints, got %d", tt.checkpointCount, countCheckpoints(script))
}
})
}
}
func TestNarratorScript_SectionAudioConversion(t *testing.T) {
// Verify NarratorSection can be converted to SectionAudio for TTS
sections := []NarratorSection{
{Heading: "Einleitung", NarratorText: "Willkommen zur Schulung."},
{Heading: "Hauptteil", NarratorText: "Hier lernen Sie die Grundlagen."},
}
audioSections := make([]SectionAudio, len(sections))
for i, s := range sections {
audioSections[i] = SectionAudio{
Text: s.NarratorText,
Heading: s.Heading,
}
}
if len(audioSections) != 2 {
t.Fatalf("Expected 2 audio sections, got %d", len(audioSections))
}
if audioSections[0].Heading != "Einleitung" {
t.Errorf("Expected heading 'Einleitung', got '%s'", audioSections[0].Heading)
}
if audioSections[1].Text != "Hier lernen Sie die Grundlagen." {
t.Errorf("Unexpected text: '%s'", audioSections[1].Text)
}
}
// =============================================================================
// InteractiveVideoManifest Progress Tests (User Journey: Learner resumes)
// =============================================================================
func TestManifest_IdentifiesNextUnpassedCheckpoint(t *testing.T) {
manifest := InteractiveVideoManifest{
StreamURL: "https://example.com/video.mp4",
Checkpoints: []CheckpointManifestEntry{
{Index: 0, Title: "CP1", TimestampSeconds: 180, Progress: &CheckpointProgress{Passed: true}},
{Index: 1, Title: "CP2", TimestampSeconds: 360, Progress: &CheckpointProgress{Passed: false}},
{Index: 2, Title: "CP3", TimestampSeconds: 540, Progress: nil},
},
}
var nextUnpassed *CheckpointManifestEntry
for i := range manifest.Checkpoints {
cp := &manifest.Checkpoints[i]
if cp.Progress == nil || !cp.Progress.Passed {
nextUnpassed = cp
break
}
}
if nextUnpassed == nil {
t.Fatal("Expected to find an unpassed checkpoint")
}
if nextUnpassed.Index != 1 {
t.Errorf("Expected next unpassed at index 1, got %d", nextUnpassed.Index)
}
if nextUnpassed.Title != "CP2" {
t.Errorf("Expected CP2, got %s", nextUnpassed.Title)
}
}
func TestManifest_AllCheckpointsPassed(t *testing.T) {
manifest := InteractiveVideoManifest{
Checkpoints: []CheckpointManifestEntry{
{Index: 0, Progress: &CheckpointProgress{Passed: true}},
{Index: 1, Progress: &CheckpointProgress{Passed: true}},
},
}
allPassed := true
for _, cp := range manifest.Checkpoints {
if cp.Progress == nil || !cp.Progress.Passed {
allPassed = false
break
}
}
if !allPassed {
t.Error("Expected all checkpoints to be passed")
}
}
func TestManifest_NoCheckpoints_AllPassedIsTrue(t *testing.T) {
manifest := InteractiveVideoManifest{
Checkpoints: []CheckpointManifestEntry{},
}
allPassed := true
for _, cp := range manifest.Checkpoints {
if cp.Progress == nil || !cp.Progress.Passed {
allPassed = false
break
}
}
if !allPassed {
t.Error("Empty checkpoint list should be considered all-passed")
}
}
func TestManifest_SeekProtection_BlocksSkippingPastUnpassed(t *testing.T) {
// Simulates seek protection logic from InteractiveVideoPlayer
checkpoints := []CheckpointManifestEntry{
{Index: 0, TimestampSeconds: 180, Progress: &CheckpointProgress{Passed: true}},
{Index: 1, TimestampSeconds: 360, Progress: nil}, // Not yet attempted
{Index: 2, TimestampSeconds: 540, Progress: nil},
}
seekTarget := 500.0 // User tries to seek to 500s
// Find first unpassed checkpoint
var firstUnpassed *CheckpointManifestEntry
for i := range checkpoints {
if checkpoints[i].Progress == nil || !checkpoints[i].Progress.Passed {
firstUnpassed = &checkpoints[i]
break
}
}
blocked := false
if firstUnpassed != nil && seekTarget > firstUnpassed.TimestampSeconds {
blocked = true
}
if !blocked {
t.Error("Seek past unpassed checkpoint should be blocked")
}
if firstUnpassed.TimestampSeconds != 360 {
t.Errorf("Expected block at 360s, got %f", firstUnpassed.TimestampSeconds)
}
}
func TestManifest_SeekProtection_AllowsSeekBeforeFirstUnpassed(t *testing.T) {
checkpoints := []CheckpointManifestEntry{
{Index: 0, TimestampSeconds: 180, Progress: &CheckpointProgress{Passed: true}},
{Index: 1, TimestampSeconds: 360, Progress: nil},
}
seekTarget := 200.0 // User seeks to 200s — before unpassed checkpoint at 360s
var firstUnpassed *CheckpointManifestEntry
for i := range checkpoints {
if checkpoints[i].Progress == nil || !checkpoints[i].Progress.Passed {
firstUnpassed = &checkpoints[i]
break
}
}
blocked := false
if firstUnpassed != nil && seekTarget > firstUnpassed.TimestampSeconds {
blocked = true
}
if blocked {
t.Error("Seek before unpassed checkpoint should be allowed")
}
}

View File

@@ -16,8 +16,9 @@ import (
type MediaType string
const (
MediaTypeAudio MediaType = "audio"
MediaTypeVideo MediaType = "video"
MediaTypeAudio MediaType = "audio"
MediaTypeVideo MediaType = "video"
MediaTypeInteractiveVideo MediaType = "interactive_video"
)
// MediaStatus represents the processing status
@@ -169,6 +170,57 @@ func (c *TTSClient) GenerateVideo(ctx context.Context, req *TTSGenerateVideoRequ
return &result, nil
}
// PresignedURLRequest is the request to get a presigned URL
type PresignedURLRequest struct {
Bucket string `json:"bucket"`
ObjectKey string `json:"object_key"`
Expires int `json:"expires"`
}
// PresignedURLResponse is the response containing a presigned URL
type PresignedURLResponse struct {
URL string `json:"url"`
ExpiresIn int `json:"expires_in"`
}
// GetPresignedURL requests a presigned URL from the TTS service
func (c *TTSClient) GetPresignedURL(ctx context.Context, bucket, objectKey string) (string, error) {
reqBody := PresignedURLRequest{
Bucket: bucket,
ObjectKey: objectKey,
Expires: 3600,
}
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/presigned-url", bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("TTS presigned URL request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("TTS presigned URL error (%d): %s", resp.StatusCode, string(respBody))
}
var result PresignedURLResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("parse presigned URL response: %w", err)
}
return result.URL, nil
}
// IsHealthy checks if the TTS service is responsive
func (c *TTSClient) IsHealthy(ctx context.Context) bool {
httpReq, err := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/health", nil)
@@ -184,3 +236,115 @@ func (c *TTSClient) IsHealthy(ctx context.Context) bool {
return resp.StatusCode == http.StatusOK
}
// ============================================================================
// Interactive Video TTS Client Methods
// ============================================================================
// SynthesizeSectionsRequest is the request for batch section audio synthesis
type SynthesizeSectionsRequest struct {
Sections []SectionAudio `json:"sections"`
Voice string `json:"voice"`
ModuleID string `json:"module_id"`
}
// SectionAudio represents one section's text for audio synthesis
type SectionAudio struct {
Text string `json:"text"`
Heading string `json:"heading"`
}
// SynthesizeSectionsResponse is the response from batch section synthesis
type SynthesizeSectionsResponse struct {
Sections []SectionResult `json:"sections"`
TotalDuration float64 `json:"total_duration"`
}
// SectionResult is the result for one section's audio
type SectionResult struct {
Heading string `json:"heading"`
AudioPath string `json:"audio_path"`
AudioObjectKey string `json:"audio_object_key"`
Duration float64 `json:"duration"`
StartTimestamp float64 `json:"start_timestamp"`
}
// GenerateInteractiveVideoRequest is the request for interactive video generation
type GenerateInteractiveVideoRequest struct {
Script *NarratorScript `json:"script"`
Audio *SynthesizeSectionsResponse `json:"audio"`
ModuleID string `json:"module_id"`
}
// GenerateInteractiveVideoResponse is the response from interactive video generation
type GenerateInteractiveVideoResponse struct {
VideoID string `json:"video_id"`
Bucket string `json:"bucket"`
ObjectKey string `json:"object_key"`
DurationSeconds float64 `json:"duration_seconds"`
SizeBytes int64 `json:"size_bytes"`
}
// SynthesizeSections calls the TTS service to synthesize audio for multiple sections
func (c *TTSClient) SynthesizeSections(ctx context.Context, req *SynthesizeSectionsRequest) (*SynthesizeSectionsResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/synthesize-sections", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("TTS synthesize-sections request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("TTS synthesize-sections error (%d): %s", resp.StatusCode, string(respBody))
}
var result SynthesizeSectionsResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("parse TTS synthesize-sections response: %w", err)
}
return &result, nil
}
// GenerateInteractiveVideo calls the TTS service to create an interactive video with checkpoint slides
func (c *TTSClient) GenerateInteractiveVideo(ctx context.Context, req *GenerateInteractiveVideoRequest) (*GenerateInteractiveVideoResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/generate-interactive-video", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("TTS interactive video request failed: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("TTS interactive video error (%d): %s", resp.StatusCode, string(respBody))
}
var result GenerateInteractiveVideoResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("parse TTS interactive video response: %w", err)
}
return &result, nil
}

View File

@@ -106,7 +106,8 @@ const (
RoleR6 = "R6" // Einkauf
RoleR7 = "R7" // Fachabteilung
RoleR8 = "R8" // IT-Admin
RoleR9 = "R9" // Alle Mitarbeiter
RoleR9 = "R9" // Alle Mitarbeiter
RoleR10 = "R10" // Behoerden / Oeffentlicher Dienst
)
// RoleLabels maps role codes to human-readable labels
@@ -118,8 +119,9 @@ var RoleLabels = map[string]string{
RoleR5: "HR / Personal",
RoleR6: "Einkauf / Beschaffung",
RoleR7: "Fachabteilung",
RoleR8: "IT-Administration",
RoleR9: "Alle Mitarbeiter",
RoleR8: "IT-Administration",
RoleR9: "Alle Mitarbeiter",
RoleR10: "Behoerden / Oeffentlicher Dienst",
}
// NIS2RoleMapping maps internal roles to NIS2 levels
@@ -131,8 +133,38 @@ var NIS2RoleMapping = map[string]string{
RoleR5: "N4", // HR
RoleR6: "N4", // Einkauf
RoleR7: "N5", // Fachabteilung
RoleR8: "N2", // IT-Admin
RoleR9: "N5", // Alle Mitarbeiter
RoleR8: "N2", // IT-Admin
RoleR9: "N5", // Alle Mitarbeiter
RoleR10: "N4", // Behoerden
}
// TargetAudienceRoleMapping maps canonical control target_audience values to CTM roles
var TargetAudienceRoleMapping = map[string][]string{
"enterprise": {RoleR1, RoleR4, RoleR5, RoleR6, RoleR7, RoleR9}, // Unternehmen
"authority": {RoleR10}, // Behoerden
"provider": {RoleR2, RoleR8}, // IT-Dienstleister
"all": {RoleR1, RoleR2, RoleR3, RoleR4, RoleR5, RoleR6, RoleR7, RoleR8, RoleR9, RoleR10},
}
// CategoryRoleMapping provides additional role hints based on control category
var CategoryRoleMapping = map[string][]string{
"encryption": {RoleR2, RoleR8},
"authentication": {RoleR2, RoleR8, RoleR9},
"network": {RoleR2, RoleR8},
"data_protection": {RoleR3, RoleR5, RoleR9},
"logging": {RoleR2, RoleR4, RoleR8},
"incident": {RoleR1, RoleR4},
"continuity": {RoleR1, RoleR2, RoleR4},
"compliance": {RoleR1, RoleR3, RoleR4},
"supply_chain": {RoleR6},
"physical": {RoleR7},
"personnel": {RoleR5, RoleR9},
"application": {RoleR8},
"system": {RoleR2, RoleR8},
"risk": {RoleR1, RoleR4},
"governance": {RoleR1, RoleR4},
"hardware": {RoleR2, RoleR8},
"identity": {RoleR2, RoleR3, RoleR8},
}
// ============================================================================
@@ -498,3 +530,228 @@ type BulkResult struct {
Skipped int `json:"skipped"`
Errors []string `json:"errors"`
}
// ============================================================================
// Training Block Types (Controls → Schulungsmodule Pipeline)
// ============================================================================
// TrainingBlockConfig defines how canonical controls are grouped into training modules
type TrainingBlockConfig struct {
ID uuid.UUID `json:"id"`
TenantID uuid.UUID `json:"tenant_id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
DomainFilter string `json:"domain_filter,omitempty"` // "AUTH", "CRYP", etc.
CategoryFilter string `json:"category_filter,omitempty"` // "authentication", etc.
SeverityFilter string `json:"severity_filter,omitempty"` // "high", "critical"
TargetAudienceFilter string `json:"target_audience_filter,omitempty"` // "enterprise", "authority", "provider", "all"
RegulationArea RegulationArea `json:"regulation_area"`
ModuleCodePrefix string `json:"module_code_prefix"`
FrequencyType FrequencyType `json:"frequency_type"`
DurationMinutes int `json:"duration_minutes"`
PassThreshold int `json:"pass_threshold"`
MaxControlsPerModule int `json:"max_controls_per_module"`
IsActive bool `json:"is_active"`
LastGeneratedAt *time.Time `json:"last_generated_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TrainingBlockControlLink tracks which canonical controls are linked to which module
type TrainingBlockControlLink struct {
ID uuid.UUID `json:"id"`
BlockConfigID uuid.UUID `json:"block_config_id"`
ModuleID uuid.UUID `json:"module_id"`
ControlID string `json:"control_id"`
ControlTitle string `json:"control_title"`
ControlObjective string `json:"control_objective"`
ControlRequirements []string `json:"control_requirements"`
SortOrder int `json:"sort_order"`
CreatedAt time.Time `json:"created_at"`
}
// CanonicalControlSummary is a lightweight view on canonical_controls for the training pipeline
type CanonicalControlSummary struct {
ControlID string `json:"control_id"`
Title string `json:"title"`
Objective string `json:"objective"`
Rationale string `json:"rationale"`
Requirements []string `json:"requirements"`
Severity string `json:"severity"`
Category string `json:"category"`
TargetAudience string `json:"target_audience"`
Tags []string `json:"tags"`
}
// CanonicalControlMeta provides aggregated metadata about canonical controls
type CanonicalControlMeta struct {
Domains []DomainCount `json:"domains"`
Categories []CategoryCount `json:"categories"`
Audiences []AudienceCount `json:"audiences"`
Total int `json:"total"`
}
// DomainCount is a domain with its control count
type DomainCount struct {
Domain string `json:"domain"`
Count int `json:"count"`
}
// CategoryCount is a category with its control count
type CategoryCount struct {
Category string `json:"category"`
Count int `json:"count"`
}
// AudienceCount is a target audience with its control count
type AudienceCount struct {
Audience string `json:"audience"`
Count int `json:"count"`
}
// CreateBlockConfigRequest is the API request for creating a block config
type CreateBlockConfigRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description,omitempty"`
DomainFilter string `json:"domain_filter,omitempty"`
CategoryFilter string `json:"category_filter,omitempty"`
SeverityFilter string `json:"severity_filter,omitempty"`
TargetAudienceFilter string `json:"target_audience_filter,omitempty"`
RegulationArea RegulationArea `json:"regulation_area" binding:"required"`
ModuleCodePrefix string `json:"module_code_prefix" binding:"required"`
FrequencyType FrequencyType `json:"frequency_type"`
DurationMinutes int `json:"duration_minutes"`
PassThreshold int `json:"pass_threshold"`
MaxControlsPerModule int `json:"max_controls_per_module"`
}
// UpdateBlockConfigRequest is the API request for updating a block config
type UpdateBlockConfigRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
DomainFilter *string `json:"domain_filter,omitempty"`
CategoryFilter *string `json:"category_filter,omitempty"`
SeverityFilter *string `json:"severity_filter,omitempty"`
TargetAudienceFilter *string `json:"target_audience_filter,omitempty"`
MaxControlsPerModule *int `json:"max_controls_per_module,omitempty"`
DurationMinutes *int `json:"duration_minutes,omitempty"`
PassThreshold *int `json:"pass_threshold,omitempty"`
IsActive *bool `json:"is_active,omitempty"`
}
// ============================================================================
// Interactive Video / Checkpoint Types
// ============================================================================
// NarratorScript is an extended VideoScript with narrator persona and checkpoints
type NarratorScript struct {
Title string `json:"title"`
Intro string `json:"intro"`
Sections []NarratorSection `json:"sections"`
Outro string `json:"outro"`
TotalDurationEstimate int `json:"total_duration_estimate"`
}
// NarratorSection is one narrative section with optional checkpoint
type NarratorSection struct {
Heading string `json:"heading"`
NarratorText string `json:"narrator_text"`
BulletPoints []string `json:"bullet_points"`
Transition string `json:"transition"`
Checkpoint *CheckpointDefinition `json:"checkpoint,omitempty"`
}
// CheckpointDefinition defines a quiz checkpoint within a video
type CheckpointDefinition struct {
Title string `json:"title"`
Questions []CheckpointQuestion `json:"questions"`
}
// CheckpointQuestion is a quiz question within a checkpoint
type CheckpointQuestion struct {
Question string `json:"question"`
Options []string `json:"options"`
CorrectIndex int `json:"correct_index"`
Explanation string `json:"explanation"`
}
// Checkpoint is a DB record for a video checkpoint
type Checkpoint struct {
ID uuid.UUID `json:"id"`
ModuleID uuid.UUID `json:"module_id"`
CheckpointIndex int `json:"checkpoint_index"`
Title string `json:"title"`
TimestampSeconds float64 `json:"timestamp_seconds"`
CreatedAt time.Time `json:"created_at"`
}
// CheckpointProgress tracks a user's progress on a checkpoint
type CheckpointProgress struct {
ID uuid.UUID `json:"id"`
AssignmentID uuid.UUID `json:"assignment_id"`
CheckpointID uuid.UUID `json:"checkpoint_id"`
Passed bool `json:"passed"`
Attempts int `json:"attempts"`
LastAttemptAt *time.Time `json:"last_attempt_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// InteractiveVideoManifest is returned to the frontend player
type InteractiveVideoManifest struct {
MediaID uuid.UUID `json:"media_id"`
StreamURL string `json:"stream_url"`
Checkpoints []CheckpointManifestEntry `json:"checkpoints"`
}
// CheckpointManifestEntry is one checkpoint in the manifest
type CheckpointManifestEntry struct {
CheckpointID uuid.UUID `json:"checkpoint_id"`
Index int `json:"index"`
Title string `json:"title"`
TimestampSeconds float64 `json:"timestamp_seconds"`
Questions []CheckpointQuestion `json:"questions"`
Progress *CheckpointProgress `json:"progress,omitempty"`
}
// SubmitCheckpointQuizRequest is the API request for submitting a checkpoint quiz
type SubmitCheckpointQuizRequest struct {
AssignmentID string `json:"assignment_id"`
Answers []int `json:"answers"`
}
// SubmitCheckpointQuizResponse is the API response for a checkpoint quiz submission
type SubmitCheckpointQuizResponse struct {
Passed bool `json:"passed"`
Score float64 `json:"score"`
Feedback []CheckpointQuizFeedback `json:"feedback"`
}
// CheckpointQuizFeedback is feedback for a single question
type CheckpointQuizFeedback struct {
Question string `json:"question"`
Correct bool `json:"correct"`
Explanation string `json:"explanation"`
}
// GenerateBlockRequest is the API request for generating modules from a block config
type GenerateBlockRequest struct {
Language string `json:"language"`
AutoMatrix bool `json:"auto_matrix"`
}
// PreviewBlockResponse shows what would be generated without writing to DB
type PreviewBlockResponse struct {
ControlCount int `json:"control_count"`
ModuleCount int `json:"module_count"`
Controls []CanonicalControlSummary `json:"controls"`
ProposedRoles []string `json:"proposed_roles"`
}
// GenerateBlockResponse shows the result of a block generation
type GenerateBlockResponse struct {
ModulesCreated int `json:"modules_created"`
ControlsLinked int `json:"controls_linked"`
MatrixEntriesCreated int `json:"matrix_entries_created"`
ContentGenerated int `json:"content_generated"`
Errors []string `json:"errors,omitempty"`
}

View File

@@ -235,6 +235,12 @@ func (s *Store) UpdateModule(ctx context.Context, module *TrainingModule) error
return err
}
// DeleteModule deletes a training module by ID
func (s *Store) DeleteModule(ctx context.Context, id uuid.UUID) error {
_, err := s.pool.Exec(ctx, `DELETE FROM training_modules WHERE id = $1`, id)
return err
}
// SetAcademyCourseID links a training module to an academy course
func (s *Store) SetAcademyCourseID(ctx context.Context, moduleID, courseID uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
@@ -570,6 +576,18 @@ func (s *Store) UpdateAssignmentStatus(ctx context.Context, id uuid.UUID, status
return err
}
// UpdateAssignmentDeadline updates the deadline of an assignment
func (s *Store) UpdateAssignmentDeadline(ctx context.Context, id uuid.UUID, deadline time.Time) error {
now := time.Now().UTC()
_, err := s.pool.Exec(ctx, `
UPDATE training_assignments SET
deadline = $2,
updated_at = $3
WHERE id = $1
`, id, deadline, now)
return err
}
// UpdateAssignmentQuizResult updates quiz-related fields on an assignment
func (s *Store) UpdateAssignmentQuizResult(ctx context.Context, id uuid.UUID, score float64, passed bool, attempts int) error {
now := time.Now().UTC()
@@ -1252,6 +1270,80 @@ func (s *Store) GetPublishedAudio(ctx context.Context, moduleID uuid.UUID) (*Tra
return &media, nil
}
// SetCertificateID sets the certificate ID on an assignment
func (s *Store) SetCertificateID(ctx context.Context, assignmentID, certID uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE training_assignments SET certificate_id = $2, updated_at = NOW() WHERE id = $1
`, assignmentID, certID)
return err
}
// GetAssignmentByCertificateID finds an assignment by its certificate ID
func (s *Store) GetAssignmentByCertificateID(ctx context.Context, certID uuid.UUID) (*TrainingAssignment, error) {
var assignmentID uuid.UUID
err := s.pool.QueryRow(ctx,
"SELECT id FROM training_assignments WHERE certificate_id = $1",
certID).Scan(&assignmentID)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return s.GetAssignment(ctx, assignmentID)
}
// ListCertificates lists assignments that have certificates for a tenant
func (s *Store) ListCertificates(ctx context.Context, tenantID uuid.UUID) ([]TrainingAssignment, error) {
rows, err := s.pool.Query(ctx, `
SELECT
ta.id, ta.tenant_id, ta.module_id, ta.user_id, ta.user_name, ta.user_email,
ta.role_code, ta.trigger_type, ta.trigger_event, ta.status, ta.progress_percent,
ta.quiz_score, ta.quiz_passed, ta.quiz_attempts,
ta.started_at, ta.completed_at, ta.deadline, ta.certificate_id,
ta.escalation_level, ta.last_escalation_at, ta.enrollment_id,
ta.created_at, ta.updated_at,
m.module_code, m.title
FROM training_assignments ta
JOIN training_modules m ON m.id = ta.module_id
WHERE ta.tenant_id = $1 AND ta.certificate_id IS NOT NULL
ORDER BY ta.completed_at DESC
`, tenantID)
if err != nil {
return nil, err
}
defer rows.Close()
var assignments []TrainingAssignment
for rows.Next() {
var a TrainingAssignment
var status, triggerType string
err := rows.Scan(
&a.ID, &a.TenantID, &a.ModuleID, &a.UserID, &a.UserName, &a.UserEmail,
&a.RoleCode, &triggerType, &a.TriggerEvent, &status, &a.ProgressPercent,
&a.QuizScore, &a.QuizPassed, &a.QuizAttempts,
&a.StartedAt, &a.CompletedAt, &a.Deadline, &a.CertificateID,
&a.EscalationLevel, &a.LastEscalationAt, &a.EnrollmentID,
&a.CreatedAt, &a.UpdatedAt,
&a.ModuleCode, &a.ModuleTitle,
)
if err != nil {
return nil, err
}
a.Status = AssignmentStatus(status)
a.TriggerType = TriggerType(triggerType)
assignments = append(assignments, a)
}
if assignments == nil {
assignments = []TrainingAssignment{}
}
return assignments, nil
}
// GetPublishedVideo gets the published video for a module
func (s *Store) GetPublishedVideo(ctx context.Context, moduleID uuid.UUID) (*TrainingMedia, error) {
var media TrainingMedia
@@ -1283,3 +1375,195 @@ func (s *Store) GetPublishedVideo(ctx context.Context, moduleID uuid.UUID) (*Tra
media.Status = MediaStatus(status)
return &media, nil
}
// ============================================================================
// Checkpoint Operations
// ============================================================================
// CreateCheckpoint inserts a new checkpoint
func (s *Store) CreateCheckpoint(ctx context.Context, cp *Checkpoint) error {
cp.ID = uuid.New()
cp.CreatedAt = time.Now().UTC()
_, err := s.pool.Exec(ctx, `
INSERT INTO training_checkpoints (id, module_id, checkpoint_index, title, timestamp_seconds, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
`, cp.ID, cp.ModuleID, cp.CheckpointIndex, cp.Title, cp.TimestampSeconds, cp.CreatedAt)
return err
}
// ListCheckpoints returns all checkpoints for a module ordered by index
func (s *Store) ListCheckpoints(ctx context.Context, moduleID uuid.UUID) ([]Checkpoint, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, module_id, checkpoint_index, title, timestamp_seconds, created_at
FROM training_checkpoints
WHERE module_id = $1
ORDER BY checkpoint_index
`, moduleID)
if err != nil {
return nil, err
}
defer rows.Close()
var checkpoints []Checkpoint
for rows.Next() {
var cp Checkpoint
if err := rows.Scan(&cp.ID, &cp.ModuleID, &cp.CheckpointIndex, &cp.Title, &cp.TimestampSeconds, &cp.CreatedAt); err != nil {
return nil, err
}
checkpoints = append(checkpoints, cp)
}
if checkpoints == nil {
checkpoints = []Checkpoint{}
}
return checkpoints, nil
}
// DeleteCheckpointsForModule removes all checkpoints for a module (used before regenerating)
func (s *Store) DeleteCheckpointsForModule(ctx context.Context, moduleID uuid.UUID) error {
_, err := s.pool.Exec(ctx, `DELETE FROM training_checkpoints WHERE module_id = $1`, moduleID)
return err
}
// GetCheckpointProgress retrieves progress for a specific checkpoint+assignment
func (s *Store) GetCheckpointProgress(ctx context.Context, assignmentID, checkpointID uuid.UUID) (*CheckpointProgress, error) {
var cp CheckpointProgress
err := s.pool.QueryRow(ctx, `
SELECT id, assignment_id, checkpoint_id, passed, attempts, last_attempt_at, created_at
FROM training_checkpoint_progress
WHERE assignment_id = $1 AND checkpoint_id = $2
`, assignmentID, checkpointID).Scan(
&cp.ID, &cp.AssignmentID, &cp.CheckpointID, &cp.Passed, &cp.Attempts, &cp.LastAttemptAt, &cp.CreatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &cp, nil
}
// UpsertCheckpointProgress creates or updates checkpoint progress
func (s *Store) UpsertCheckpointProgress(ctx context.Context, progress *CheckpointProgress) error {
progress.ID = uuid.New()
now := time.Now().UTC()
progress.LastAttemptAt = &now
progress.CreatedAt = now
_, err := s.pool.Exec(ctx, `
INSERT INTO training_checkpoint_progress (id, assignment_id, checkpoint_id, passed, attempts, last_attempt_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (assignment_id, checkpoint_id) DO UPDATE SET
passed = EXCLUDED.passed,
attempts = training_checkpoint_progress.attempts + 1,
last_attempt_at = EXCLUDED.last_attempt_at
`, progress.ID, progress.AssignmentID, progress.CheckpointID, progress.Passed, progress.Attempts, progress.LastAttemptAt, progress.CreatedAt)
return err
}
// GetCheckpointQuestions retrieves quiz questions for a specific checkpoint
func (s *Store) GetCheckpointQuestions(ctx context.Context, checkpointID uuid.UUID) ([]QuizQuestion, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, module_id, question, options, correct_index, explanation, difficulty, is_active, sort_order, created_at
FROM training_quiz_questions
WHERE checkpoint_id = $1 AND is_active = true
ORDER BY sort_order
`, checkpointID)
if err != nil {
return nil, err
}
defer rows.Close()
var questions []QuizQuestion
for rows.Next() {
var q QuizQuestion
var options []byte
var difficulty string
if err := rows.Scan(&q.ID, &q.ModuleID, &q.Question, &options, &q.CorrectIndex, &q.Explanation, &difficulty, &q.IsActive, &q.SortOrder, &q.CreatedAt); err != nil {
return nil, err
}
json.Unmarshal(options, &q.Options)
q.Difficulty = Difficulty(difficulty)
questions = append(questions, q)
}
if questions == nil {
questions = []QuizQuestion{}
}
return questions, nil
}
// CreateCheckpointQuizQuestion creates a quiz question linked to a checkpoint
func (s *Store) CreateCheckpointQuizQuestion(ctx context.Context, q *QuizQuestion, checkpointID uuid.UUID) error {
q.ID = uuid.New()
q.CreatedAt = time.Now().UTC()
q.IsActive = true
options, _ := json.Marshal(q.Options)
_, err := s.pool.Exec(ctx, `
INSERT INTO training_quiz_questions (id, module_id, checkpoint_id, question, options, correct_index, explanation, difficulty, is_active, sort_order, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`, q.ID, q.ModuleID, checkpointID, q.Question, options, q.CorrectIndex, q.Explanation, string(q.Difficulty), q.IsActive, q.SortOrder, q.CreatedAt)
return err
}
// AreAllCheckpointsPassed checks if all checkpoints for a module are passed by an assignment
func (s *Store) AreAllCheckpointsPassed(ctx context.Context, assignmentID, moduleID uuid.UUID) (bool, error) {
var totalCheckpoints, passedCheckpoints int
err := s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM training_checkpoints WHERE module_id = $1
`, moduleID).Scan(&totalCheckpoints)
if err != nil {
return false, err
}
if totalCheckpoints == 0 {
return true, nil
}
err = s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM training_checkpoint_progress cp
JOIN training_checkpoints c ON cp.checkpoint_id = c.id
WHERE cp.assignment_id = $1 AND c.module_id = $2 AND cp.passed = true
`, assignmentID, moduleID).Scan(&passedCheckpoints)
if err != nil {
return false, err
}
return passedCheckpoints >= totalCheckpoints, nil
}
// ListCheckpointProgress returns all checkpoint progress for an assignment
func (s *Store) ListCheckpointProgress(ctx context.Context, assignmentID uuid.UUID) ([]CheckpointProgress, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, assignment_id, checkpoint_id, passed, attempts, last_attempt_at, created_at
FROM training_checkpoint_progress
WHERE assignment_id = $1
ORDER BY created_at
`, assignmentID)
if err != nil {
return nil, err
}
defer rows.Close()
var progress []CheckpointProgress
for rows.Next() {
var cp CheckpointProgress
if err := rows.Scan(&cp.ID, &cp.AssignmentID, &cp.CheckpointID, &cp.Passed, &cp.Attempts, &cp.LastAttemptAt, &cp.CreatedAt); err != nil {
return nil, err
}
progress = append(progress, cp)
}
if progress == nil {
progress = []CheckpointProgress{}
}
return progress, nil
}