diff --git a/ai-compliance-sdk/internal/api/handlers/iace_handler_init.go b/ai-compliance-sdk/internal/api/handlers/iace_handler_init.go index 358f10a3..4543932a 100644 --- a/ai-compliance-sdk/internal/api/handlers/iace_handler_init.go +++ b/ai-compliance-sdk/internal/api/handlers/iace_handler_init.go @@ -140,6 +140,7 @@ func (h *IACEHandler) InitializeProject(c *gin.Context) { existingHazards, _ := h.store.ListHazards(ctx, projectID) hazardStep := InitStep{Name: "Gefaehrdungen erstellt", Status: "skipped"} hazardIDsByCategory := make(map[string][]uuid.UUID) + hazardPatternMeasures := make(map[uuid.UUID][]string) if len(existingHazards) == 0 && len(matchOutput.MatchedPatterns) > 0 { comps, _ := h.store.ListComponents(ctx, projectID) @@ -226,6 +227,10 @@ func (h *IACEHandler) InitializeProject(c *gin.Context) { created++ catCount[cat]++ hazardIDsByCategory[cat] = append(hazardIDsByCategory[cat], hz.ID) + // Remember this pattern's suggested measures for this hazard + if len(mp.SuggestedMeasureIDs) > 0 { + hazardPatternMeasures[hz.ID] = mp.SuggestedMeasureIDs + } } } } @@ -266,42 +271,44 @@ func (h *IACEHandler) InitializeProject(c *gin.Context) { } // For each hazard: assign up to maxMitigationsPerHazard measures - // Priority: pattern-suggested first, then category fallback - suggestedByMeasCat := make(map[string][]iace.ProtectiveMeasureEntry) - for _, sm := range matchOutput.SuggestedMeasures { - if entry, ok := measureByID[sm.MeasureID]; ok { - suggestedByMeasCat[entry.HazardCategory] = append(suggestedByMeasCat[entry.HazardCategory], entry) - } - } - + // Priority 1: Pattern-specific SuggestedMeasureIDs (from the pattern that created this hazard) + // Priority 2: Category fallback (generic measures for the hazard category) for _, hazID := range allHazardIDs { hazCat := hazardCatByID[hazID] measCat := patternCatToMeasureCat(hazCat) added := 0 + usedIDs := make(map[string]bool) - // First: pattern-suggested measures for this category - for _, entry := range suggestedByMeasCat[measCat] { - if added >= maxMitigationsPerHazard { - break - } - rt := iace.ReductionType(entry.ReductionType) - if rt == "" { - rt = iace.ReductionTypeInformation - } - _, cerr := h.store.CreateMitigation(ctx, iace.CreateMitigationRequest{ - HazardID: hazID, ReductionType: rt, - Name: entry.Name, Description: entry.Description, - }) - if cerr == nil { - created++ - added++ + // Priority 1: Pattern-specific measures + if patternMIDs, ok := hazardPatternMeasures[hazID]; ok { + for _, mid := range patternMIDs { + if added >= maxMitigationsPerHazard { + break + } + entry, ok := measureByID[mid] + if !ok { + continue + } + rt := iace.ReductionType(entry.ReductionType) + if rt == "" { + rt = iace.ReductionTypeInformation + } + _, cerr := h.store.CreateMitigation(ctx, iace.CreateMitigationRequest{ + HazardID: hazID, ReductionType: rt, + Name: entry.Name, Description: entry.Description, + }) + if cerr == nil { + created++ + added++ + usedIDs[mid] = true + } } } - // Then: category fallback if still under limit + // Priority 2: Category fallback (skip already-used IDs) for _, m := range measuresByCat[measCat] { - if added >= maxMitigationsPerHazard { - break + if added >= maxMitigationsPerHazard || usedIDs[m.ID] { + continue } rt := iace.ReductionType(m.ReductionType) if rt == "" { diff --git a/ai-compliance-sdk/internal/iace/pattern_engine.go b/ai-compliance-sdk/internal/iace/pattern_engine.go index d5cae5ef..3c8e4c1e 100644 --- a/ai-compliance-sdk/internal/iace/pattern_engine.go +++ b/ai-compliance-sdk/internal/iace/pattern_engine.go @@ -67,6 +67,7 @@ type PatternMatch struct { GeneratedHazardType string `json:"generated_hazard_type,omitempty"` MatchedFailureModes []string `json:"matched_failure_modes,omitempty"` ApplicableLifecycles []string `json:"applicable_lifecycles,omitempty"` + SuggestedMeasureIDs []string `json:"suggested_measure_ids,omitempty"` } // HazardSuggestion is a suggested hazard from pattern matching. @@ -220,6 +221,7 @@ func (e *PatternEngine) Match(input MatchInput) *MatchOutput { GeneratedHazardType: p.GeneratedHazardType, MatchedFailureModes: matchedFMs, ApplicableLifecycles: p.ApplicableLifecycles, + SuggestedMeasureIDs: p.SuggestedMeasureIDs, }) for _, cat := range p.GeneratedHazardCats {