diff --git a/ai-compliance-sdk/internal/api/handlers/usecase_handler.go b/ai-compliance-sdk/internal/api/handlers/usecase_handler.go index edbf695..a72d11b 100644 --- a/ai-compliance-sdk/internal/api/handlers/usecase_handler.go +++ b/ai-compliance-sdk/internal/api/handlers/usecase_handler.go @@ -16,17 +16,16 @@ type UseCaseHandler struct { store *usecase.Store compiler *usecase.Compiler gapDetector *usecase.GapDetector - llmGen *usecase.LLMQuestionGenerator } // NewUseCaseHandler creates a new UseCaseHandler. func NewUseCaseHandler(pool *pgxpool.Pool, registry *llm.ProviderRegistry) *UseCaseHandler { store := usecase.NewStore(pool) + llmGen := usecase.NewLLMQuestionGenerator(registry) return &UseCaseHandler{ store: store, - compiler: usecase.NewCompiler(store), + compiler: usecase.NewCompiler(store, llmGen), gapDetector: usecase.NewGapDetector(store), - llmGen: usecase.NewLLMQuestionGenerator(registry), } } @@ -59,12 +58,11 @@ func (h *UseCaseHandler) GetTemplate(c *gin.Context) { // Compile generates questions from MC filters ad-hoc. // POST /sdk/v1/use-case/compile -// Optional: "mode": "llm" to use LLM-based generation +// Uses the full pipeline: doc_check → LLM → deterministic fallback func (h *UseCaseHandler) Compile(c *gin.Context) { var req struct { MCFilters []string `json:"mc_filters" binding:"required"` Regulations []string `json:"regulations"` - Mode string `json:"mode"` // "deterministic" (default) or "llm" } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -77,29 +75,13 @@ func (h *UseCaseHandler) Compile(c *gin.Context) { Regulations: req.Regulations, } - if req.Mode == "llm" && h.llmGen != nil { - // Fetch MCs first, then generate via LLM - mcs, err := h.store.FetchMCsByFilters(req.MCFilters) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - questions, err := h.llmGen.GenerateQuestions(mcs, req.Regulations) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, gin.H{"questions": questions, "total": len(questions), "mode": "llm"}) - return - } - questions, err := h.compiler.Compile(tmpl) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - c.JSON(http.StatusOK, gin.H{"questions": questions, "total": len(questions), "mode": "deterministic"}) + c.JSON(http.StatusOK, gin.H{"questions": questions, "total": len(questions)}) } // CreateAudit starts a new audit from a template. diff --git a/ai-compliance-sdk/internal/usecase/compiler.go b/ai-compliance-sdk/internal/usecase/compiler.go index fd7d36f..ccea48a 100644 --- a/ai-compliance-sdk/internal/usecase/compiler.go +++ b/ai-compliance-sdk/internal/usecase/compiler.go @@ -2,6 +2,7 @@ package usecase import ( "fmt" + "log" "strings" "golang.org/x/text/cases" @@ -10,51 +11,57 @@ import ( // Compiler turns Master Controls into audit questionnaires. type Compiler struct { - store *Store + store *Store + llmGen *LLMQuestionGenerator } -// NewCompiler creates a Compiler. -func NewCompiler(store *Store) *Compiler { - return &Compiler{store: store} +// NewCompiler creates a Compiler with optional LLM generator. +func NewCompiler(store *Store, llmGen *LLMQuestionGenerator) *Compiler { + return &Compiler{store: store, llmGen: llmGen} } -// Compile generates questions for a template by combining pre-defined -// questions, existing doc_check_controls, and MC-derived questions. +// Compile generates questions for a template. +// +// Flow (per Plan): +// 1. Fetch MCs matching template filters from DB +// 2. For each MC: check doc_check_controls → Mode A (deterministic) +// 3. For remaining MCs: use LLM → Mode B +// 4. For remaining MCs: derive from MC name → Mode A fallback +// 5. Template hardcoded questions = absolute fallback if DB returns nothing func (c *Compiler) Compile(tmpl *Template) ([]Question, error) { - // 1. Start with pre-defined template questions - if len(tmpl.Questions) > 0 { - return c.enrichWithMCIDs(tmpl) - } - - // 2. Fetch MCs matching the template filters + // 1. Fetch MCs matching the template filters mcs, err := c.store.FetchMCsByFilters(tmpl.MCFilters) if err != nil { - return nil, fmt.Errorf("fetch MCs: %w", err) + log.Printf("usecase: MC fetch failed: %v, falling back to template questions", err) + return c.templateFallback(tmpl), nil } + if len(mcs) == 0 { + // No MCs in DB for these filters → use hardcoded template questions + if len(tmpl.Questions) > 0 { + return tmpl.Questions, nil + } return nil, fmt.Errorf("no Master Controls found for filters %v", tmpl.MCFilters) } - // 3. Check for existing doc_check_controls questions + // 2. Check for existing doc_check_controls mcIDs := make([]string, len(mcs)) for i, mc := range mcs { mcIDs[i] = mc.MasterControlID } - checkQuestions, err := c.store.FetchCheckQuestions(mcIDs) - if err != nil { - return nil, fmt.Errorf("fetch check questions: %w", err) - } + checkQuestions, _ := c.store.FetchCheckQuestions(mcIDs) - // 4. Generate questions from MCs + // 3. Build questions: doc_check → LLM → deterministic var questions []Question + var mcsWithoutQuestions []MCInfo qNum := 1 for _, mc := range mcs { - // Mode A: Use existing doc_check questions - if cqs, ok := checkQuestions[mc.MasterControlID]; ok { + // Mode A: existing doc_check_controls + if cqs, ok := checkQuestions[mc.MasterControlID]; ok && len(cqs) > 0 { for _, cq := range cqs { - q := Question{ + questions = append(questions, Question{ ID: fmt.Sprintf("Q%d", qNum), MCID: mc.MasterControlID, MCName: mc.CanonicalName, @@ -64,15 +71,33 @@ func (c *Compiler) Compile(tmpl *Template) ([]Question, error) { Regulation: mc.RegSource, PassCriteria: splitCriteria(cq.PassCriteria), FailCriteria: splitCriteria(cq.FailCriteria), - } - questions = append(questions, q) + }) qNum++ } continue } + mcsWithoutQuestions = append(mcsWithoutQuestions, mc) + } - // Mode A fallback: Derive question from MC name - q := Question{ + // Mode B: LLM for MCs without doc_check_controls + if len(mcsWithoutQuestions) > 0 && c.llmGen != nil { + llmQuestions, err := c.llmGen.GenerateQuestions(mcsWithoutQuestions, tmpl.Regulations) + if err == nil && len(llmQuestions) > 0 { + // Renumber + for i := range llmQuestions { + llmQuestions[i].ID = fmt.Sprintf("Q%d", qNum) + qNum++ + } + questions = append(questions, llmQuestions...) + mcsWithoutQuestions = nil // all handled + } else if err != nil { + log.Printf("usecase: LLM generation failed: %v, using deterministic fallback", err) + } + } + + // Mode A fallback: deterministic derivation for remaining MCs + for _, mc := range mcsWithoutQuestions { + questions = append(questions, Question{ ID: fmt.Sprintf("Q%d", qNum), MCID: mc.MasterControlID, MCName: mc.CanonicalName, @@ -82,58 +107,92 @@ func (c *Compiler) Compile(tmpl *Template) ([]Question, error) { Regulation: mc.RegSource, PassCriteria: []string{"Anforderung erfuellt und dokumentiert"}, FailCriteria: []string{"Nicht implementiert oder nicht nachweisbar"}, - } - questions = append(questions, q) + }) qNum++ - // Cap at a reasonable number if qNum > 50 { break } } + // Merge: add template hardcoded questions that cover topics not yet covered + if len(tmpl.Questions) > 0 { + questions = mergeTemplateQuestions(questions, tmpl.Questions, qNum) + } + + if len(questions) == 0 { + return c.templateFallback(tmpl), nil + } + return questions, nil } -// enrichWithMCIDs links pre-defined questions to MCs. -func (c *Compiler) enrichWithMCIDs(tmpl *Template) ([]Question, error) { - mcs, err := c.store.FetchMCsByFilters(tmpl.MCFilters) - if err != nil { - return tmpl.Questions, nil // fallback to questions without MC linkage +// templateFallback returns hardcoded template questions or an error. +func (c *Compiler) templateFallback(tmpl *Template) []Question { + if len(tmpl.Questions) > 0 { + return tmpl.Questions } + return nil +} - mcByTopic := make(map[string]MCInfo) - for _, mc := range mcs { - mcByTopic[mc.CanonicalName] = mc - } - - questions := make([]Question, len(tmpl.Questions)) - copy(questions, tmpl.Questions) - - // Try to link questions to MCs by keyword matching - for i := range questions { - if questions[i].MCID != "" { - continue +// mergeTemplateQuestions adds template questions that aren't already +// covered by MC-compiled questions (matched by keyword overlap). +func mergeTemplateQuestions(compiled, template []Question, nextNum int) []Question { + // Build set of covered MC topics + coveredTopics := make(map[string]bool) + for _, q := range compiled { + if q.MCName != "" { + coveredTopics[q.MCName] = true } - qLower := strings.ToLower(questions[i].Text) - for _, mc := range mcs { - topic := strings.ReplaceAll(mc.CanonicalName, "_", " ") - words := strings.Fields(topic) - matched := 0 - for _, w := range words { - if strings.Contains(qLower, w) { - matched++ - } - } - if matched >= 2 { - questions[i].MCID = mc.MasterControlID - questions[i].MCName = mc.CanonicalName + // Also index key words from the question text + for _, w := range extractKeywords(q.Text) { + coveredTopics[w] = true + } + } + + qNum := nextNum + for _, tq := range template { + // Check if this template question's topic is already covered + keywords := extractKeywords(tq.Text) + covered := false + for _, kw := range keywords { + if coveredTopics[kw] { + covered = true break } } + if covered { + continue + } + + tq.ID = fmt.Sprintf("Q%d", qNum) + compiled = append(compiled, tq) + qNum++ } - return questions, nil + return compiled +} + +// extractKeywords pulls significant words from a question for dedup. +func extractKeywords(text string) []string { + stopwords := map[string]bool{ + "ist": true, "hat": true, "gibt": true, "es": true, "ein": true, + "eine": true, "der": true, "die": true, "das": true, "den": true, + "dem": true, "des": true, "oder": true, "und": true, "fuer": true, + "nach": true, "mit": true, "von": true, "zu": true, "auf": true, + "in": true, "an": true, "bei": true, "werden": true, "wird": true, + "sind": true, "nicht": true, "nur": true, "auch": true, + } + + words := strings.Fields(strings.ToLower(text)) + var keywords []string + for _, w := range words { + w = strings.Trim(w, "?.,;:!\"'()") + if len(w) > 3 && !stopwords[w] { + keywords = append(keywords, w) + } + } + return keywords } // deriveQuestion generates a human-readable question from an MC name.