From 06711bad1c37092063e370120d2564fa55ae311a Mon Sep 17 00:00:00 2001 From: Benjamin Boenisch Date: Wed, 25 Feb 2026 22:27:06 +0100 Subject: [PATCH] feat(sdk,iace): add Personalized Drafting Pipeline v2 and IACE engine Drafting Engine: 7-module pipeline with narrative tags, allowed facts governance, PII sanitizer, prose validator with repair loop, hash-based cache, and terminology guide. v1 fallback via ?v=1 query param. IACE: Initial AI-Act Conformity Engine with risk classifier, completeness checker, hazard library, and PostgreSQL store for AI system assessments. Co-Authored-By: Claude Opus 4.6 --- .../api/sdk/drafting-engine/draft/route.ts | 609 +++++- .../sdk/drafting-engine/allowed-facts-v2.ts | 85 + .../lib/sdk/drafting-engine/allowed-facts.ts | 257 +++ .../lib/sdk/drafting-engine/cache.ts | 303 +++ .../lib/sdk/drafting-engine/narrative-tags.ts | 139 ++ .../sdk/drafting-engine/prose-validator.ts | 485 +++++ .../lib/sdk/drafting-engine/sanitizer.ts | 298 +++ .../lib/sdk/drafting-engine/terminology.ts | 184 ++ ai-compliance-sdk/cmd/server/main.go | 226 +- .../internal/api/handlers/iace_handler.go | 1833 +++++++++++++++++ ai-compliance-sdk/internal/iace/classifier.go | 415 ++++ .../internal/iace/classifier_test.go | 553 +++++ .../internal/iace/completeness.go | 485 +++++ .../internal/iace/completeness_test.go | 678 ++++++ ai-compliance-sdk/internal/iace/engine.go | 202 ++ .../internal/iace/engine_test.go | 936 +++++++++ .../internal/iace/hazard_library.go | 606 ++++++ .../internal/iace/hazard_library_test.go | 293 +++ ai-compliance-sdk/internal/iace/models.go | 485 +++++ ai-compliance-sdk/internal/iace/store.go | 1777 ++++++++++++++++ 20 files changed, 10588 insertions(+), 261 deletions(-) create mode 100644 admin-compliance/lib/sdk/drafting-engine/allowed-facts-v2.ts create mode 100644 admin-compliance/lib/sdk/drafting-engine/allowed-facts.ts create mode 100644 admin-compliance/lib/sdk/drafting-engine/cache.ts create mode 100644 admin-compliance/lib/sdk/drafting-engine/narrative-tags.ts create mode 100644 admin-compliance/lib/sdk/drafting-engine/prose-validator.ts create mode 100644 admin-compliance/lib/sdk/drafting-engine/sanitizer.ts create mode 100644 admin-compliance/lib/sdk/drafting-engine/terminology.ts create mode 100644 ai-compliance-sdk/internal/api/handlers/iace_handler.go create mode 100644 ai-compliance-sdk/internal/iace/classifier.go create mode 100644 ai-compliance-sdk/internal/iace/classifier_test.go create mode 100644 ai-compliance-sdk/internal/iace/completeness.go create mode 100644 ai-compliance-sdk/internal/iace/completeness_test.go create mode 100644 ai-compliance-sdk/internal/iace/engine.go create mode 100644 ai-compliance-sdk/internal/iace/engine_test.go create mode 100644 ai-compliance-sdk/internal/iace/hazard_library.go create mode 100644 ai-compliance-sdk/internal/iace/hazard_library_test.go create mode 100644 ai-compliance-sdk/internal/iace/models.go create mode 100644 ai-compliance-sdk/internal/iace/store.go diff --git a/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts b/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts index c7a4a32..dde1950 100644 --- a/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts +++ b/admin-compliance/app/api/sdk/drafting-engine/draft/route.ts @@ -1,9 +1,11 @@ /** - * Drafting Engine - Draft API + * Drafting Engine - Draft API v2 * - * Erstellt strukturierte Compliance-Dokument-Entwuerfe. - * Baut dokument-spezifische Prompts aus SOUL-Template + State-Projection. - * Gibt strukturiertes JSON zurueck. + * Erstellt personalisierte Compliance-Dokument-Entwuerfe. + * Pipeline: Constraint → Context → Sanitize → LLM → Validate → Repair → Merge + * + * v1-Modus: ?v=1 oder fehlender v2-Kontext → Legacy-Pipeline + * v2-Modus: Standard — Personalisierte Prosa mit Governance */ import { NextRequest, NextResponse } from 'next/server' @@ -11,7 +13,7 @@ import { NextRequest, NextResponse } from 'next/server' const OLLAMA_URL = process.env.OLLAMA_URL || 'http://host.docker.internal:11434' const LLM_MODEL = process.env.COMPLIANCE_LLM_MODEL || 'qwen2.5vl:32b' -// Import prompt builders +// v1 imports (Legacy) import { buildVVTDraftPrompt } from '@/lib/sdk/drafting-engine/prompts/draft-vvt' import { buildTOMDraftPrompt } from '@/lib/sdk/drafting-engine/prompts/draft-tom' import { buildDSFADraftPrompt } from '@/lib/sdk/drafting-engine/prompts/draft-dsfa' @@ -21,9 +23,32 @@ import type { DraftContext, DraftResponse, DraftRevision, DraftSection } from '@ import type { ScopeDocumentType } from '@/lib/sdk/compliance-scope-types' import { ConstraintEnforcer } from '@/lib/sdk/drafting-engine/constraint-enforcer' -const constraintEnforcer = new ConstraintEnforcer() +// v2 imports (Personalisierte Pipeline) +import { deriveNarrativeTags, extractScoresFromDraftContext, narrativeTagsToPromptString } from '@/lib/sdk/drafting-engine/narrative-tags' +import type { NarrativeTags } from '@/lib/sdk/drafting-engine/narrative-tags' +import { buildAllowedFactsFromDraftContext, allowedFactsToPromptString, disallowedTopicsToPromptString } from '@/lib/sdk/drafting-engine/allowed-facts-v2' +import { sanitizeAllowedFacts, validateNoRemainingPII, SanitizationError } from '@/lib/sdk/drafting-engine/sanitizer' +import { terminologyToPromptString, styleContractToPromptString } from '@/lib/sdk/drafting-engine/terminology' +import { executeRepairLoop, type ProseBlockOutput, type RepairAudit } from '@/lib/sdk/drafting-engine/prose-validator' +import { ProseCacheManager, computeChecksumSync, type CacheKeyParams } from '@/lib/sdk/drafting-engine/cache' -const DRAFTING_SYSTEM_PROMPT = `Du bist ein DSGVO-Compliance-Experte und erstellst strukturierte Dokument-Entwuerfe. +// ============================================================================ +// Shared State +// ============================================================================ + +const constraintEnforcer = new ConstraintEnforcer() +const proseCache = new ProseCacheManager({ maxEntries: 200, ttlHours: 24 }) + +// Template/Terminology Versionen (fuer Cache-Key) +const TEMPLATE_VERSION = '2.0.0' +const TERMINOLOGY_VERSION = '1.0.0' +const VALIDATOR_VERSION = '1.0.0' + +// ============================================================================ +// v1 Legacy Pipeline +// ============================================================================ + +const V1_SYSTEM_PROMPT = `Du bist ein DSGVO-Compliance-Experte und erstellst strukturierte Dokument-Entwuerfe. Du MUSST immer im JSON-Format antworten mit einem "sections" Array. Jede Section hat: id, title, content, schemaField. Halte die Tiefe strikt am vorgegebenen Level. @@ -60,10 +85,488 @@ Antworte als JSON mit "sections" Array.` } } +async function handleV1Draft(body: Record): Promise { + const { documentType, draftContext, instructions, existingDraft } = body as { + documentType: ScopeDocumentType + draftContext: DraftContext + instructions?: string + existingDraft?: DraftRevision + } + + const constraintCheck = constraintEnforcer.checkFromContext(documentType, draftContext) + if (!constraintCheck.allowed) { + return NextResponse.json({ + draft: null, + constraintCheck, + tokensUsed: 0, + error: 'Constraint-Verletzung: ' + constraintCheck.violations.join('; '), + }, { status: 403 }) + } + + const draftPrompt = buildPromptForDocumentType(documentType, draftContext, instructions) + const messages = [ + { role: 'system', content: V1_SYSTEM_PROMPT }, + ...(existingDraft ? [{ + role: 'assistant', + content: `Bisheriger Entwurf:\n${JSON.stringify(existingDraft.sections, null, 2)}`, + }] : []), + { role: 'user', content: draftPrompt }, + ] + + const ollamaResponse = await fetch(`${OLLAMA_URL}/api/chat`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model: LLM_MODEL, + messages, + stream: false, + options: { temperature: 0.15, num_predict: 16384 }, + format: 'json', + }), + signal: AbortSignal.timeout(180000), + }) + + if (!ollamaResponse.ok) { + return NextResponse.json( + { error: `LLM nicht erreichbar (Status ${ollamaResponse.status})` }, + { status: 502 } + ) + } + + const result = await ollamaResponse.json() + const content = result.message?.content || '' + + let sections: DraftSection[] = [] + try { + const parsed = JSON.parse(content) + sections = (parsed.sections || []).map((s: Record, i: number) => ({ + id: String(s.id || `section-${i}`), + title: String(s.title || ''), + content: String(s.content || ''), + schemaField: s.schemaField ? String(s.schemaField) : undefined, + })) + } catch { + sections = [{ id: 'raw', title: 'Entwurf', content }] + } + + const draft: DraftRevision = { + id: `draft-${Date.now()}`, + content: sections.map(s => `## ${s.title}\n\n${s.content}`).join('\n\n'), + sections, + createdAt: new Date().toISOString(), + instruction: instructions as string | undefined, + } + + return NextResponse.json({ + draft, + constraintCheck, + tokensUsed: result.eval_count || 0, + } satisfies DraftResponse) +} + +// ============================================================================ +// v2 Personalisierte Pipeline +// ============================================================================ + +/** Prose block definitions per document type */ +const DOCUMENT_PROSE_BLOCKS: Record> = { + tom: [ + { blockId: 'tom-intro', blockType: 'introduction', sectionName: 'Einleitung TOM', targetWords: 120 }, + { blockId: 'tom-transition', blockType: 'transition', sectionName: 'Ueberleitung Massnahmen', targetWords: 40 }, + { blockId: 'tom-conclusion', blockType: 'conclusion', sectionName: 'Fazit TOM', targetWords: 80 }, + ], + dsfa: [ + { blockId: 'dsfa-intro', blockType: 'introduction', sectionName: 'Einleitung DSFA', targetWords: 150 }, + { blockId: 'dsfa-transition', blockType: 'transition', sectionName: 'Ueberleitung Risikobewertung', targetWords: 40 }, + { blockId: 'dsfa-appreciation', blockType: 'appreciation', sectionName: 'Wuerdigung bestehender Massnahmen', targetWords: 60 }, + { blockId: 'dsfa-conclusion', blockType: 'conclusion', sectionName: 'Fazit DSFA', targetWords: 100 }, + ], + vvt: [ + { blockId: 'vvt-intro', blockType: 'introduction', sectionName: 'Einleitung VVT', targetWords: 120 }, + { blockId: 'vvt-conclusion', blockType: 'conclusion', sectionName: 'Fazit VVT', targetWords: 80 }, + ], + dsi: [ + { blockId: 'dsi-intro', blockType: 'introduction', sectionName: 'Einleitung Datenschutzerklaerung', targetWords: 130 }, + { blockId: 'dsi-conclusion', blockType: 'conclusion', sectionName: 'Fazit Datenschutzerklaerung', targetWords: 80 }, + ], + lf: [ + { blockId: 'lf-intro', blockType: 'introduction', sectionName: 'Einleitung Loeschfristen', targetWords: 100 }, + { blockId: 'lf-conclusion', blockType: 'conclusion', sectionName: 'Fazit Loeschfristen', targetWords: 60 }, + ], +} + +function buildV2SystemPrompt( + sanitizedFactsString: string, + narrativeTagsString: string, + terminologyString: string, + styleString: string, + disallowedString: string, + companyName: string, + blockId: string, + blockType: string, + sectionName: string, + documentType: string, + targetWords: number +): string { + return `Du bist ein Compliance-Dokumenten-Redakteur. +Du schreibst einzelne Textabschnitte fuer offizielle Compliance-Dokumente. + +KUNDENPROFIL (ERLAUBTE FAKTEN — nur diese darfst du verwenden): +${sanitizedFactsString} + +BEWERTUNGSERGEBNIS (sprachliche Tags — verwende nur diese Begriffe): +${narrativeTagsString} + +TERMINOLOGIE (verwende ausschliesslich diese Fachbegriffe): +${terminologyString} + +STIL: +${styleString} + +VERBOTENE INHALTE: +${disallowedString} +- Keine konkreten Prozentwerte, Scores oder Zahlen +- Keine Compliance-Level-Bezeichnungen (L1, L2, L3, L4) +- Keine direkte Ansprache ("Sie", "Ihr") +- Kein Denglisch, keine Marketing-Sprache, keine Superlative + +STRIKTE REGELN: +1. Verwende den Firmennamen "${companyName}" — nie "Ihr Unternehmen" +2. Schreibe in der dritten Person ("Die ${companyName}...") +3. Beziehe dich auf die Branche und organisatorische Merkmale +4. Verwende NUR Fakten aus dem Kundenprofil oben +5. Verwende NUR die sprachlichen Tags aus dem Bewertungsergebnis +6. Erfinde KEINE zusaetzlichen Fakten oder Bewertungen +7. Halte dich an die Terminologie-Vorgaben +8. Dein Text wird ZWISCHEN feste Datentabellen eingefuegt + +OUTPUT-FORMAT: Antworte ausschliesslich als JSON: +{ + "blockId": "${blockId}", + "blockType": "${blockType}", + "language": "de", + "text": "...", + "assertions": { + "companyNameUsed": true/false, + "industryReferenced": true/false, + "structureReferenced": true/false, + "itLandscapeReferenced": true/false, + "narrativeTagsUsed": ["riskSummary", ...] + }, + "forbiddenContentDetected": [] +} + +DOKUMENTENTYP: ${documentType} +SEKTION: ${sectionName} +BLOCK-TYP: ${blockType} +ZIEL-LAENGE: ${targetWords} Woerter` +} + +function buildBlockSpecificPrompt(blockType: string, sectionName: string, documentType: string): string { + switch (blockType) { + case 'introduction': + return `Schreibe eine Einleitung fuer das Dokument "${documentType}" (Sektion: ${sectionName}). +Erklaere, warum dieses Dokument fuer das Unternehmen erstellt wurde. +Gehe auf die spezifische Situation des Unternehmens ein. +Erwaehne die Branche, die Organisationsform und die IT-Strategie.` + case 'transition': + return `Schreibe eine kurze Ueberleitung zur naechsten Sektion "${sectionName}". +Verknuepfe den vorherigen Abschnitt logisch mit dem folgenden.` + case 'conclusion': + return `Schreibe einen abschliessenden Absatz fuer die Sektion "${sectionName}". +Fasse die wesentlichen Punkte zusammen und verweise auf die fortlaufende Pflege.` + case 'appreciation': + return `Schreibe einen wertschaetzenden Satz ueber die bestehenden Massnahmen. +Verwende dabei die sprachlichen Tags aus dem Bewertungsergebnis. +Keine neuen Fakten erfinden — nur das Profil wuerdigen.` + default: + return `Schreibe einen Textabschnitt fuer "${sectionName}".` + } +} + +async function callOllama(systemPrompt: string, userPrompt: string): Promise { + const response = await fetch(`${OLLAMA_URL}/api/chat`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model: LLM_MODEL, + messages: [ + { role: 'system', content: systemPrompt }, + { role: 'user', content: userPrompt }, + ], + stream: false, + options: { temperature: 0.15, num_predict: 4096 }, + format: 'json', + }), + signal: AbortSignal.timeout(120000), + }) + + if (!response.ok) { + throw new Error(`Ollama error: ${response.status}`) + } + + const result = await response.json() + return result.message?.content || '' +} + +async function handleV2Draft(body: Record): Promise { + const { documentType, draftContext, instructions } = body as { + documentType: ScopeDocumentType + draftContext: DraftContext + instructions?: string + } + + // Step 1: Constraint Check (Hard Gate) + const constraintCheck = constraintEnforcer.checkFromContext(documentType, draftContext) + if (!constraintCheck.allowed) { + return NextResponse.json({ + draft: null, + constraintCheck, + tokensUsed: 0, + error: 'Constraint-Verletzung: ' + constraintCheck.violations.join('; '), + }, { status: 403 }) + } + + // Step 2: Derive Narrative Tags (deterministisch) + const scores = extractScoresFromDraftContext(draftContext) + const narrativeTags: NarrativeTags = deriveNarrativeTags(scores) + + // Step 3: Build Allowed Facts + const allowedFacts = buildAllowedFactsFromDraftContext(draftContext, narrativeTags) + + // Step 4: PII Sanitization + let sanitizationResult + try { + sanitizationResult = sanitizeAllowedFacts(allowedFacts) + } catch (error) { + if (error instanceof SanitizationError) { + return NextResponse.json({ + error: `Sanitization Hard Abort: ${error.message} (Feld: ${error.field})`, + draft: null, + constraintCheck, + tokensUsed: 0, + }, { status: 422 }) + } + throw error + } + + const sanitizedFacts = sanitizationResult.facts + + // Verify no remaining PII + const piiWarnings = validateNoRemainingPII(sanitizedFacts) + if (piiWarnings.length > 0) { + console.warn('PII-Warnungen nach Sanitization:', piiWarnings) + } + + // Step 5: Build prompt components + const factsString = allowedFactsToPromptString(sanitizedFacts) + const tagsString = narrativeTagsToPromptString(narrativeTags) + const termsString = terminologyToPromptString() + const styleString = styleContractToPromptString() + const disallowedString = disallowedTopicsToPromptString() + + // Compute prompt hash for audit + const promptHash = computeChecksumSync({ factsString, tagsString, termsString, styleString, disallowedString }) + + // Step 6: Generate Prose Blocks (with cache + repair loop) + const proseBlocks = DOCUMENT_PROSE_BLOCKS[documentType] || DOCUMENT_PROSE_BLOCKS.tom + const generatedBlocks: ProseBlockOutput[] = [] + const repairAudits: RepairAudit[] = [] + let totalTokens = 0 + + for (const blockDef of proseBlocks) { + // Check cache + const cacheParams: CacheKeyParams = { + allowedFacts: sanitizedFacts, + templateVersion: TEMPLATE_VERSION, + terminologyVersion: TERMINOLOGY_VERSION, + narrativeTags, + promptHash, + blockType: blockDef.blockType, + sectionName: blockDef.sectionName, + } + + const cached = proseCache.getSync(cacheParams) + if (cached) { + generatedBlocks.push(cached) + repairAudits.push({ + repairAttempts: 0, + validatorFailures: [], + repairSuccessful: true, + fallbackUsed: false, + }) + continue + } + + // Build prompts + const systemPrompt = buildV2SystemPrompt( + factsString, tagsString, termsString, styleString, disallowedString, + sanitizedFacts.companyName, + blockDef.blockId, blockDef.blockType, blockDef.sectionName, + documentType, blockDef.targetWords + ) + const userPrompt = buildBlockSpecificPrompt( + blockDef.blockType, blockDef.sectionName, documentType + ) + (instructions ? `\n\nZusaetzliche Anweisungen: ${instructions}` : '') + + // Call LLM + Repair Loop + try { + const rawOutput = await callOllama(systemPrompt, userPrompt) + totalTokens += rawOutput.length / 4 // Rough token estimate + + const { block, audit } = await executeRepairLoop( + rawOutput, + sanitizedFacts, + narrativeTags, + blockDef.blockId, + blockDef.blockType, + async (repairPrompt) => callOllama(systemPrompt, repairPrompt), + documentType + ) + + generatedBlocks.push(block) + repairAudits.push(audit) + + // Cache successful blocks (not fallbacks) + if (!audit.fallbackUsed) { + proseCache.setSync(cacheParams, block) + } + } catch (error) { + // LLM unreachable → Fallback + const { buildFallbackBlock } = await import('@/lib/sdk/drafting-engine/prose-validator') + generatedBlocks.push( + buildFallbackBlock(blockDef.blockId, blockDef.blockType, sanitizedFacts, documentType) + ) + repairAudits.push({ + repairAttempts: 0, + validatorFailures: [[(error as Error).message]], + repairSuccessful: false, + fallbackUsed: true, + fallbackReason: `LLM-Fehler: ${(error as Error).message}`, + }) + } + } + + // Step 7: Build v1-compatible draft sections from prose blocks + original prompt + const draftPrompt = buildPromptForDocumentType(documentType, draftContext, instructions) + + // Also generate data sections via legacy pipeline + let dataSections: DraftSection[] = [] + try { + const dataResponse = await callOllama(V1_SYSTEM_PROMPT, draftPrompt) + const parsed = JSON.parse(dataResponse) + dataSections = (parsed.sections || []).map((s: Record, i: number) => ({ + id: String(s.id || `section-${i}`), + title: String(s.title || ''), + content: String(s.content || ''), + schemaField: s.schemaField ? String(s.schemaField) : undefined, + })) + totalTokens += dataResponse.length / 4 + } catch { + dataSections = [] + } + + // Merge: Prose intro → Data sections → Prose transitions/conclusion + const introBlock = generatedBlocks.find(b => b.blockType === 'introduction') + const transitionBlocks = generatedBlocks.filter(b => b.blockType === 'transition') + const appreciationBlocks = generatedBlocks.filter(b => b.blockType === 'appreciation') + const conclusionBlock = generatedBlocks.find(b => b.blockType === 'conclusion') + + const mergedSections: DraftSection[] = [] + + if (introBlock) { + mergedSections.push({ + id: introBlock.blockId, + title: 'Einleitung', + content: introBlock.text, + }) + } + + for (let i = 0; i < dataSections.length; i++) { + // Insert transition before data section (if available) + if (i > 0 && transitionBlocks[i - 1]) { + mergedSections.push({ + id: transitionBlocks[i - 1].blockId, + title: '', + content: transitionBlocks[i - 1].text, + }) + } + mergedSections.push(dataSections[i]) + } + + for (const block of appreciationBlocks) { + mergedSections.push({ + id: block.blockId, + title: 'Wuerdigung', + content: block.text, + }) + } + + if (conclusionBlock) { + mergedSections.push({ + id: conclusionBlock.blockId, + title: 'Fazit', + content: conclusionBlock.text, + }) + } + + // If no data sections generated, use prose blocks as sections + const finalSections = mergedSections.length > 0 ? mergedSections : generatedBlocks.map(b => ({ + id: b.blockId, + title: b.blockType === 'introduction' ? 'Einleitung' : + b.blockType === 'conclusion' ? 'Fazit' : + b.blockType === 'appreciation' ? 'Wuerdigung' : 'Ueberleitung', + content: b.text, + })) + + const draft: DraftRevision = { + id: `draft-v2-${Date.now()}`, + content: finalSections.map(s => s.title ? `## ${s.title}\n\n${s.content}` : s.content).join('\n\n'), + sections: finalSections, + createdAt: new Date().toISOString(), + instruction: instructions, + } + + // Step 8: Build Audit Trail + const auditTrail = { + documentType, + templateVersion: TEMPLATE_VERSION, + terminologyVersion: TERMINOLOGY_VERSION, + validatorVersion: VALIDATOR_VERSION, + promptHash, + llmModel: LLM_MODEL, + llmTemperature: 0.15, + llmProvider: 'ollama', + narrativeTags, + sanitization: sanitizationResult.audit, + repairAudits, + proseBlocks: generatedBlocks.map((b, i) => ({ + blockId: b.blockId, + blockType: b.blockType, + wordCount: b.text.split(/\s+/).filter(Boolean).length, + fallbackUsed: repairAudits[i]?.fallbackUsed ?? false, + repairAttempts: repairAudits[i]?.repairAttempts ?? 0, + })), + cacheStats: proseCache.getStats(), + } + + return NextResponse.json({ + draft, + constraintCheck, + tokensUsed: Math.round(totalTokens), + pipelineVersion: 'v2', + auditTrail, + }) +} + +// ============================================================================ +// Route Handler +// ============================================================================ + export async function POST(request: NextRequest) { try { const body = await request.json() - const { documentType, draftContext, instructions, existingDraft } = body + const { documentType, draftContext } = body if (!documentType || !draftContext) { return NextResponse.json( @@ -72,92 +575,14 @@ export async function POST(request: NextRequest) { ) } - // 1. Constraint Check (Hard Gate) - const constraintCheck = constraintEnforcer.checkFromContext(documentType, draftContext) - - if (!constraintCheck.allowed) { - return NextResponse.json({ - draft: null, - constraintCheck, - tokensUsed: 0, - error: 'Constraint-Verletzung: ' + constraintCheck.violations.join('; '), - }, { status: 403 }) + // v1 fallback: explicit ?v=1 parameter + const version = request.nextUrl.searchParams.get('v') + if (version === '1') { + return handleV1Draft(body) } - // 2. Build document-specific prompt - const draftPrompt = buildPromptForDocumentType(documentType, draftContext, instructions) - - // 3. Build messages - const messages = [ - { role: 'system', content: DRAFTING_SYSTEM_PROMPT }, - ...(existingDraft ? [{ - role: 'assistant', - content: `Bisheriger Entwurf:\n${JSON.stringify(existingDraft.sections, null, 2)}`, - }] : []), - { role: 'user', content: draftPrompt }, - ] - - // 4. Call LLM (non-streaming for structured output) - const ollamaResponse = await fetch(`${OLLAMA_URL}/api/chat`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - model: LLM_MODEL, - messages, - stream: false, - options: { - temperature: 0.15, - num_predict: 16384, - }, - format: 'json', - }), - signal: AbortSignal.timeout(180000), - }) - - if (!ollamaResponse.ok) { - return NextResponse.json( - { error: `LLM nicht erreichbar (Status ${ollamaResponse.status})` }, - { status: 502 } - ) - } - - const result = await ollamaResponse.json() - const content = result.message?.content || '' - - // 5. Parse JSON response - let sections: DraftSection[] = [] - try { - const parsed = JSON.parse(content) - sections = (parsed.sections || []).map((s: Record, i: number) => ({ - id: String(s.id || `section-${i}`), - title: String(s.title || ''), - content: String(s.content || ''), - schemaField: s.schemaField ? String(s.schemaField) : undefined, - })) - } catch { - // If not JSON, wrap raw content as single section - sections = [{ - id: 'raw', - title: 'Entwurf', - content: content, - }] - } - - const draft: DraftRevision = { - id: `draft-${Date.now()}`, - content: sections.map(s => `## ${s.title}\n\n${s.content}`).join('\n\n'), - sections, - createdAt: new Date().toISOString(), - instruction: instructions, - } - - const response: DraftResponse = { - draft, - constraintCheck, - tokensUsed: result.eval_count || 0, - } - - return NextResponse.json(response) + // Default: v2 pipeline + return handleV2Draft(body) } catch (error) { console.error('Draft generation error:', error) return NextResponse.json( diff --git a/admin-compliance/lib/sdk/drafting-engine/allowed-facts-v2.ts b/admin-compliance/lib/sdk/drafting-engine/allowed-facts-v2.ts new file mode 100644 index 0000000..2149199 --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/allowed-facts-v2.ts @@ -0,0 +1,85 @@ +/** + * Allowed Facts v2 Adapter — Baut AllowedFacts aus DraftContext + * + * Die Haupt-AllowedFacts Datei (allowed-facts.ts) erwartet SDKState, + * aber in der Draft API Route haben wir nur DraftContext. + * Dieser Adapter ueberbrueckt die Luecke. + * + * Re-exportiert auch die Serialisierungs-/Validierungsfunktionen. + */ + +import type { AllowedFacts, FactPolicy } from './allowed-facts' +import { + DEFAULT_FACT_POLICY, + allowedFactsToPromptString, + disallowedTopicsToPromptString, + checkForDisallowedContent, +} from './allowed-facts' +import type { NarrativeTags } from './narrative-tags' +import type { DraftContext } from './types' + +// Re-exports +export { allowedFactsToPromptString, disallowedTopicsToPromptString, checkForDisallowedContent } + +/** + * Baut AllowedFacts aus einem DraftContext (API Route Kontext). + * Der DraftContext hat bereits projizierte Firmendaten. + */ +export function buildAllowedFactsFromDraftContext( + context: DraftContext, + narrativeTags: NarrativeTags +): AllowedFacts { + const profile = context.companyProfile + + return { + companyName: profile.name || 'Unbekannt', + legalForm: '', // Nicht im DraftContext enthalten + industry: profile.industry || '', + location: '', // Nicht im DraftContext enthalten + employeeCount: profile.employeeCount || 0, + + teamStructure: deriveTeamStructure(profile.employeeCount), + itLandscape: deriveItLandscape(profile.businessModel, profile.isPublicSector), + specialFeatures: deriveSpecialFeatures(profile), + + triggeredRegulations: deriveRegulations(context), + primaryUseCases: [], // Nicht im DraftContext enthalten + + narrativeTags, + } +} + +// ============================================================================ +// Private Helpers +// ============================================================================ + +function deriveTeamStructure(employeeCount: number): string { + if (employeeCount > 500) return 'Konzernstruktur' + if (employeeCount > 50) return 'mittelstaendisch' + return 'Kleinunternehmen' +} + +function deriveItLandscape(businessModel: string, isPublicSector: boolean): string { + if (businessModel?.includes('SaaS') || businessModel?.includes('Cloud')) return 'Cloud-First' + if (isPublicSector) return 'On-Premise' + return 'Hybrid' +} + +function deriveSpecialFeatures(profile: DraftContext['companyProfile']): string[] { + const features: string[] = [] + if (profile.isPublicSector) features.push('Oeffentlicher Sektor') + if (profile.employeeCount > 250) features.push('Grossunternehmen') + if (profile.dataProtectionOfficer) features.push('Interner DSB benannt') + return features +} + +function deriveRegulations(context: DraftContext): string[] { + const regs = new Set(['DSGVO']) + const triggers = context.decisions.hardTriggers || [] + for (const t of triggers) { + if (t.id.includes('ai_act') || t.id.includes('ai-act')) regs.add('AI Act') + if (t.id.includes('nis2') || t.id.includes('NIS2')) regs.add('NIS2') + if (t.id.includes('ttdsg') || t.id.includes('TTDSG')) regs.add('TTDSG') + } + return Array.from(regs) +} diff --git a/admin-compliance/lib/sdk/drafting-engine/allowed-facts.ts b/admin-compliance/lib/sdk/drafting-engine/allowed-facts.ts new file mode 100644 index 0000000..1ef2e71 --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/allowed-facts.ts @@ -0,0 +1,257 @@ +/** + * Allowed Facts Governance — Kontrolliertes Faktenbudget fuer LLM + * + * Definiert welche Fakten das LLM in Prosa-Bloecken verwenden darf + * und welche Themen explizit verboten sind. + * + * Verhindert Halluzinationen durch explizite Whitelisting. + */ + +import type { SDKState, CompanyProfile } from '../types' +import type { NarrativeTags } from './narrative-tags' + +// ============================================================================ +// Types +// ============================================================================ + +/** Explizites Faktenbudget fuer das LLM */ +export interface AllowedFacts { + // Firmenprofil + companyName: string + legalForm: string + industry: string + location: string + employeeCount: number + + // Organisation + teamStructure: string + itLandscape: string + specialFeatures: string[] + + // Compliance-Kontext + triggeredRegulations: string[] + primaryUseCases: string[] + + // Narrative Tags (deterministisch) + narrativeTags: NarrativeTags +} + +/** Regeln welche Themen erlaubt/verboten sind */ +export interface FactPolicy { + allowedTopics: string[] + disallowedTopics: string[] +} + +// ============================================================================ +// Default Policy +// ============================================================================ + +export const DEFAULT_FACT_POLICY: FactPolicy = { + allowedTopics: [ + 'Branche', + 'Unternehmensgroesse', + 'Teamstruktur', + 'IT-Strategie', + 'Regulatorischer Kontext', + 'Anwendungsfaelle', + 'Organisationsform', + 'Standort', + 'Rechtsform', + ], + disallowedTopics: [ + 'Umsatz', + 'Gewinn', + 'Kundenzahlen', + 'konkrete Zertifizierungen', + 'interne Tool-Namen', + 'Personennamen', + 'E-Mail-Adressen', + 'Telefonnummern', + 'IP-Adressen', + 'konkrete Prozentwerte', + 'konkrete Scores', + 'Compliance-Level-Bezeichnungen', + 'interne Projektnamen', + 'Passwoerter', + 'API-Keys', + 'Vertragsinhalte', + 'Gehaltsinformationen', + ], +} + +// ============================================================================ +// Builder +// ============================================================================ + +/** + * Extrahiert AllowedFacts aus dem SDKState. + * Nur explizit freigegebene Felder werden uebernommen. + */ +export function buildAllowedFacts( + state: SDKState, + narrativeTags: NarrativeTags +): AllowedFacts { + const profile = state.companyProfile + const scope = state.complianceScope + + return { + companyName: profile?.name ?? 'Unbekannt', + legalForm: profile?.legalForm ?? '', + industry: profile?.industry ?? '', + location: profile?.location ?? '', + employeeCount: profile?.employeeCount ?? 0, + + teamStructure: deriveTeamStructure(profile), + itLandscape: deriveItLandscape(profile), + specialFeatures: deriveSpecialFeatures(profile), + + triggeredRegulations: deriveTriggeredRegulations(scope), + primaryUseCases: derivePrimaryUseCases(state), + + narrativeTags, + } +} + +// ============================================================================ +// Serialization +// ============================================================================ + +/** + * Serialisiert AllowedFacts fuer den LLM-Prompt. + */ +export function allowedFactsToPromptString(facts: AllowedFacts): string { + const lines = [ + `- Firma: ${facts.companyName}${facts.legalForm ? ` (${facts.legalForm})` : ''}`, + `- Branche: ${facts.industry || 'nicht angegeben'}`, + `- Standort: ${facts.location || 'nicht angegeben'}`, + `- Mitarbeiter: ${facts.employeeCount || 'nicht angegeben'}`, + `- Teamstruktur: ${facts.teamStructure || 'nicht angegeben'}`, + `- IT-Umgebung: ${facts.itLandscape || 'nicht angegeben'}`, + ] + + if (facts.triggeredRegulations.length > 0) { + lines.push(`- Relevante Regulierungen: ${facts.triggeredRegulations.join(', ')}`) + } + if (facts.primaryUseCases.length > 0) { + lines.push(`- Anwendungsfaelle: ${facts.primaryUseCases.join(', ')}`) + } + if (facts.specialFeatures.length > 0) { + lines.push(`- Besonderheiten: ${facts.specialFeatures.join(', ')}`) + } + + return lines.join('\n') +} + +/** + * Serialisiert die Disallowed Topics fuer den LLM-Prompt. + */ +export function disallowedTopicsToPromptString(policy: FactPolicy = DEFAULT_FACT_POLICY): string { + return policy.disallowedTopics.map(t => `- ${t}`).join('\n') +} + +// ============================================================================ +// Validation +// ============================================================================ + +/** + * Prueft ob ein Text potentiell verbotene Themen enthaelt. + * Gibt eine Liste der erkannten Verstoesse zurueck. + */ +export function checkForDisallowedContent( + text: string, + policy: FactPolicy = DEFAULT_FACT_POLICY +): string[] { + const violations: string[] = [] + const lower = text.toLowerCase() + + // Prozentwerte + if (/\d+\s*%/.test(text)) { + violations.push('Konkrete Prozentwerte gefunden') + } + + // Score-Muster + if (/score[:\s]*\d+/i.test(text)) { + violations.push('Konkrete Scores gefunden') + } + + // Compliance-Level Bezeichnungen + if (/\b(L1|L2|L3|L4)\b/.test(text)) { + violations.push('Compliance-Level-Bezeichnungen (L1-L4) gefunden') + } + + // E-Mail-Adressen + if (/[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}/.test(text)) { + violations.push('E-Mail-Adresse gefunden') + } + + // Telefonnummern + if (/(\+?\d{1,3}[-.\s]?)?\(?\d{2,5}\)?[-.\s]?\d{3,10}/.test(text)) { + // Nur wenn es nicht die Mitarbeiterzahl ist (einstellig/zweistellig) + const matches = text.match(/(\+?\d{1,3}[-.\s]?)?\(?\d{2,5}\)?[-.\s]?\d{3,10}/g) || [] + for (const m of matches) { + if (m.replace(/\D/g, '').length >= 6) { + violations.push('Telefonnummer gefunden') + break + } + } + } + + // IP-Adressen + if (/\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/.test(text)) { + violations.push('IP-Adresse gefunden') + } + + // Direkte Ansprache + if (/\b(Sie|Ihr|Ihnen|Ihrem|Ihrer)\b/.test(text)) { + violations.push('Direkte Ansprache (Sie/Ihr) gefunden') + } + + return violations +} + +// ============================================================================ +// Private Helpers +// ============================================================================ + +function deriveTeamStructure(profile: CompanyProfile | null): string { + if (!profile) return '' + // Ableitung aus verfuegbaren Profildaten + if (profile.employeeCount > 500) return 'Konzernstruktur' + if (profile.employeeCount > 50) return 'mittelstaendisch' + return 'Kleinunternehmen' +} + +function deriveItLandscape(profile: CompanyProfile | null): string { + if (!profile) return '' + return profile.businessModel?.includes('SaaS') ? 'Cloud-First' : + profile.businessModel?.includes('Cloud') ? 'Cloud-First' : + profile.isPublicSector ? 'On-Premise' : 'Hybrid' +} + +function deriveSpecialFeatures(profile: CompanyProfile | null): string[] { + if (!profile) return [] + const features: string[] = [] + if (profile.isPublicSector) features.push('Oeffentlicher Sektor') + if (profile.employeeCount > 250) features.push('Grossunternehmen') + if (profile.dataProtectionOfficer) features.push('Interner DSB benannt') + return features +} + +function deriveTriggeredRegulations( + scope: import('../compliance-scope-types').ComplianceScopeState | null +): string[] { + if (!scope?.decision) return ['DSGVO'] + const regs = new Set(['DSGVO']) + const triggers = scope.decision.triggeredHardTriggers || [] + for (const t of triggers) { + if (t.rule.id.includes('ai_act') || t.rule.id.includes('ai-act')) regs.add('AI Act') + if (t.rule.id.includes('nis2') || t.rule.id.includes('NIS2')) regs.add('NIS2') + if (t.rule.id.includes('ttdsg') || t.rule.id.includes('TTDSG')) regs.add('TTDSG') + } + return Array.from(regs) +} + +function derivePrimaryUseCases(state: SDKState): string[] { + if (!state.useCases || state.useCases.length === 0) return [] + return state.useCases.slice(0, 3).map(uc => uc.name || uc.title || 'Unbenannt') +} diff --git a/admin-compliance/lib/sdk/drafting-engine/cache.ts b/admin-compliance/lib/sdk/drafting-engine/cache.ts new file mode 100644 index 0000000..7cbb0f1 --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/cache.ts @@ -0,0 +1,303 @@ +/** + * Cache Manager — Hash-basierte Prose-Block-Cache + * + * Deterministischer Cache fuer LLM-generierte Prosa-Bloecke. + * Kein TTL-basiertes Raten — stattdessen Hash-basierte Invalidierung. + * + * Cache-Key = SHA-256 ueber alle Eingabeparameter. + * Aendert sich ein Eingabewert → neuer Hash → Cache-Miss → Neu-Generierung. + */ + +import type { AllowedFacts } from './allowed-facts' +import type { NarrativeTags } from './narrative-tags' +import type { ProseBlockOutput } from './prose-validator' + +// ============================================================================ +// Types +// ============================================================================ + +export interface CacheEntry { + block: ProseBlockOutput + createdAt: string + hitCount: number + cacheKey: string +} + +export interface CacheKeyParams { + allowedFacts: AllowedFacts + templateVersion: string + terminologyVersion: string + narrativeTags: NarrativeTags + promptHash: string + blockType: string + sectionName: string +} + +export interface CacheStats { + totalEntries: number + totalHits: number + totalMisses: number + hitRate: number + oldestEntry: string | null + newestEntry: string | null +} + +// ============================================================================ +// SHA-256 (Browser-kompatibel via SubtleCrypto) +// ============================================================================ + +/** + * Berechnet SHA-256 Hash eines Strings. + * Nutzt SubtleCrypto (verfuegbar in Node.js 15+ und allen modernen Browsern). + */ +async function sha256(input: string): Promise { + // In Next.js API Routes laeuft Node.js — nutze crypto + if (typeof globalThis.crypto?.subtle !== 'undefined') { + const encoder = new TextEncoder() + const data = encoder.encode(input) + const hashBuffer = await globalThis.crypto.subtle.digest('SHA-256', data) + const hashArray = Array.from(new Uint8Array(hashBuffer)) + return hashArray.map(b => b.toString(16).padStart(2, '0')).join('') + } + + // Fallback: Node.js crypto + try { + const { createHash } = await import('crypto') + return createHash('sha256').update(input).digest('hex') + } catch { + // Letzer Fallback: Einfacher Hash (nicht kryptographisch) + return simpleHash(input) + } +} + +/** + * Synchrone SHA-256 Berechnung (Node.js only). + */ +function sha256Sync(input: string): string { + try { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const crypto = require('crypto') + return crypto.createHash('sha256').update(input).digest('hex') + } catch { + return simpleHash(input) + } +} + +/** + * Einfacher nicht-kryptographischer Hash als Fallback. + */ +function simpleHash(input: string): string { + let hash = 0 + for (let i = 0; i < input.length; i++) { + const char = input.charCodeAt(i) + hash = ((hash << 5) - hash) + char + hash = hash & hash + } + return Math.abs(hash).toString(16).padStart(16, '0') +} + +// ============================================================================ +// Cache Key Computation +// ============================================================================ + +/** + * Berechnet den deterministischen Cache-Key. + * Sortiert Keys um konsistente Serialisierung zu gewaehrleisten. + */ +export async function computeCacheKey(params: CacheKeyParams): Promise { + const payload = JSON.stringify(params, Object.keys(params).sort()) + return sha256(payload) +} + +/** + * Synchrone Variante fuer Cache-Key (Node.js). + */ +export function computeCacheKeySync(params: CacheKeyParams): string { + const payload = JSON.stringify(params, Object.keys(params).sort()) + return sha256Sync(payload) +} + +// ============================================================================ +// In-Memory Cache +// ============================================================================ + +/** + * In-Memory Cache fuer Prose-Bloecke. + * + * Sicherheitsmechanismen: + * - Max Eintraege (Speicher-Limit) + * - TTL als zusaetzlicher Sicherheitsmechanismus (24h default) + * - LRU-artige Bereinigung bei Overflow + */ +export class ProseCacheManager { + private cache = new Map() + private hits = 0 + private misses = 0 + private readonly maxEntries: number + private readonly ttlMs: number + + constructor(options?: { maxEntries?: number; ttlHours?: number }) { + this.maxEntries = options?.maxEntries ?? 500 + this.ttlMs = (options?.ttlHours ?? 24) * 60 * 60 * 1000 + } + + /** + * Sucht einen gecachten Block. + */ + async get(params: CacheKeyParams): Promise { + const key = await computeCacheKey(params) + return this.getByKey(key) + } + + /** + * Sucht synchron (Node.js). + */ + getSync(params: CacheKeyParams): ProseBlockOutput | null { + const key = computeCacheKeySync(params) + return this.getByKey(key) + } + + /** + * Speichert einen Block im Cache. + */ + async set(params: CacheKeyParams, block: ProseBlockOutput): Promise { + const key = await computeCacheKey(params) + this.setByKey(key, block) + } + + /** + * Speichert synchron (Node.js). + */ + setSync(params: CacheKeyParams, block: ProseBlockOutput): void { + const key = computeCacheKeySync(params) + this.setByKey(key, block) + } + + /** + * Gibt Cache-Statistiken zurueck. + */ + getStats(): CacheStats { + const entries = Array.from(this.cache.values()) + const total = this.hits + this.misses + + return { + totalEntries: this.cache.size, + totalHits: this.hits, + totalMisses: this.misses, + hitRate: total > 0 ? this.hits / total : 0, + oldestEntry: entries.length > 0 + ? entries.reduce((a, b) => a.createdAt < b.createdAt ? a : b).createdAt + : null, + newestEntry: entries.length > 0 + ? entries.reduce((a, b) => a.createdAt > b.createdAt ? a : b).createdAt + : null, + } + } + + /** + * Loescht alle Eintraege. + */ + clear(): void { + this.cache.clear() + this.hits = 0 + this.misses = 0 + } + + /** + * Loescht abgelaufene Eintraege. + */ + cleanup(): number { + const now = Date.now() + let removed = 0 + for (const [key, entry] of this.cache.entries()) { + if (now - new Date(entry.createdAt).getTime() > this.ttlMs) { + this.cache.delete(key) + removed++ + } + } + return removed + } + + // ======================================================================== + // Private + // ======================================================================== + + private getByKey(key: string): ProseBlockOutput | null { + const entry = this.cache.get(key) + + if (!entry) { + this.misses++ + return null + } + + // TTL pruefen + if (Date.now() - new Date(entry.createdAt).getTime() > this.ttlMs) { + this.cache.delete(key) + this.misses++ + return null + } + + entry.hitCount++ + this.hits++ + return entry.block + } + + private setByKey(key: string, block: ProseBlockOutput): void { + // Bei Overflow: aeltesten Eintrag entfernen + if (this.cache.size >= this.maxEntries) { + this.evictOldest() + } + + this.cache.set(key, { + block, + createdAt: new Date().toISOString(), + hitCount: 0, + cacheKey: key, + }) + } + + private evictOldest(): void { + let oldestKey: string | null = null + let oldestTime = Infinity + + for (const [key, entry] of this.cache.entries()) { + const time = new Date(entry.createdAt).getTime() + if (time < oldestTime) { + oldestTime = time + oldestKey = key + } + } + + if (oldestKey) { + this.cache.delete(oldestKey) + } + } +} + +// ============================================================================ +// Checksum Utils (fuer Data Block Integritaet) +// ============================================================================ + +/** + * Berechnet Integritaets-Checksum ueber Daten. + */ +export async function computeChecksum(data: unknown): Promise { + const serialized = JSON.stringify(data, Object.keys(data as Record).sort()) + return sha256(serialized) +} + +/** + * Synchrone Checksum-Berechnung. + */ +export function computeChecksumSync(data: unknown): string { + const serialized = JSON.stringify(data, Object.keys(data as Record).sort()) + return sha256Sync(serialized) +} + +/** + * Verifiziert eine Checksum gegen Daten. + */ +export async function verifyChecksum(data: unknown, expectedChecksum: string): Promise { + const actual = await computeChecksum(data) + return actual === expectedChecksum +} diff --git a/admin-compliance/lib/sdk/drafting-engine/narrative-tags.ts b/admin-compliance/lib/sdk/drafting-engine/narrative-tags.ts new file mode 100644 index 0000000..81cee7c --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/narrative-tags.ts @@ -0,0 +1,139 @@ +/** + * Narrative Tags — Deterministische Score-zu-Sprache Ableitung + * + * Der Data Layer erzeugt aus berechneten Scores sprachliche Tags. + * Das LLM darf NUR diese Tags verwenden — niemals echte Scores oder Prozentwerte. + * + * Alle Funktionen sind 100% deterministisch: gleiche Eingabe = gleiche Ausgabe. + */ + +// ============================================================================ +// Types +// ============================================================================ + +export interface NarrativeTags { + /** Sprachliche Risiko-Einschaetzung */ + riskSummary: 'niedrig' | 'moderat' | 'erhoht' + /** Reifegrad der bestehenden Massnahmen */ + maturity: 'ausbaufahig' | 'solide' | 'hoch' + /** Handlungsprioritaet */ + priority: 'kurzfristig' | 'mittelfristig' | 'langfristig' + /** Abdeckungsgrad der Controls */ + coverageLevel: 'grundlegend' | 'umfassend' | 'vollstaendig' + /** Dringlichkeit */ + urgency: 'planbar' | 'zeitnah' | 'dringend' +} + +/** Eingabe-Scores fuer die Tag-Ableitung */ +export interface NarrativeTagScores { + /** Gesamt-Risikoscore (0-100) */ + overallRisk: number + /** Reife-Score (0-100) */ + maturityScore: number + /** Anzahl identifizierter Luecken */ + gapCount: number + /** Anzahl kritischer Luecken */ + criticalGaps: number + /** Control-Abdeckung (0-100) */ + controlCoverage: number + /** Anzahl kritischer Findings */ + criticalFindings: number + /** Anzahl hoher Findings */ + highFindings: number +} + +// ============================================================================ +// Tag Derivation (deterministisch) +// ============================================================================ + +/** + * Leitet aus numerischen Scores sprachliche Narrative Tags ab. + * 100% deterministisch — gleiche Scores = gleiche Tags. + */ +export function deriveNarrativeTags(scores: NarrativeTagScores): NarrativeTags { + return { + riskSummary: + scores.overallRisk <= 30 ? 'niedrig' : + scores.overallRisk <= 65 ? 'moderat' : 'erhoht', + + maturity: + scores.maturityScore <= 40 ? 'ausbaufahig' : + scores.maturityScore <= 75 ? 'solide' : 'hoch', + + priority: + scores.gapCount === 0 ? 'langfristig' : + scores.criticalGaps > 0 ? 'kurzfristig' : 'mittelfristig', + + coverageLevel: + scores.controlCoverage <= 50 ? 'grundlegend' : + scores.controlCoverage <= 80 ? 'umfassend' : 'vollstaendig', + + urgency: + scores.criticalFindings > 0 ? 'dringend' : + scores.highFindings > 0 ? 'zeitnah' : 'planbar', + } +} + +/** + * Extrahiert NarrativeTagScores aus einem DraftContext. + * Falls Werte fehlen, werden sichere Defaults (konservativ) verwendet. + */ +export function extractScoresFromDraftContext(context: { + decisions: { + scores: { + risk_score: number + complexity_score: number + assurance_score: number + composite_score: number + } + } + constraints: { + riskFlags: Array<{ severity: string }> + } +}): NarrativeTagScores { + const { scores } = context.decisions + const riskFlags = context.constraints.riskFlags + + const criticalFindings = riskFlags.filter(f => f.severity === 'critical').length + const highFindings = riskFlags.filter(f => f.severity === 'high').length + + return { + overallRisk: scores.risk_score ?? 50, + maturityScore: scores.assurance_score ?? 50, + gapCount: riskFlags.length, + criticalGaps: criticalFindings, + controlCoverage: scores.assurance_score ?? 50, + criticalFindings, + highFindings, + } +} + +// ============================================================================ +// Serialization +// ============================================================================ + +/** + * Serialisiert NarrativeTags fuer den LLM-Prompt. + */ +export function narrativeTagsToPromptString(tags: NarrativeTags): string { + return [ + `- Risikoprofil: ${tags.riskSummary}`, + `- Reifegrad: ${tags.maturity}`, + `- Prioritaet: ${tags.priority}`, + `- Abdeckungsgrad: ${tags.coverageLevel}`, + `- Dringlichkeit: ${tags.urgency}`, + ].join('\n') +} + +/** + * Gibt die erlaubten Tag-Werte als flache Liste zurueck (fuer Validierung). + */ +export function getAllAllowedTagValues(): string[] { + return [ + 'niedrig', 'moderat', 'erhoht', + 'ausbaufahig', 'solide', 'hoch', + 'kurzfristig', 'mittelfristig', 'langfristig', + 'grundlegend', 'umfassend', 'vollstaendig', + 'planbar', 'zeitnah', 'dringend', + ] +} diff --git a/admin-compliance/lib/sdk/drafting-engine/prose-validator.ts b/admin-compliance/lib/sdk/drafting-engine/prose-validator.ts new file mode 100644 index 0000000..af819d8 --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/prose-validator.ts @@ -0,0 +1,485 @@ +/** + * Prose Validator + Repair Loop — Governance Layer + * + * Validiert LLM-generierte Prosa-Bloecke gegen das Regelwerk. + * Orchestriert den Repair-Loop (max 2 Versuche) mit Fallback. + * + * 12 Pruefregeln, davon 10 reparierbar und 2 Hard Aborts. + */ + +import type { NarrativeTags } from './narrative-tags' +import { getAllAllowedTagValues } from './narrative-tags' +import type { AllowedFacts } from './allowed-facts' +import { checkForDisallowedContent } from './allowed-facts' +import { checkStyleViolations, checkTerminologyUsage } from './terminology' +import type { SanitizedFacts } from './sanitizer' +import { isSanitized } from './sanitizer' + +// ============================================================================ +// Types +// ============================================================================ + +/** Strukturierter LLM-Output (Pflicht-Format) */ +export interface ProseBlockOutput { + blockId: string + blockType: 'introduction' | 'transition' | 'conclusion' | 'appreciation' + language: 'de' + text: string + + assertions: { + companyNameUsed: boolean + industryReferenced: boolean + structureReferenced: boolean + itLandscapeReferenced: boolean + narrativeTagsUsed: string[] + } + + forbiddenContentDetected: string[] +} + +/** Einzelner Validierungsfehler */ +export interface ProseValidationError { + rule: string + severity: 'error' | 'warning' + message: string + repairable: boolean +} + +/** Validierungsergebnis */ +export interface ProseValidatorResult { + valid: boolean + errors: ProseValidationError[] + repairable: boolean +} + +/** Repair-Loop Audit */ +export interface RepairAudit { + repairAttempts: number + validatorFailures: string[][] + repairSuccessful: boolean + fallbackUsed: boolean + fallbackReason?: string +} + +/** Word count limits per block type */ +const WORD_COUNT_LIMITS: Record = { + introduction: { min: 30, max: 200 }, + transition: { min: 10, max: 80 }, + conclusion: { min: 20, max: 150 }, + appreciation: { min: 15, max: 100 }, +} + +// ============================================================================ +// Prose Validator +// ============================================================================ + +/** + * Validiert einen ProseBlockOutput gegen alle 12 Regeln. + */ +export function validateProseBlock( + block: ProseBlockOutput, + facts: AllowedFacts | SanitizedFacts, + expectedTags: NarrativeTags +): ProseValidatorResult { + const errors: ProseValidationError[] = [] + + // Rule 1: JSON_VALID — wird extern geprueft (Parsing vor Aufruf) + // Wenn wir hier sind, ist JSON bereits valide + + // Rule 2: COMPANY_NAME_PRESENT + if (!block.text.includes(facts.companyName) && facts.companyName !== 'Unbekannt') { + errors.push({ + rule: 'COMPANY_NAME_PRESENT', + severity: 'error', + message: `Firmenname "${facts.companyName}" nicht im Text gefunden`, + repairable: true, + }) + } + + // Rule 3: INDUSTRY_REFERENCED + if (facts.industry && !block.text.toLowerCase().includes(facts.industry.toLowerCase())) { + errors.push({ + rule: 'INDUSTRY_REFERENCED', + severity: 'warning', + message: `Branche "${facts.industry}" nicht im Text referenziert`, + repairable: true, + }) + } + + // Rule 4: NO_NUMERIC_SCORES + if (/\d+\s*%/.test(block.text)) { + errors.push({ + rule: 'NO_NUMERIC_SCORES', + severity: 'error', + message: 'Prozentwerte im Text gefunden', + repairable: true, + }) + } + if (/score[:\s]*\d+/i.test(block.text)) { + errors.push({ + rule: 'NO_NUMERIC_SCORES', + severity: 'error', + message: 'Score-Werte im Text gefunden', + repairable: true, + }) + } + if (/\b(L1|L2|L3|L4)\b/.test(block.text)) { + errors.push({ + rule: 'NO_NUMERIC_SCORES', + severity: 'error', + message: 'Compliance-Level-Bezeichnungen (L1-L4) im Text gefunden', + repairable: true, + }) + } + + // Rule 5: NO_DISALLOWED_TOPICS + const disallowedViolations = checkForDisallowedContent(block.text) + for (const violation of disallowedViolations) { + errors.push({ + rule: 'NO_DISALLOWED_TOPICS', + severity: 'error', + message: violation, + repairable: true, + }) + } + + // Rule 6: WORD_COUNT_IN_RANGE + const wordCount = block.text.split(/\s+/).filter(Boolean).length + const limits = WORD_COUNT_LIMITS[block.blockType] + if (limits) { + if (wordCount < limits.min) { + errors.push({ + rule: 'WORD_COUNT_IN_RANGE', + severity: 'warning', + message: `Wortanzahl ${wordCount} unter Minimum ${limits.min} fuer ${block.blockType}`, + repairable: true, + }) + } + if (wordCount > limits.max) { + errors.push({ + rule: 'WORD_COUNT_IN_RANGE', + severity: 'error', + message: `Wortanzahl ${wordCount} ueber Maximum ${limits.max} fuer ${block.blockType}`, + repairable: true, + }) + } + } + + // Rule 7: NO_DIRECT_ADDRESS + if (/\b(Sie|Ihr|Ihnen|Ihrem|Ihrer)\b/.test(block.text)) { + errors.push({ + rule: 'NO_DIRECT_ADDRESS', + severity: 'error', + message: 'Direkte Ansprache (Sie/Ihr) gefunden', + repairable: true, + }) + } + + // Rule 8: NARRATIVE_TAGS_CONSISTENT + const allowedTags = getAllAllowedTagValues() + if (block.assertions.narrativeTagsUsed) { + for (const tag of block.assertions.narrativeTagsUsed) { + if (!allowedTags.includes(tag)) { + errors.push({ + rule: 'NARRATIVE_TAGS_CONSISTENT', + severity: 'error', + message: `Unbekannter Narrative Tag "${tag}" in assertions`, + repairable: true, + }) + } + } + } + // Pruefen ob Text Tags enthaelt die nicht zu den erwarteten gehoeren + const expectedTagValues = Object.values(expectedTags) + const allTagValues = getAllAllowedTagValues() + for (const tagValue of allTagValues) { + if (block.text.includes(tagValue) && !expectedTagValues.includes(tagValue)) { + errors.push({ + rule: 'NARRATIVE_TAGS_CONSISTENT', + severity: 'error', + message: `Tag "${tagValue}" im Text, aber nicht im erwarteten Tag-Set`, + repairable: true, + }) + } + } + + // Rule 9: TERMINOLOGY_CORRECT + const termViolations = checkTerminologyUsage(block.text) + for (const warning of termViolations) { + errors.push({ + rule: 'TERMINOLOGY_CORRECT', + severity: 'warning', + message: warning, + repairable: true, + }) + } + + // Rule 10: Style violations + const styleViolations = checkStyleViolations(block.text) + for (const violation of styleViolations) { + errors.push({ + rule: 'STYLE_VIOLATION', + severity: 'warning', + message: violation, + repairable: true, + }) + } + + // Rule 11: SANITIZATION_PASSED (Hard Abort) + if ('__sanitized' in facts && !isSanitized(facts)) { + errors.push({ + rule: 'SANITIZATION_PASSED', + severity: 'error', + message: 'Sanitization-Flag gesetzt aber nicht valide', + repairable: false, + }) + } + + // Rule 12: Self-reported forbidden content + if (block.forbiddenContentDetected && block.forbiddenContentDetected.length > 0) { + errors.push({ + rule: 'SELF_REPORTED_FORBIDDEN', + severity: 'error', + message: `LLM meldet verbotene Inhalte: ${block.forbiddenContentDetected.join(', ')}`, + repairable: true, + }) + } + + const hasHardAbort = errors.some(e => !e.repairable) + const hasErrors = errors.some(e => e.severity === 'error') + + return { + valid: !hasErrors, + errors, + repairable: hasErrors && !hasHardAbort, + } +} + +// ============================================================================ +// JSON Parsing +// ============================================================================ + +/** + * Parst und validiert LLM-Output als ProseBlockOutput. + * Gibt null zurueck wenn JSON nicht parsebar ist. + */ +export function parseProseBlockOutput(rawContent: string): ProseBlockOutput | null { + try { + const parsed = JSON.parse(rawContent) + + // Pflichtfelder pruefen + if ( + typeof parsed.blockId !== 'string' || + typeof parsed.text !== 'string' || + !['introduction', 'transition', 'conclusion', 'appreciation'].includes(parsed.blockType) + ) { + return null + } + + return { + blockId: parsed.blockId, + blockType: parsed.blockType, + language: parsed.language || 'de', + text: parsed.text, + assertions: { + companyNameUsed: parsed.assertions?.companyNameUsed ?? false, + industryReferenced: parsed.assertions?.industryReferenced ?? false, + structureReferenced: parsed.assertions?.structureReferenced ?? false, + itLandscapeReferenced: parsed.assertions?.itLandscapeReferenced ?? false, + narrativeTagsUsed: parsed.assertions?.narrativeTagsUsed ?? [], + }, + forbiddenContentDetected: parsed.forbiddenContentDetected ?? [], + } + } catch { + return null + } +} + +// ============================================================================ +// Repair Prompt Builder +// ============================================================================ + +/** + * Baut den Repair-Prompt fuer einen fehlgeschlagenen Block. + */ +export function buildRepairPrompt( + originalBlock: ProseBlockOutput, + validationErrors: ProseValidationError[] +): string { + const errorList = validationErrors + .filter(e => e.severity === 'error') + .map(e => `- ${e.rule}: ${e.message}`) + .join('\n') + + return `Der vorherige Text enthielt Fehler. Ueberarbeite ihn unter Beibehaltung der Aussage. + +FEHLER: +${errorList} + +REGELN: +- Entferne alle unerlaubten Inhalte +- Behalte den Firmenkontext bei +- Erzeuge ausschliesslich JSON im vorgegebenen Format +- Aendere KEINE Fakten, ergaenze KEINE neuen Informationen +- Verwende KEINE direkte Ansprache (Sie/Ihr) +- Verwende KEINE konkreten Prozentwerte oder Scores + +ORIGINALTEXT: +${JSON.stringify(originalBlock, null, 2)}` +} + +// ============================================================================ +// Fallback Templates +// ============================================================================ + +const FALLBACK_TEMPLATES: Record = { + introduction: 'Die {{companyName}} dokumentiert im Folgenden die {{documentType}}-relevanten Massnahmen und Bewertungen. Die nachstehenden Ausfuehrungen basieren auf der aktuellen Analyse der organisatorischen und technischen Gegebenheiten.', + transition: 'Auf Grundlage der vorstehenden Daten ergeben sich die folgenden Detailbewertungen.', + conclusion: 'Die {{companyName}} verfuegt ueber die dokumentierten Massnahmen und Strukturen. Die Einhaltung der regulatorischen Anforderungen wird fortlaufend ueberprueft und angepasst.', + appreciation: 'Die bestehende Organisationsstruktur der {{companyName}} bildet eine {{maturity}} Grundlage fuer die nachfolgend dokumentierten Massnahmen.', +} + +/** + * Erzeugt einen Fallback-Block wenn der Repair-Loop fehlschlaegt. + */ +export function buildFallbackBlock( + blockId: string, + blockType: ProseBlockOutput['blockType'], + facts: AllowedFacts, + documentType?: string +): ProseBlockOutput { + let text = FALLBACK_TEMPLATES[blockType] + .replace(/\{\{companyName\}\}/g, facts.companyName) + .replace(/\{\{maturity\}\}/g, facts.narrativeTags.maturity) + .replace(/\{\{documentType\}\}/g, documentType || 'Compliance') + + return { + blockId, + blockType, + language: 'de', + text, + assertions: { + companyNameUsed: true, + industryReferenced: false, + structureReferenced: false, + itLandscapeReferenced: false, + narrativeTagsUsed: blockType === 'appreciation' ? ['maturity'] : [], + }, + forbiddenContentDetected: [], + } +} + +// ============================================================================ +// Repair Loop Orchestrator +// ============================================================================ + +/** Callback fuer LLM-Aufruf (wird von der Route injiziert) */ +export type LLMCallFn = (prompt: string) => Promise + +/** + * Orchestriert den Repair-Loop fuer einen einzelnen Prosa-Block. + * + * 1. Parse + Validate + * 2. Bei Fehler: Repair-Prompt → LLM → Parse + Validate (max 2x) + * 3. Bei weiterem Fehler: Fallback Template + * + * @returns Validierter ProseBlockOutput + RepairAudit + */ +export async function executeRepairLoop( + rawLLMOutput: string, + facts: AllowedFacts | SanitizedFacts, + expectedTags: NarrativeTags, + blockId: string, + blockType: ProseBlockOutput['blockType'], + llmCall: LLMCallFn, + documentType?: string, + maxRepairAttempts = 2 +): Promise<{ block: ProseBlockOutput; audit: RepairAudit }> { + const audit: RepairAudit = { + repairAttempts: 0, + validatorFailures: [], + repairSuccessful: false, + fallbackUsed: false, + } + + // Versuch 0: Original-Output parsen + validieren + let parsed = parseProseBlockOutput(rawLLMOutput) + + if (!parsed) { + // JSON invalid → Regeneration zaehlt als Repair-Versuch + audit.validatorFailures.push(['JSON_VALID: LLM-Output konnte nicht als JSON geparst werden']) + audit.repairAttempts++ + + if (audit.repairAttempts <= maxRepairAttempts) { + const repairPrompt = `Der vorherige Output war kein valides JSON. Erzeuge ausschliesslich ein JSON-Objekt mit den Feldern: blockId, blockType, language, text, assertions, forbiddenContentDetected.\n\nOriginal-Output:\n${rawLLMOutput.slice(0, 500)}` + try { + const repaired = await llmCall(repairPrompt) + parsed = parseProseBlockOutput(repaired) + } catch { + // LLM-Fehler → weiter zum Fallback + } + } + } + + if (!parsed) { + audit.fallbackUsed = true + audit.fallbackReason = 'JSON konnte nach Repair nicht geparst werden' + return { + block: buildFallbackBlock(blockId, blockType, facts, documentType), + audit, + } + } + + // Validierungs-Schleife + for (let attempt = audit.repairAttempts; attempt <= maxRepairAttempts; attempt++) { + const result = validateProseBlock(parsed, facts, expectedTags) + + if (result.valid) { + audit.repairSuccessful = attempt === 0 ? true : true + return { block: parsed, audit } + } + + // Hard Abort? → Fallback sofort + if (!result.repairable) { + audit.fallbackUsed = true + audit.fallbackReason = `Hard Abort: ${result.errors.filter(e => !e.repairable).map(e => e.rule).join(', ')}` + audit.validatorFailures.push(result.errors.map(e => `${e.rule}: ${e.message}`)) + return { + block: buildFallbackBlock(blockId, blockType, facts, documentType), + audit, + } + } + + // Fehler protokollieren + audit.validatorFailures.push(result.errors.map(e => `${e.rule}: ${e.message}`)) + + // Noch Repair-Versuche uebrig? + if (attempt >= maxRepairAttempts) { + break + } + + // Repair-Prompt senden + audit.repairAttempts++ + try { + const repairPrompt = buildRepairPrompt(parsed, result.errors) + const repairedOutput = await llmCall(repairPrompt) + const repairedParsed = parseProseBlockOutput(repairedOutput) + if (!repairedParsed) { + // Parsing fehlgeschlagen nach Repair + continue + } + parsed = repairedParsed + } catch { + // LLM-Fehler → naechster Versuch oder Fallback + continue + } + } + + // Alle Versuche erschoepft → Fallback + audit.fallbackUsed = true + audit.fallbackReason = `${maxRepairAttempts} Repair-Versuche erschoepft` + return { + block: buildFallbackBlock(blockId, blockType, facts, documentType), + audit, + } +} diff --git a/admin-compliance/lib/sdk/drafting-engine/sanitizer.ts b/admin-compliance/lib/sdk/drafting-engine/sanitizer.ts new file mode 100644 index 0000000..ccdad14 --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/sanitizer.ts @@ -0,0 +1,298 @@ +/** + * PII Sanitizer — Bereinigt Kontextdaten vor LLM-Aufruf + * + * Entfernt personenbezogene Daten (PII) aus AllowedFacts + * bevor sie an das LLM weitergegeben werden. + * + * Bei Fehler: Hard Abort — kein LLM-Aufruf ohne erfolgreiche Sanitization. + */ + +import type { AllowedFacts } from './allowed-facts' + +// ============================================================================ +// Types +// ============================================================================ + +/** Bereinigtes Faktenbudget (PII-frei) */ +export type SanitizedFacts = AllowedFacts & { + __sanitized: true +} + +/** Audit-Protokoll der Sanitization */ +export interface SanitizationAudit { + sanitizationApplied: boolean + redactedFieldsCount: number + redactedFieldNames: string[] +} + +/** Ergebnis der Sanitization */ +export interface SanitizationResult { + facts: SanitizedFacts + audit: SanitizationAudit +} + +/** Sanitization-Fehler (loest Hard Abort aus) */ +export class SanitizationError extends Error { + constructor( + message: string, + public readonly field: string, + public readonly reason: string + ) { + super(message) + this.name = 'SanitizationError' + } +} + +// ============================================================================ +// PII Detection Patterns +// ============================================================================ + +const PII_PATTERNS = { + email: /[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}/g, + phone: /(\+?\d{1,3}[-.\s]?)?\(?\d{2,5}\)?[-.\s]?\d{3,10}/g, + ipAddress: /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/g, + internalId: /\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b/gi, + apiKey: /\b(sk-|pk-|api[_-]?key[_-]?)[a-zA-Z0-9]{20,}\b/gi, +} as const + +// ============================================================================ +// Sanitizer +// ============================================================================ + +/** + * Bereinigt AllowedFacts von PII vor dem LLM-Aufruf. + * + * @throws {SanitizationError} Wenn ein Feld nicht bereinigt werden kann + */ +export function sanitizeAllowedFacts(facts: AllowedFacts): SanitizationResult { + const redactedFields: string[] = [] + + // Kopie erstellen + const sanitized: AllowedFacts = { + ...facts, + specialFeatures: [...facts.specialFeatures], + triggeredRegulations: [...facts.triggeredRegulations], + primaryUseCases: [...facts.primaryUseCases], + narrativeTags: { ...facts.narrativeTags }, + } + + // Firmenname: erlaubt (wird benoetigt), aber PII darin pruefen + sanitized.companyName = sanitizeString(facts.companyName, 'companyName', redactedFields) + + // Rechtsform: erlaubt, kurzer Wert + sanitized.legalForm = sanitizeString(facts.legalForm, 'legalForm', redactedFields) + + // Branche: erlaubt + sanitized.industry = sanitizeString(facts.industry, 'industry', redactedFields) + + // Standort: erlaubt (Stadt/Region), aber keine Strasse/Hausnummer + sanitized.location = sanitizeAddress(facts.location, 'location', redactedFields) + + // Mitarbeiterzahl: erlaubt (kein PII) + // employeeCount bleibt unveraendert + + // Teamstruktur: erlaubt (generisch) + sanitized.teamStructure = sanitizeString(facts.teamStructure, 'teamStructure', redactedFields) + + // IT-Landschaft: erlaubt (generisch) + sanitized.itLandscape = sanitizeString(facts.itLandscape, 'itLandscape', redactedFields) + + // Besonderheiten: pruefen + sanitized.specialFeatures = facts.specialFeatures.map((f, i) => + sanitizeString(f, `specialFeatures[${i}]`, redactedFields) + ) + + // Regulierungen: erlaubt (generisch) + sanitized.triggeredRegulations = facts.triggeredRegulations.map((r, i) => + sanitizeString(r, `triggeredRegulations[${i}]`, redactedFields) + ) + + // Use Cases: pruefen + sanitized.primaryUseCases = facts.primaryUseCases.map((uc, i) => + sanitizeString(uc, `primaryUseCases[${i}]`, redactedFields) + ) + + // Narrative Tags: deterministisch, kein PII moeglich + // Bleiben unveraendert + + return { + facts: { ...sanitized, __sanitized: true } as SanitizedFacts, + audit: { + sanitizationApplied: true, + redactedFieldsCount: redactedFields.length, + redactedFieldNames: redactedFields, + }, + } +} + +/** + * Prueft ob ein SanitizedFacts-Objekt tatsaechlich bereinigt wurde. + */ +export function isSanitized(facts: unknown): facts is SanitizedFacts { + return ( + typeof facts === 'object' && + facts !== null && + '__sanitized' in facts && + (facts as SanitizedFacts).__sanitized === true + ) +} + +// ============================================================================ +// Private Helpers +// ============================================================================ + +/** + * Bereinigt einen String-Wert von PII. + * Gibt den bereinigten String zurueck und fuegt redacted Fields hinzu. + */ +function sanitizeString( + value: string, + fieldName: string, + redactedFields: string[] +): string { + if (!value) return value + + let result = value + let wasRedacted = false + + // E-Mail-Adressen entfernen + if (PII_PATTERNS.email.test(result)) { + result = result.replace(PII_PATTERNS.email, '[REDACTED]') + wasRedacted = true + } + // Reset regex lastIndex + PII_PATTERNS.email.lastIndex = 0 + + // Telefonnummern entfernen (nur wenn >= 6 Ziffern) + const phoneMatches = result.match(PII_PATTERNS.phone) + if (phoneMatches) { + for (const match of phoneMatches) { + if (match.replace(/\D/g, '').length >= 6) { + result = result.replace(match, '[REDACTED]') + wasRedacted = true + } + } + } + PII_PATTERNS.phone.lastIndex = 0 + + // IP-Adressen entfernen + if (PII_PATTERNS.ipAddress.test(result)) { + result = result.replace(PII_PATTERNS.ipAddress, '[REDACTED]') + wasRedacted = true + } + PII_PATTERNS.ipAddress.lastIndex = 0 + + // Interne IDs (UUIDs) entfernen + if (PII_PATTERNS.internalId.test(result)) { + result = result.replace(PII_PATTERNS.internalId, '[REDACTED]') + wasRedacted = true + } + PII_PATTERNS.internalId.lastIndex = 0 + + // API Keys entfernen + if (PII_PATTERNS.apiKey.test(result)) { + result = result.replace(PII_PATTERNS.apiKey, '[REDACTED]') + wasRedacted = true + } + PII_PATTERNS.apiKey.lastIndex = 0 + + if (wasRedacted) { + redactedFields.push(fieldName) + } + + return result +} + +/** + * Bereinigt Adress-Felder: behaelt Stadt/Region, entfernt Strasse/Hausnummer. + */ +function sanitizeAddress( + value: string, + fieldName: string, + redactedFields: string[] +): string { + if (!value) return value + + // Zuerst generische PII-Bereinigung + let result = sanitizeString(value, fieldName, redactedFields) + + // Strasse + Hausnummer Pattern (deutsch) + const streetPattern = /\b[A-ZÄÖÜ][a-zäöüß]+(?:straße|str\.|weg|gasse|platz|allee|ring|damm)\s*\d+[a-z]?\b/gi + if (streetPattern.test(result)) { + result = result.replace(streetPattern, '') + if (!redactedFields.includes(fieldName)) { + redactedFields.push(fieldName) + } + } + + // PLZ-Pattern (5-stellig deutsch) + const plzPattern = /\b\d{5}\s+/g + if (plzPattern.test(result)) { + result = result.replace(plzPattern, '') + if (!redactedFields.includes(fieldName)) { + redactedFields.push(fieldName) + } + } + + return result.trim() +} + +/** + * Validiert das gesamte SanitizedFacts-Objekt auf verbleibende PII. + * Gibt Warnungen zurueck wenn doch noch PII gefunden wird. + */ +export function validateNoRemainingPII(facts: SanitizedFacts): string[] { + const warnings: string[] = [] + const allValues = extractAllStringValues(facts) + + for (const { path, value } of allValues) { + if (path === '__sanitized') continue + + PII_PATTERNS.email.lastIndex = 0 + if (PII_PATTERNS.email.test(value)) { + warnings.push(`Verbleibende E-Mail in ${path}`) + } + + PII_PATTERNS.ipAddress.lastIndex = 0 + if (PII_PATTERNS.ipAddress.test(value)) { + warnings.push(`Verbleibende IP-Adresse in ${path}`) + } + + PII_PATTERNS.apiKey.lastIndex = 0 + if (PII_PATTERNS.apiKey.test(value)) { + warnings.push(`Verbleibender API-Key in ${path}`) + } + } + + return warnings +} + +/** + * Extrahiert alle String-Werte aus einem Objekt (rekursiv). + */ +function extractAllStringValues( + obj: Record, + prefix = '' +): Array<{ path: string; value: string }> { + const results: Array<{ path: string; value: string }> = [] + + for (const [key, val] of Object.entries(obj)) { + const path = prefix ? `${prefix}.${key}` : key + + if (typeof val === 'string') { + results.push({ path, value: val }) + } else if (Array.isArray(val)) { + for (let i = 0; i < val.length; i++) { + if (typeof val[i] === 'string') { + results.push({ path: `${path}[${i}]`, value: val[i] }) + } else if (typeof val[i] === 'object' && val[i] !== null) { + results.push(...extractAllStringValues(val[i] as Record, `${path}[${i}]`)) + } + } + } else if (typeof val === 'object' && val !== null) { + results.push(...extractAllStringValues(val as Record, path)) + } + } + + return results +} diff --git a/admin-compliance/lib/sdk/drafting-engine/terminology.ts b/admin-compliance/lib/sdk/drafting-engine/terminology.ts new file mode 100644 index 0000000..6877f4f --- /dev/null +++ b/admin-compliance/lib/sdk/drafting-engine/terminology.ts @@ -0,0 +1,184 @@ +/** + * Terminology Guide & Style Contract — Konsistente Fachbegriffe + * + * Stellt sicher, dass alle Prosa-Bloecke eines Dokuments + * dieselben Fachbegriffe und denselben Schreibstil verwenden. + * + * 100% deterministisch. + */ + +// ============================================================================ +// Terminology Guide +// ============================================================================ + +export interface TerminologyGuide { + /** DSGVO-Begriffe */ + dsgvo: Record + /** TOM-Begriffe */ + tom: Record + /** Allgemeine Compliance-Begriffe */ + general: Record +} + +export const DEFAULT_TERMINOLOGY: TerminologyGuide = { + dsgvo: { + controller: 'Verantwortlicher', + processor: 'Auftragsverarbeiter', + data_subject: 'betroffene Person', + processing: 'Verarbeitung', + personal_data: 'personenbezogene Daten', + consent: 'Einwilligung', + dpia: 'Datenschutz-Folgenabschaetzung (DSFA)', + legitimate_interest: 'berechtigtes Interesse', + data_breach: 'Verletzung des Schutzes personenbezogener Daten', + dpo: 'Datenschutzbeauftragter (DSB)', + supervisory_authority: 'Aufsichtsbehoerde', + ropa: 'Verzeichnis von Verarbeitungstaetigkeiten (VVT)', + retention_period: 'Aufbewahrungsfrist', + erasure: 'Loeschung', + restriction: 'Einschraenkung der Verarbeitung', + portability: 'Datenportabilitaet', + third_country: 'Drittland', + adequacy_decision: 'Angemessenheitsbeschluss', + scc: 'Standardvertragsklauseln (SCC)', + }, + tom: { + access_control: 'Zutrittskontrolle', + access_management: 'Zugangskontrolle', + authorization: 'Zugriffskontrolle', + encryption: 'Verschluesselung', + pseudonymization: 'Pseudonymisierung', + availability: 'Verfuegbarkeitskontrolle', + resilience: 'Belastbarkeit', + recoverability: 'Wiederherstellbarkeit', + audit_logging: 'Protokollierung', + separation: 'Trennungsgebot', + input_control: 'Eingabekontrolle', + transport_control: 'Weitergabekontrolle', + order_control: 'Auftragskontrolle', + }, + general: { + risk_assessment: 'Risikobewertung', + audit_trail: 'Pruefpfad', + compliance_level: 'Compliance-Tiefe', + gap_analysis: 'Lueckenanalyse', + remediation: 'Massnahmenplan', + incident_response: 'Vorfallreaktion', + business_continuity: 'Geschaeftskontinuitaet', + vendor_management: 'Dienstleistermanagement', + awareness_training: 'Sensibilisierungsschulung', + }, +} + +// ============================================================================ +// Style Contract +// ============================================================================ + +export interface StyleContract { + /** Anrede-Stil */ + addressing: '3rd_person_company' + /** Tonalitaet */ + tone: 'formal_legal_plain' + /** Verbotene Formulierungen */ + forbid: string[] +} + +export const DEFAULT_STYLE_CONTRACT: StyleContract = { + addressing: '3rd_person_company', + tone: 'formal_legal_plain', + forbid: [ + 'Denglisch', + 'Marketing-Sprache', + 'Superlative', + 'Direkte Ansprache', + 'Umgangssprache', + 'Konjunktiv-Ketten', + ], +} + +/** Konkrete Regex-Muster fuer verbotene Formulierungen */ +export const STYLE_VIOLATION_PATTERNS: Array<{ name: string; pattern: RegExp }> = [ + { name: 'Direkte Ansprache', pattern: /\b(Sie|Ihr|Ihnen|Ihrem|Ihrer)\b/ }, + { name: 'Superlative', pattern: /\b(bestmoeglich|hoechstmoeglich|optimal|perfekt|einzigartig)\b/i }, + { name: 'Marketing-Sprache', pattern: /\b(revolutionaer|bahnbrechend|innovativ|fuehrend|erstklassig)\b/i }, + { name: 'Umgangssprache', pattern: /\b(super|toll|mega|krass|cool|easy)\b/i }, + { name: 'Denglisch', pattern: /\b(State of the Art|Best Practice|Compliance Journey|Data Driven)\b/i }, +] + +// ============================================================================ +// Serialization +// ============================================================================ + +/** + * Serialisiert den Terminology Guide fuer den LLM-Prompt. + * Gibt nur die haeufigsten Begriffe aus (Token-Budget). + */ +export function terminologyToPromptString(guide: TerminologyGuide = DEFAULT_TERMINOLOGY): string { + const keyTerms = [ + ...Object.entries(guide.dsgvo).slice(0, 10), + ...Object.entries(guide.tom).slice(0, 6), + ...Object.entries(guide.general).slice(0, 4), + ] + return keyTerms.map(([key, value]) => ` ${key}: "${value}"`).join('\n') +} + +/** + * Serialisiert den Style Contract fuer den LLM-Prompt. + */ +export function styleContractToPromptString(style: StyleContract = DEFAULT_STYLE_CONTRACT): string { + return [ + `Anrede: Dritte Person ("Die [Firmenname]...", NICHT "Sie...")`, + `Ton: Professionell, juristisch korrekt, aber verstaendlich`, + `Verboten: ${style.forbid.join(', ')}`, + ].join('\n') +} + +// ============================================================================ +// Validation +// ============================================================================ + +/** + * Prueft einen Text auf Style-Verstoesse. + * Gibt eine Liste der gefundenen Verstoesse zurueck. + */ +export function checkStyleViolations(text: string): string[] { + const violations: string[] = [] + for (const { name, pattern } of STYLE_VIOLATION_PATTERNS) { + if (pattern.test(text)) { + violations.push(`Style-Verstoss: ${name}`) + } + } + return violations +} + +/** + * Prueft ob die Terminologie korrekt verwendet wird. + * Gibt Warnungen zurueck wenn falsche Begriffe erkannt werden. + */ +export function checkTerminologyUsage( + text: string, + guide: TerminologyGuide = DEFAULT_TERMINOLOGY +): string[] { + const warnings: string[] = [] + const lower = text.toLowerCase() + + // Prüfe ob englische Begriffe statt deutscher verwendet werden + const termChecks: Array<{ wrong: string; correct: string }> = [ + { wrong: 'data controller', correct: guide.dsgvo.controller }, + { wrong: 'data processor', correct: guide.dsgvo.processor }, + { wrong: 'data subject', correct: guide.dsgvo.data_subject }, + { wrong: 'personal data', correct: guide.dsgvo.personal_data }, + { wrong: 'data breach', correct: guide.dsgvo.data_breach }, + { wrong: 'encryption', correct: guide.tom.encryption }, + { wrong: 'pseudonymization', correct: guide.tom.pseudonymization }, + { wrong: 'risk assessment', correct: guide.general.risk_assessment }, + ] + + for (const { wrong, correct } of termChecks) { + if (lower.includes(wrong.toLowerCase())) { + warnings.push(`Englischer Begriff "${wrong}" gefunden — verwende "${correct}"`) + } + } + + return warnings +} diff --git a/ai-compliance-sdk/cmd/server/main.go b/ai-compliance-sdk/cmd/server/main.go index 2c98b9b..54923ef 100644 --- a/ai-compliance-sdk/cmd/server/main.go +++ b/ai-compliance-sdk/cmd/server/main.go @@ -20,15 +20,10 @@ import ( "github.com/breakpilot/ai-compliance-sdk/internal/roadmap" "github.com/breakpilot/ai-compliance-sdk/internal/ucca" "github.com/breakpilot/ai-compliance-sdk/internal/whistleblower" - "github.com/breakpilot/ai-compliance-sdk/internal/dsb" - "github.com/breakpilot/ai-compliance-sdk/internal/multitenant" - "github.com/breakpilot/ai-compliance-sdk/internal/reporting" - "github.com/breakpilot/ai-compliance-sdk/internal/sso" + "github.com/breakpilot/ai-compliance-sdk/internal/iace" "github.com/breakpilot/ai-compliance-sdk/internal/vendor" "github.com/breakpilot/ai-compliance-sdk/internal/workshop" "github.com/breakpilot/ai-compliance-sdk/internal/portfolio" - "github.com/breakpilot/ai-compliance-sdk/internal/gci" - "github.com/breakpilot/ai-compliance-sdk/internal/training" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5/pgxpool" @@ -73,10 +68,7 @@ func main() { whistleblowerStore := whistleblower.NewStore(pool) incidentStore := incidents.NewStore(pool) vendorStore := vendor.NewStore(pool) - reportingStore := reporting.NewStore(pool, dsgvoStore, vendorStore, incidentStore, whistleblowerStore, academyStore) - ssoStore := sso.NewStore(pool) - multitenantStore := multitenant.NewStore(pool, rbacStore, reportingStore) - dsbStore := dsb.NewStore(pool, reportingStore) + iaceStore := iace.NewStore(pool) // Initialize services rbacService := rbac.NewService(rbacStore) @@ -120,24 +112,7 @@ func main() { whistleblowerHandlers := handlers.NewWhistleblowerHandlers(whistleblowerStore) incidentHandlers := handlers.NewIncidentHandlers(incidentStore) vendorHandlers := handlers.NewVendorHandlers(vendorStore) - reportingHandlers := handlers.NewReportingHandlers(reportingStore) - ssoHandlers := handlers.NewSSOHandlers(ssoStore, cfg.JWTSecret) - multitenantHandlers := handlers.NewMultiTenantHandlers(multitenantStore, rbacStore) - industryHandlers := handlers.NewIndustryHandlers() - dsbHandlers := handlers.NewDSBHandlers(dsbStore) - - // Initialize GCI engine and handlers - gciEngine := gci.NewEngine() - gciHandlers := handlers.NewGCIHandlers(gciEngine) - - // Initialize Training Engine - trainingStore := training.NewStore(pool) - ttsClient := training.NewTTSClient(cfg.TTSServiceURL) - contentGenerator := training.NewContentGenerator(providerRegistry, piiDetector, trainingStore, ttsClient) - trainingHandlers := handlers.NewTrainingHandlers(trainingStore, contentGenerator) - - // Initialize RAG handlers - ragHandlers := handlers.NewRAGHandlers() + iaceHandler := handlers.NewIACEHandler(iaceStore) // Initialize middleware rbacMiddleware := rbac.NewMiddleware(rbacService, policyEngine) @@ -494,7 +469,6 @@ func main() { // Certificates academyRoutes.GET("/certificates/:id", academyHandlers.GetCertificate) - academyRoutes.GET("/certificates/:id/pdf", academyHandlers.DownloadCertificatePDF) academyRoutes.POST("/enrollments/:id/certificate", academyHandlers.GenerateCertificate) // Quiz @@ -600,159 +574,73 @@ func main() { vendorRoutes.GET("/stats", vendorHandlers.GetStatistics) } - // Reporting routes - Executive Compliance Reporting Dashboard - reportingRoutes := v1.Group("/reporting") + // IACE routes - Industrial AI Compliance Engine (CE-Risikobeurteilung SW/FW/KI) + iaceRoutes := v1.Group("/iace") { - reportingRoutes.GET("/executive", reportingHandlers.GetExecutiveReport) - reportingRoutes.GET("/score", reportingHandlers.GetComplianceScore) - reportingRoutes.GET("/deadlines", reportingHandlers.GetUpcomingDeadlines) - reportingRoutes.GET("/risks", reportingHandlers.GetRiskOverview) - } + // Hazard Library (project-independent) + iaceRoutes.GET("/hazard-library", iaceHandler.ListHazardLibrary) - // SSO routes - Single Sign-On (SAML/OIDC) - ssoRoutes := v1.Group("/sso") - { - // Config CRUD - ssoRoutes.POST("/configs", ssoHandlers.CreateConfig) - ssoRoutes.GET("/configs", ssoHandlers.ListConfigs) - ssoRoutes.GET("/configs/:id", ssoHandlers.GetConfig) - ssoRoutes.PUT("/configs/:id", ssoHandlers.UpdateConfig) - ssoRoutes.DELETE("/configs/:id", ssoHandlers.DeleteConfig) + // Project Management + iaceRoutes.POST("/projects", iaceHandler.CreateProject) + iaceRoutes.GET("/projects", iaceHandler.ListProjects) + iaceRoutes.GET("/projects/:id", iaceHandler.GetProject) + iaceRoutes.PUT("/projects/:id", iaceHandler.UpdateProject) + iaceRoutes.DELETE("/projects/:id", iaceHandler.ArchiveProject) - // SSO Users - ssoRoutes.GET("/users", ssoHandlers.ListUsers) + // Onboarding + iaceRoutes.POST("/projects/:id/init-from-profile", iaceHandler.InitFromProfile) + iaceRoutes.POST("/projects/:id/completeness-check", iaceHandler.CheckCompleteness) - // OIDC Flow - ssoRoutes.GET("/oidc/login", ssoHandlers.InitiateOIDCLogin) - ssoRoutes.GET("/oidc/callback", ssoHandlers.HandleOIDCCallback) - } + // Components + iaceRoutes.POST("/projects/:id/components", iaceHandler.CreateComponent) + iaceRoutes.GET("/projects/:id/components", iaceHandler.ListComponents) + iaceRoutes.PUT("/projects/:id/components/:cid", iaceHandler.UpdateComponent) + iaceRoutes.DELETE("/projects/:id/components/:cid", iaceHandler.DeleteComponent) - // Multi-Tenant Administration routes - mtRoutes := v1.Group("/multi-tenant") - { - mtRoutes.GET("/overview", multitenantHandlers.GetOverview) - mtRoutes.POST("/tenants", multitenantHandlers.CreateTenant) - mtRoutes.GET("/tenants/:id", multitenantHandlers.GetTenantDetail) - mtRoutes.PUT("/tenants/:id", multitenantHandlers.UpdateTenant) - mtRoutes.GET("/tenants/:id/namespaces", multitenantHandlers.ListNamespaces) - mtRoutes.POST("/tenants/:id/namespaces", multitenantHandlers.CreateNamespace) - mtRoutes.POST("/switch", multitenantHandlers.SwitchTenant) - } + // Regulatory Classification + iaceRoutes.POST("/projects/:id/classify", iaceHandler.Classify) + iaceRoutes.GET("/projects/:id/classifications", iaceHandler.GetClassifications) + iaceRoutes.POST("/projects/:id/classify/:regulation", iaceHandler.ClassifySingle) - // Industry-Specific Templates routes (Phase 3.3) - industryRoutes := v1.Group("/industry/templates") - { - industryRoutes.GET("", industryHandlers.ListIndustries) - industryRoutes.GET("/:slug", industryHandlers.GetIndustry) - industryRoutes.GET("/:slug/vvt", industryHandlers.GetVVTTemplates) - industryRoutes.GET("/:slug/tom", industryHandlers.GetTOMRecommendations) - industryRoutes.GET("/:slug/risks", industryHandlers.GetRiskScenarios) - } + // Hazards + iaceRoutes.POST("/projects/:id/hazards", iaceHandler.CreateHazard) + iaceRoutes.GET("/projects/:id/hazards", iaceHandler.ListHazards) + iaceRoutes.PUT("/projects/:id/hazards/:hid", iaceHandler.UpdateHazard) + iaceRoutes.POST("/projects/:id/hazards/suggest", iaceHandler.SuggestHazards) - // DSB-as-a-Service Portal routes (Phase 3.4) - dsbRoutes := v1.Group("/dsb") - { - dsbRoutes.GET("/dashboard", dsbHandlers.GetDashboard) - dsbRoutes.POST("/assignments", dsbHandlers.CreateAssignment) - dsbRoutes.GET("/assignments", dsbHandlers.ListAssignments) - dsbRoutes.GET("/assignments/:id", dsbHandlers.GetAssignment) - dsbRoutes.PUT("/assignments/:id", dsbHandlers.UpdateAssignment) - dsbRoutes.POST("/assignments/:id/hours", dsbHandlers.CreateHourEntry) - dsbRoutes.GET("/assignments/:id/hours", dsbHandlers.ListHours) - dsbRoutes.GET("/assignments/:id/hours/summary", dsbHandlers.GetHoursSummary) - dsbRoutes.POST("/assignments/:id/tasks", dsbHandlers.CreateTask) - dsbRoutes.GET("/assignments/:id/tasks", dsbHandlers.ListTasks) - dsbRoutes.PUT("/tasks/:taskId", dsbHandlers.UpdateTask) - dsbRoutes.POST("/tasks/:taskId/complete", dsbHandlers.CompleteTask) - dsbRoutes.POST("/assignments/:id/communications", dsbHandlers.CreateCommunication) - dsbRoutes.GET("/assignments/:id/communications", dsbHandlers.ListCommunications) - } + // Risk Assessment + iaceRoutes.POST("/projects/:id/hazards/:hid/assess", iaceHandler.AssessRisk) + iaceRoutes.GET("/projects/:id/risk-summary", iaceHandler.GetRiskSummary) + iaceRoutes.POST("/projects/:id/hazards/:hid/reassess", iaceHandler.ReassessRisk) - // GCI routes - Gesamt-Compliance-Index - gciRoutes := v1.Group("/gci") - { - // Core GCI endpoints - gciRoutes.GET("/score", gciHandlers.GetScore) - gciRoutes.GET("/score/breakdown", gciHandlers.GetScoreBreakdown) - gciRoutes.GET("/score/history", gciHandlers.GetHistory) - gciRoutes.GET("/matrix", gciHandlers.GetMatrix) - gciRoutes.GET("/audit-trail", gciHandlers.GetAuditTrail) - gciRoutes.GET("/profiles", gciHandlers.GetWeightProfiles) + // Mitigations + iaceRoutes.POST("/projects/:id/hazards/:hid/mitigations", iaceHandler.CreateMitigation) + iaceRoutes.PUT("/mitigations/:mid", iaceHandler.UpdateMitigation) + iaceRoutes.POST("/mitigations/:mid/verify", iaceHandler.VerifyMitigation) - // NIS2 sub-routes - gciRoutes.GET("/nis2/score", gciHandlers.GetNIS2Score) - gciRoutes.GET("/nis2/roles", gciHandlers.ListNIS2Roles) - gciRoutes.POST("/nis2/roles/assign", gciHandlers.AssignNIS2Role) + // Evidence + iaceRoutes.POST("/projects/:id/evidence", iaceHandler.UploadEvidence) + iaceRoutes.GET("/projects/:id/evidence", iaceHandler.ListEvidence) - // ISO 27001 sub-routes - gciRoutes.GET("/iso/gap-analysis", gciHandlers.GetISOGapAnalysis) - gciRoutes.GET("/iso/mappings", gciHandlers.ListISOMappings) - gciRoutes.GET("/iso/mappings/:controlId", gciHandlers.GetISOMapping) - } + // Verification Plans + iaceRoutes.POST("/projects/:id/verification-plan", iaceHandler.CreateVerificationPlan) + iaceRoutes.PUT("/verification-plan/:vid", iaceHandler.UpdateVerificationPlan) + iaceRoutes.POST("/verification-plan/:vid/complete", iaceHandler.CompleteVerification) + // CE Technical File + iaceRoutes.POST("/projects/:id/tech-file/generate", iaceHandler.GenerateTechFile) + iaceRoutes.GET("/projects/:id/tech-file", iaceHandler.ListTechFileSections) + iaceRoutes.PUT("/projects/:id/tech-file/:section", iaceHandler.UpdateTechFileSection) + iaceRoutes.POST("/projects/:id/tech-file/:section/approve", iaceHandler.ApproveTechFileSection) + iaceRoutes.GET("/projects/:id/tech-file/export", iaceHandler.ExportTechFile) - // Training Engine routes - Compliance Training Management - trainingRoutes := v1.Group("/training") - { - // Modules - trainingRoutes.GET("/modules", trainingHandlers.ListModules) - trainingRoutes.GET("/modules/:id", trainingHandlers.GetModule) - trainingRoutes.POST("/modules", trainingHandlers.CreateModule) - trainingRoutes.PUT("/modules/:id", trainingHandlers.UpdateModule) + // Monitoring + iaceRoutes.POST("/projects/:id/monitoring", iaceHandler.CreateMonitoringEvent) + iaceRoutes.GET("/projects/:id/monitoring", iaceHandler.ListMonitoringEvents) + iaceRoutes.PUT("/projects/:id/monitoring/:eid", iaceHandler.UpdateMonitoringEvent) - // Training Matrix (CTM) - trainingRoutes.GET("/matrix", trainingHandlers.GetMatrix) - trainingRoutes.GET("/matrix/:role", trainingHandlers.GetMatrixForRole) - trainingRoutes.POST("/matrix", trainingHandlers.SetMatrixEntry) - trainingRoutes.DELETE("/matrix/:role/:moduleId", trainingHandlers.DeleteMatrixEntry) - - // Assignments - trainingRoutes.POST("/assignments/compute", trainingHandlers.ComputeAssignments) - trainingRoutes.GET("/assignments", trainingHandlers.ListAssignments) - trainingRoutes.GET("/assignments/:id", trainingHandlers.GetAssignment) - trainingRoutes.POST("/assignments/:id/start", trainingHandlers.StartAssignment) - trainingRoutes.POST("/assignments/:id/progress", trainingHandlers.UpdateAssignmentProgress) - trainingRoutes.POST("/assignments/:id/complete", trainingHandlers.CompleteAssignment) - - // Quiz - trainingRoutes.GET("/quiz/:moduleId", trainingHandlers.GetQuiz) - trainingRoutes.POST("/quiz/:moduleId/submit", trainingHandlers.SubmitQuiz) - trainingRoutes.GET("/quiz/attempts/:assignmentId", trainingHandlers.GetQuizAttempts) - - // Content Generation - trainingRoutes.POST("/content/generate", trainingHandlers.GenerateContent) - trainingRoutes.POST("/content/generate-quiz", trainingHandlers.GenerateQuiz) - trainingRoutes.POST("/content/generate-all", trainingHandlers.GenerateAllContent) - trainingRoutes.POST("/content/generate-all-quiz", trainingHandlers.GenerateAllQuizzes) - trainingRoutes.GET("/content/:moduleId", trainingHandlers.GetContent) - trainingRoutes.POST("/content/publish/:id", trainingHandlers.PublishContent) - - // Audio/Media - trainingRoutes.POST("/content/:moduleId/generate-audio", trainingHandlers.GenerateAudio) - trainingRoutes.GET("/media/module/:moduleId", trainingHandlers.GetModuleMedia) - trainingRoutes.GET("/media/:id/url", trainingHandlers.GetMediaURL) - trainingRoutes.POST("/media/:id/publish", trainingHandlers.PublishMedia) - - // Video - trainingRoutes.POST("/content/:moduleId/generate-video", trainingHandlers.GenerateVideo) - trainingRoutes.POST("/content/:moduleId/preview-script", trainingHandlers.PreviewVideoScript) - - // Deadlines and Escalation - trainingRoutes.GET("/deadlines", trainingHandlers.GetDeadlines) - trainingRoutes.GET("/deadlines/overdue", trainingHandlers.GetOverdueDeadlines) - trainingRoutes.POST("/escalation/check", trainingHandlers.CheckEscalation) - - // Audit and Stats - trainingRoutes.GET("/audit-log", trainingHandlers.GetAuditLog) - trainingRoutes.GET("/stats", trainingHandlers.GetStats) - trainingRoutes.GET("/certificates/:id/verify", trainingHandlers.VerifyCertificate) - } - - // RAG Search routes - Compliance Regulation Corpus - ragRoutes := v1.Group("/rag") - { - ragRoutes.POST("/search", ragHandlers.Search) - ragRoutes.GET("/regulations", ragHandlers.ListRegulations) + // Audit Trail + iaceRoutes.GET("/projects/:id/audit-trail", iaceHandler.GetAuditTrail) } } diff --git a/ai-compliance-sdk/internal/api/handlers/iace_handler.go b/ai-compliance-sdk/internal/api/handlers/iace_handler.go new file mode 100644 index 0000000..68c0028 --- /dev/null +++ b/ai-compliance-sdk/internal/api/handlers/iace_handler.go @@ -0,0 +1,1833 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/breakpilot/ai-compliance-sdk/internal/iace" + "github.com/breakpilot/ai-compliance-sdk/internal/rbac" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// ============================================================================ +// Handler Struct & Constructor +// ============================================================================ + +// IACEHandler handles HTTP requests for the IACE module (Inherent-risk Adjusted +// Control Effectiveness). It provides endpoints for project management, component +// onboarding, regulatory classification, hazard/risk analysis, evidence management, +// CE technical file generation, and post-market monitoring. +type IACEHandler struct { + store *iace.Store + engine *iace.RiskEngine + classifier *iace.Classifier + checker *iace.CompletenessChecker +} + +// NewIACEHandler creates a new IACEHandler with all required dependencies. +func NewIACEHandler(store *iace.Store) *IACEHandler { + return &IACEHandler{ + store: store, + engine: iace.NewRiskEngine(), + classifier: iace.NewClassifier(), + checker: iace.NewCompletenessChecker(), + } +} + +// ============================================================================ +// Helper: Tenant ID extraction +// ============================================================================ + +// getTenantID extracts the tenant UUID from the X-Tenant-Id header. +// It first checks the rbac middleware context; if not present, falls back to the +// raw header value. +func getTenantID(c *gin.Context) (uuid.UUID, error) { + // Prefer value set by RBAC middleware + tid := rbac.GetTenantID(c) + if tid != uuid.Nil { + return tid, nil + } + + tenantStr := c.GetHeader("X-Tenant-Id") + if tenantStr == "" { + return uuid.Nil, fmt.Errorf("X-Tenant-Id header required") + } + return uuid.Parse(tenantStr) +} + +// ============================================================================ +// Project Management +// ============================================================================ + +// CreateProject handles POST /projects +// Creates a new IACE compliance project for a machine or system. +func (h *IACEHandler) CreateProject(c *gin.Context) { + tenantID, err := getTenantID(c) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + var req iace.CreateProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + project, err := h.store.CreateProject(c.Request.Context(), tenantID, req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{"project": project}) +} + +// ListProjects handles GET /projects +// Lists all IACE projects for the authenticated tenant. +func (h *IACEHandler) ListProjects(c *gin.Context) { + tenantID, err := getTenantID(c) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + projects, err := h.store.ListProjects(c.Request.Context(), tenantID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if projects == nil { + projects = []iace.Project{} + } + + c.JSON(http.StatusOK, iace.ProjectListResponse{ + Projects: projects, + Total: len(projects), + }) +} + +// GetProject handles GET /projects/:id +// Returns a project with its components, classifications, and completeness gates. +func (h *IACEHandler) GetProject(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + components, _ := h.store.ListComponents(c.Request.Context(), projectID) + classifications, _ := h.store.GetClassifications(c.Request.Context(), projectID) + + if components == nil { + components = []iace.Component{} + } + if classifications == nil { + classifications = []iace.RegulatoryClassification{} + } + + // Build completeness context to compute gates + ctx := h.buildCompletenessContext(c, project, components, classifications) + result := h.checker.Check(ctx) + + c.JSON(http.StatusOK, iace.ProjectDetailResponse{ + Project: *project, + Components: components, + Classifications: classifications, + CompletenessGates: result.Gates, + }) +} + +// UpdateProject handles PUT /projects/:id +// Partially updates a project's mutable fields. +func (h *IACEHandler) UpdateProject(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + var req iace.UpdateProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + project, err := h.store.UpdateProject(c.Request.Context(), projectID, req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"project": project}) +} + +// ArchiveProject handles DELETE /projects/:id +// Archives a project by setting its status to archived. +func (h *IACEHandler) ArchiveProject(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + if err := h.store.ArchiveProject(c.Request.Context(), projectID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "project archived"}) +} + +// ============================================================================ +// Onboarding +// ============================================================================ + +// InitFromProfile handles POST /projects/:id/init-from-profile +// Initializes a project from a company profile and compliance scope JSON payload. +func (h *IACEHandler) InitFromProfile(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + var req iace.InitFromProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Store the profile and scope in project metadata + profileData := map[string]json.RawMessage{ + "company_profile": req.CompanyProfile, + "compliance_scope": req.ComplianceScope, + } + metadataBytes, _ := json.Marshal(profileData) + metadataRaw := json.RawMessage(metadataBytes) + updateReq := iace.UpdateProjectRequest{ + Metadata: &metadataRaw, + } + + project, err = h.store.UpdateProject(c.Request.Context(), projectID, updateReq) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Advance project status to onboarding + if err := h.store.UpdateProjectStatus(c.Request.Context(), projectID, iace.ProjectStatusOnboarding); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Add audit trail entry + userID := rbac.GetUserID(c) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "project", projectID, + iace.AuditActionUpdate, userID.String(), nil, metadataBytes, + ) + + c.JSON(http.StatusOK, gin.H{ + "message": "project initialized from profile", + "project": project, + }) +} + +// CreateComponent handles POST /projects/:id/components +// Adds a new component to a project. +func (h *IACEHandler) CreateComponent(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + var req iace.CreateComponentRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Override project ID from URL path + req.ProjectID = projectID + + component, err := h.store.CreateComponent(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + userID := rbac.GetUserID(c) + newVals, _ := json.Marshal(component) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "component", component.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{"component": component}) +} + +// ListComponents handles GET /projects/:id/components +// Lists all components for a project. +func (h *IACEHandler) ListComponents(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + components, err := h.store.ListComponents(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if components == nil { + components = []iace.Component{} + } + + c.JSON(http.StatusOK, gin.H{ + "components": components, + "total": len(components), + }) +} + +// UpdateComponent handles PUT /projects/:id/components/:cid +// Updates a component with the provided fields. +func (h *IACEHandler) UpdateComponent(c *gin.Context) { + _, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + componentID, err := uuid.Parse(c.Param("cid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid component ID"}) + return + } + + var updates map[string]interface{} + if err := c.ShouldBindJSON(&updates); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + component, err := h.store.UpdateComponent(c.Request.Context(), componentID, updates) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if component == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "component not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"component": component}) +} + +// DeleteComponent handles DELETE /projects/:id/components/:cid +// Deletes a component from a project. +func (h *IACEHandler) DeleteComponent(c *gin.Context) { + _, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + componentID, err := uuid.Parse(c.Param("cid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid component ID"}) + return + } + + if err := h.store.DeleteComponent(c.Request.Context(), componentID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "component deleted"}) +} + +// CheckCompleteness handles POST /projects/:id/completeness-check +// Loads all project data, evaluates all 25 CE completeness gates, updates the +// project's completeness score, and returns the result. +func (h *IACEHandler) CheckCompleteness(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + // Load all related entities + components, _ := h.store.ListComponents(c.Request.Context(), projectID) + classifications, _ := h.store.GetClassifications(c.Request.Context(), projectID) + hazards, _ := h.store.ListHazards(c.Request.Context(), projectID) + + // Collect all assessments and mitigations across all hazards + var allAssessments []iace.RiskAssessment + var allMitigations []iace.Mitigation + for _, hazard := range hazards { + assessments, _ := h.store.ListAssessments(c.Request.Context(), hazard.ID) + allAssessments = append(allAssessments, assessments...) + + mitigations, _ := h.store.ListMitigations(c.Request.Context(), hazard.ID) + allMitigations = append(allMitigations, mitigations...) + } + + evidence, _ := h.store.ListEvidence(c.Request.Context(), projectID) + techFileSections, _ := h.store.ListTechFileSections(c.Request.Context(), projectID) + + // Determine if the project has AI components + hasAI := false + for _, comp := range components { + if comp.ComponentType == iace.ComponentTypeAIModel { + hasAI = true + break + } + } + + // Build completeness context + completenessCtx := &iace.CompletenessContext{ + Project: project, + Components: components, + Classifications: classifications, + Hazards: hazards, + Assessments: allAssessments, + Mitigations: allMitigations, + Evidence: evidence, + TechFileSections: techFileSections, + HasAI: hasAI, + } + + // Run the checker + result := h.checker.Check(completenessCtx) + + // Build risk summary for the project update + riskSummary := map[string]int{ + "total_hazards": len(hazards), + } + for _, a := range allAssessments { + riskSummary[string(a.RiskLevel)]++ + } + + // Update project completeness score and risk summary + if err := h.store.UpdateProjectCompleteness( + c.Request.Context(), projectID, result.Score, riskSummary, + ); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "completeness": result, + }) +} + +// ============================================================================ +// Classification +// ============================================================================ + +// Classify handles POST /projects/:id/classify +// Runs all regulatory classifiers (AI Act, Machinery Regulation, CRA, NIS2), +// upserts each result into the store, and returns classifications. +func (h *IACEHandler) Classify(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + components, err := h.store.ListComponents(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Run all classifiers + results := h.classifier.ClassifyAll(project, components) + + // Upsert each classification result into the store + var classifications []iace.RegulatoryClassification + for _, r := range results { + reqsJSON, _ := json.Marshal(r.Requirements) + + classification, err := h.store.UpsertClassification( + c.Request.Context(), + projectID, + r.Regulation, + r.ClassificationResult, + r.RiskLevel, + r.Confidence, + r.Reasoning, + nil, // ragSources - not available from rule-based classifier + reqsJSON, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if classification != nil { + classifications = append(classifications, *classification) + } + } + + // Advance project status to classification + h.store.UpdateProjectStatus(c.Request.Context(), projectID, iace.ProjectStatusClassification) + + // Audit trail + userID := rbac.GetUserID(c) + newVals, _ := json.Marshal(classifications) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "classification", projectID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusOK, gin.H{ + "classifications": classifications, + "total": len(classifications), + }) +} + +// GetClassifications handles GET /projects/:id/classifications +// Returns all regulatory classifications for a project. +func (h *IACEHandler) GetClassifications(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + classifications, err := h.store.GetClassifications(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if classifications == nil { + classifications = []iace.RegulatoryClassification{} + } + + c.JSON(http.StatusOK, gin.H{ + "classifications": classifications, + "total": len(classifications), + }) +} + +// ClassifySingle handles POST /projects/:id/classify/:regulation +// Runs a single regulatory classifier for the specified regulation type. +func (h *IACEHandler) ClassifySingle(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + regulation := iace.RegulationType(c.Param("regulation")) + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + components, err := h.store.ListComponents(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Run the appropriate classifier + var result iace.ClassificationResult + switch regulation { + case iace.RegulationAIAct: + result = h.classifier.ClassifyAIAct(project, components) + case iace.RegulationMachineryRegulation: + result = h.classifier.ClassifyMachineryRegulation(project, components) + case iace.RegulationCRA: + result = h.classifier.ClassifyCRA(project, components) + case iace.RegulationNIS2: + result = h.classifier.ClassifyNIS2(project, components) + default: + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("unknown regulation type: %s", regulation)}) + return + } + + // Upsert the classification result + reqsJSON, _ := json.Marshal(result.Requirements) + + classification, err := h.store.UpsertClassification( + c.Request.Context(), + projectID, + result.Regulation, + result.ClassificationResult, + result.RiskLevel, + result.Confidence, + result.Reasoning, + nil, + reqsJSON, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"classification": classification}) +} + +// ============================================================================ +// Hazard & Risk +// ============================================================================ + +// ListHazardLibrary handles GET /hazard-library +// Returns built-in hazard library entries merged with any custom DB entries, +// optionally filtered by ?category and ?componentType. +func (h *IACEHandler) ListHazardLibrary(c *gin.Context) { + category := c.Query("category") + componentType := c.Query("componentType") + + // Start with built-in templates from Go code + builtinEntries := iace.GetBuiltinHazardLibrary() + + // Apply filters to built-in entries + var entries []iace.HazardLibraryEntry + for _, entry := range builtinEntries { + if category != "" && entry.Category != category { + continue + } + if componentType != "" && !containsString(entry.ApplicableComponentTypes, componentType) { + continue + } + entries = append(entries, entry) + } + + // Merge with custom DB entries (tenant-specific) + dbEntries, err := h.store.ListHazardLibrary(c.Request.Context(), category, componentType) + if err == nil && len(dbEntries) > 0 { + // Add DB entries that are not built-in (avoid duplicates) + builtinIDs := make(map[string]bool) + for _, e := range entries { + builtinIDs[e.ID.String()] = true + } + for _, dbEntry := range dbEntries { + if !builtinIDs[dbEntry.ID.String()] { + entries = append(entries, dbEntry) + } + } + } + + if entries == nil { + entries = []iace.HazardLibraryEntry{} + } + + c.JSON(http.StatusOK, gin.H{ + "hazard_library": entries, + "total": len(entries), + }) +} + +// containsString checks if a string slice contains the given value. +func containsString(slice []string, val string) bool { + for _, s := range slice { + if s == val { + return true + } + } + return false +} + +// CreateHazard handles POST /projects/:id/hazards +// Creates a new hazard within a project. +func (h *IACEHandler) CreateHazard(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + var req iace.CreateHazardRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Override project ID from URL path + req.ProjectID = projectID + + hazard, err := h.store.CreateHazard(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + userID := rbac.GetUserID(c) + newVals, _ := json.Marshal(hazard) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "hazard", hazard.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{"hazard": hazard}) +} + +// ListHazards handles GET /projects/:id/hazards +// Lists all hazards for a project. +func (h *IACEHandler) ListHazards(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + hazards, err := h.store.ListHazards(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if hazards == nil { + hazards = []iace.Hazard{} + } + + c.JSON(http.StatusOK, gin.H{ + "hazards": hazards, + "total": len(hazards), + }) +} + +// UpdateHazard handles PUT /projects/:id/hazards/:hid +// Updates a hazard with the provided fields. +func (h *IACEHandler) UpdateHazard(c *gin.Context) { + _, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + hazardID, err := uuid.Parse(c.Param("hid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid hazard ID"}) + return + } + + var updates map[string]interface{} + if err := c.ShouldBindJSON(&updates); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + hazard, err := h.store.UpdateHazard(c.Request.Context(), hazardID, updates) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if hazard == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "hazard not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"hazard": hazard}) +} + +// SuggestHazards handles POST /projects/:id/hazards/suggest +// Returns hazard library matches based on the project's components. +// TODO: Enhance with LLM-based suggestions for more intelligent matching. +func (h *IACEHandler) SuggestHazards(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + components, err := h.store.ListComponents(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Collect unique component types from the project + componentTypes := make(map[string]bool) + for _, comp := range components { + componentTypes[string(comp.ComponentType)] = true + } + + // Match built-in hazard templates against project component types + var suggestions []iace.HazardLibraryEntry + seen := make(map[uuid.UUID]bool) + + builtinEntries := iace.GetBuiltinHazardLibrary() + for _, entry := range builtinEntries { + for _, applicableType := range entry.ApplicableComponentTypes { + if componentTypes[applicableType] && !seen[entry.ID] { + seen[entry.ID] = true + suggestions = append(suggestions, entry) + break + } + } + } + + // Also check DB for custom tenant-specific hazard templates + for compType := range componentTypes { + dbEntries, err := h.store.ListHazardLibrary(c.Request.Context(), "", compType) + if err != nil { + continue + } + for _, entry := range dbEntries { + if !seen[entry.ID] { + seen[entry.ID] = true + suggestions = append(suggestions, entry) + } + } + } + + if suggestions == nil { + suggestions = []iace.HazardLibraryEntry{} + } + + c.JSON(http.StatusOK, gin.H{ + "suggestions": suggestions, + "total": len(suggestions), + "component_types": componentTypeKeys(componentTypes), + "_note": "TODO: LLM-based suggestion ranking not yet implemented", + }) +} + +// AssessRisk handles POST /projects/:id/hazards/:hid/assess +// Performs a quantitative risk assessment for a hazard using the IACE risk engine. +func (h *IACEHandler) AssessRisk(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + hazardID, err := uuid.Parse(c.Param("hid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid hazard ID"}) + return + } + + // Verify hazard exists + hazard, err := h.store.GetHazard(c.Request.Context(), hazardID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if hazard == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "hazard not found"}) + return + } + + var req iace.AssessRiskRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Override hazard ID from URL path + req.HazardID = hazardID + + userID := rbac.GetUserID(c) + + // Calculate risk using the engine + inherentRisk := h.engine.CalculateInherentRisk(req.Severity, req.Exposure, req.Probability) + controlEff := h.engine.CalculateControlEffectiveness(req.ControlMaturity, req.ControlCoverage, req.TestEvidenceStrength) + residualRisk := h.engine.CalculateResidualRisk(req.Severity, req.Exposure, req.Probability, controlEff) + riskLevel := h.engine.DetermineRiskLevel(residualRisk) + acceptable, acceptanceReason := h.engine.IsAcceptable(residualRisk, false, req.AcceptanceJustification != "") + + // Determine version by checking existing assessments + existingAssessments, _ := h.store.ListAssessments(c.Request.Context(), hazardID) + version := len(existingAssessments) + 1 + + assessment := &iace.RiskAssessment{ + HazardID: hazardID, + Version: version, + AssessmentType: iace.AssessmentTypeInitial, + Severity: req.Severity, + Exposure: req.Exposure, + Probability: req.Probability, + InherentRisk: inherentRisk, + ControlMaturity: req.ControlMaturity, + ControlCoverage: req.ControlCoverage, + TestEvidenceStrength: req.TestEvidenceStrength, + CEff: controlEff, + ResidualRisk: residualRisk, + RiskLevel: riskLevel, + IsAcceptable: acceptable, + AcceptanceJustification: req.AcceptanceJustification, + AssessedBy: userID, + } + + if err := h.store.CreateRiskAssessment(c.Request.Context(), assessment); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Update hazard status + h.store.UpdateHazard(c.Request.Context(), hazardID, map[string]interface{}{ + "status": string(iace.HazardStatusAssessed), + }) + + // Audit trail + newVals, _ := json.Marshal(assessment) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "risk_assessment", assessment.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{ + "assessment": assessment, + "acceptable": acceptable, + "acceptance_reason": acceptanceReason, + }) +} + +// GetRiskSummary handles GET /projects/:id/risk-summary +// Returns an aggregated risk overview for a project. +func (h *IACEHandler) GetRiskSummary(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + summary, err := h.store.GetRiskSummary(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"risk_summary": summary}) +} + +// CreateMitigation handles POST /projects/:id/hazards/:hid/mitigations +// Creates a new mitigation measure for a hazard. +func (h *IACEHandler) CreateMitigation(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + hazardID, err := uuid.Parse(c.Param("hid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid hazard ID"}) + return + } + + var req iace.CreateMitigationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Override hazard ID from URL path + req.HazardID = hazardID + + mitigation, err := h.store.CreateMitigation(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Update hazard status to mitigated + h.store.UpdateHazard(c.Request.Context(), hazardID, map[string]interface{}{ + "status": string(iace.HazardStatusMitigated), + }) + + // Audit trail + userID := rbac.GetUserID(c) + newVals, _ := json.Marshal(mitigation) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "mitigation", mitigation.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{"mitigation": mitigation}) +} + +// UpdateMitigation handles PUT /mitigations/:mid +// Updates a mitigation measure with the provided fields. +func (h *IACEHandler) UpdateMitigation(c *gin.Context) { + mitigationID, err := uuid.Parse(c.Param("mid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mitigation ID"}) + return + } + + var updates map[string]interface{} + if err := c.ShouldBindJSON(&updates); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + mitigation, err := h.store.UpdateMitigation(c.Request.Context(), mitigationID, updates) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if mitigation == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "mitigation not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"mitigation": mitigation}) +} + +// VerifyMitigation handles POST /mitigations/:mid/verify +// Marks a mitigation as verified with a verification result. +func (h *IACEHandler) VerifyMitigation(c *gin.Context) { + mitigationID, err := uuid.Parse(c.Param("mid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mitigation ID"}) + return + } + + var req struct { + VerificationResult string `json:"verification_result" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + userID := rbac.GetUserID(c) + + if err := h.store.VerifyMitigation( + c.Request.Context(), mitigationID, req.VerificationResult, userID.String(), + ); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "mitigation verified"}) +} + +// ReassessRisk handles POST /projects/:id/hazards/:hid/reassess +// Creates a post-mitigation risk reassessment for a hazard. +func (h *IACEHandler) ReassessRisk(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + hazardID, err := uuid.Parse(c.Param("hid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid hazard ID"}) + return + } + + // Verify hazard exists + hazard, err := h.store.GetHazard(c.Request.Context(), hazardID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if hazard == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "hazard not found"}) + return + } + + var req iace.AssessRiskRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + userID := rbac.GetUserID(c) + + // Calculate risk using the engine + inherentRisk := h.engine.CalculateInherentRisk(req.Severity, req.Exposure, req.Probability) + controlEff := h.engine.CalculateControlEffectiveness(req.ControlMaturity, req.ControlCoverage, req.TestEvidenceStrength) + residualRisk := h.engine.CalculateResidualRisk(req.Severity, req.Exposure, req.Probability, controlEff) + riskLevel := h.engine.DetermineRiskLevel(residualRisk) + + // For reassessment, check if all reduction steps have been applied + mitigations, _ := h.store.ListMitigations(c.Request.Context(), hazardID) + allReductionStepsApplied := len(mitigations) > 0 + for _, m := range mitigations { + if m.Status != iace.MitigationStatusVerified { + allReductionStepsApplied = false + break + } + } + + acceptable, acceptanceReason := h.engine.IsAcceptable(residualRisk, allReductionStepsApplied, req.AcceptanceJustification != "") + + // Determine version + existingAssessments, _ := h.store.ListAssessments(c.Request.Context(), hazardID) + version := len(existingAssessments) + 1 + + assessment := &iace.RiskAssessment{ + HazardID: hazardID, + Version: version, + AssessmentType: iace.AssessmentTypePostMitigation, + Severity: req.Severity, + Exposure: req.Exposure, + Probability: req.Probability, + InherentRisk: inherentRisk, + ControlMaturity: req.ControlMaturity, + ControlCoverage: req.ControlCoverage, + TestEvidenceStrength: req.TestEvidenceStrength, + CEff: controlEff, + ResidualRisk: residualRisk, + RiskLevel: riskLevel, + IsAcceptable: acceptable, + AcceptanceJustification: req.AcceptanceJustification, + AssessedBy: userID, + } + + if err := h.store.CreateRiskAssessment(c.Request.Context(), assessment); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + newVals, _ := json.Marshal(assessment) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "risk_assessment", assessment.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{ + "assessment": assessment, + "acceptable": acceptable, + "acceptance_reason": acceptanceReason, + "all_reduction_steps_applied": allReductionStepsApplied, + }) +} + +// ============================================================================ +// Evidence & Verification +// ============================================================================ + +// UploadEvidence handles POST /projects/:id/evidence +// Creates a new evidence record for a project. +func (h *IACEHandler) UploadEvidence(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + var req struct { + MitigationID *uuid.UUID `json:"mitigation_id,omitempty"` + VerificationPlanID *uuid.UUID `json:"verification_plan_id,omitempty"` + FileName string `json:"file_name" binding:"required"` + FilePath string `json:"file_path" binding:"required"` + FileHash string `json:"file_hash" binding:"required"` + FileSize int64 `json:"file_size" binding:"required"` + MimeType string `json:"mime_type" binding:"required"` + Description string `json:"description,omitempty"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + userID := rbac.GetUserID(c) + + evidence := &iace.Evidence{ + ProjectID: projectID, + MitigationID: req.MitigationID, + VerificationPlanID: req.VerificationPlanID, + FileName: req.FileName, + FilePath: req.FilePath, + FileHash: req.FileHash, + FileSize: req.FileSize, + MimeType: req.MimeType, + Description: req.Description, + UploadedBy: userID, + } + + if err := h.store.CreateEvidence(c.Request.Context(), evidence); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + newVals, _ := json.Marshal(evidence) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "evidence", evidence.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{"evidence": evidence}) +} + +// ListEvidence handles GET /projects/:id/evidence +// Lists all evidence records for a project. +func (h *IACEHandler) ListEvidence(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + evidence, err := h.store.ListEvidence(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if evidence == nil { + evidence = []iace.Evidence{} + } + + c.JSON(http.StatusOK, gin.H{ + "evidence": evidence, + "total": len(evidence), + }) +} + +// CreateVerificationPlan handles POST /projects/:id/verification-plan +// Creates a new verification plan for a project. +func (h *IACEHandler) CreateVerificationPlan(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + var req iace.CreateVerificationPlanRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Override project ID from URL path + req.ProjectID = projectID + + plan, err := h.store.CreateVerificationPlan(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + userID := rbac.GetUserID(c) + newVals, _ := json.Marshal(plan) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "verification_plan", plan.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{"verification_plan": plan}) +} + +// UpdateVerificationPlan handles PUT /verification-plan/:vid +// Updates a verification plan with the provided fields. +func (h *IACEHandler) UpdateVerificationPlan(c *gin.Context) { + planID, err := uuid.Parse(c.Param("vid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid verification plan ID"}) + return + } + + var updates map[string]interface{} + if err := c.ShouldBindJSON(&updates); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + plan, err := h.store.UpdateVerificationPlan(c.Request.Context(), planID, updates) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if plan == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "verification plan not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"verification_plan": plan}) +} + +// CompleteVerification handles POST /verification-plan/:vid/complete +// Marks a verification plan as completed with a result. +func (h *IACEHandler) CompleteVerification(c *gin.Context) { + planID, err := uuid.Parse(c.Param("vid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid verification plan ID"}) + return + } + + var req struct { + Result string `json:"result" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + userID := rbac.GetUserID(c) + + if err := h.store.CompleteVerification( + c.Request.Context(), planID, req.Result, userID.String(), + ); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "verification completed"}) +} + +// ============================================================================ +// CE Technical File +// ============================================================================ + +// GenerateTechFile handles POST /projects/:id/tech-file/generate +// Generates technical file sections for a project. +// TODO: Integrate LLM for intelligent content generation based on project data. +func (h *IACEHandler) GenerateTechFile(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + // Define the standard CE technical file sections to generate + sectionDefinitions := []struct { + SectionType string + Title string + }{ + {"general_description", "General Description of the Machinery"}, + {"risk_assessment_report", "Risk Assessment Report"}, + {"hazard_log_combined", "Combined Hazard Log"}, + {"essential_requirements", "Essential Health and Safety Requirements"}, + {"design_specifications", "Design Specifications and Drawings"}, + {"test_reports", "Test Reports and Verification Results"}, + {"standards_applied", "Applied Harmonised Standards"}, + {"declaration_of_conformity", "EU Declaration of Conformity"}, + } + + // Check if project has AI components for additional sections + components, _ := h.store.ListComponents(c.Request.Context(), projectID) + hasAI := false + for _, comp := range components { + if comp.ComponentType == iace.ComponentTypeAIModel { + hasAI = true + break + } + } + + if hasAI { + sectionDefinitions = append(sectionDefinitions, + struct { + SectionType string + Title string + }{"ai_intended_purpose", "AI System Intended Purpose"}, + struct { + SectionType string + Title string + }{"ai_model_description", "AI Model Description and Training Data"}, + struct { + SectionType string + Title string + }{"ai_risk_management", "AI Risk Management System"}, + struct { + SectionType string + Title string + }{"ai_human_oversight", "AI Human Oversight Measures"}, + ) + } + + // Generate each section with placeholder content + // TODO: Replace placeholder content with LLM-generated content based on project data + var sections []iace.TechFileSection + existingSections, _ := h.store.ListTechFileSections(c.Request.Context(), projectID) + existingMap := make(map[string]bool) + for _, s := range existingSections { + existingMap[s.SectionType] = true + } + + for _, def := range sectionDefinitions { + // Skip sections that already exist + if existingMap[def.SectionType] { + continue + } + + content := fmt.Sprintf( + "[Auto-generated placeholder for '%s']\n\n"+ + "Machine: %s\nManufacturer: %s\nType: %s\n\n"+ + "TODO: Replace this placeholder with actual content. "+ + "LLM-based generation will be integrated in a future release.", + def.Title, + project.MachineName, + project.Manufacturer, + project.MachineType, + ) + + section, err := h.store.CreateTechFileSection( + c.Request.Context(), projectID, def.SectionType, def.Title, content, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + sections = append(sections, *section) + } + + // Update project status + h.store.UpdateProjectStatus(c.Request.Context(), projectID, iace.ProjectStatusTechFile) + + // Audit trail + userID := rbac.GetUserID(c) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "tech_file", projectID, + iace.AuditActionCreate, userID.String(), nil, nil, + ) + + c.JSON(http.StatusCreated, gin.H{ + "sections_created": len(sections), + "sections": sections, + "_note": "TODO: LLM-based content generation not yet implemented", + }) +} + +// ListTechFileSections handles GET /projects/:id/tech-file +// Lists all technical file sections for a project. +func (h *IACEHandler) ListTechFileSections(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + sections, err := h.store.ListTechFileSections(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if sections == nil { + sections = []iace.TechFileSection{} + } + + c.JSON(http.StatusOK, gin.H{ + "sections": sections, + "total": len(sections), + }) +} + +// UpdateTechFileSection handles PUT /projects/:id/tech-file/:section +// Updates the content of a technical file section (identified by section_type). +func (h *IACEHandler) UpdateTechFileSection(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + sectionType := c.Param("section") + if sectionType == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "section type required"}) + return + } + + var req struct { + Content string `json:"content" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Find the section by project ID and section type + sections, err := h.store.ListTechFileSections(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var sectionID uuid.UUID + found := false + for _, s := range sections { + if s.SectionType == sectionType { + sectionID = s.ID + found = true + break + } + } + + if !found { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("tech file section '%s' not found", sectionType)}) + return + } + + if err := h.store.UpdateTechFileSection(c.Request.Context(), sectionID, req.Content); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + userID := rbac.GetUserID(c) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "tech_file_section", sectionID, + iace.AuditActionUpdate, userID.String(), nil, nil, + ) + + c.JSON(http.StatusOK, gin.H{"message": "tech file section updated"}) +} + +// ApproveTechFileSection handles POST /projects/:id/tech-file/:section/approve +// Marks a technical file section as approved. +func (h *IACEHandler) ApproveTechFileSection(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + sectionType := c.Param("section") + if sectionType == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "section type required"}) + return + } + + // Find the section by project ID and section type + sections, err := h.store.ListTechFileSections(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var sectionID uuid.UUID + found := false + for _, s := range sections { + if s.SectionType == sectionType { + sectionID = s.ID + found = true + break + } + } + + if !found { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("tech file section '%s' not found", sectionType)}) + return + } + + userID := rbac.GetUserID(c) + + if err := h.store.ApproveTechFileSection(c.Request.Context(), sectionID, userID.String()); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + h.store.AddAuditEntry( + c.Request.Context(), projectID, "tech_file_section", sectionID, + iace.AuditActionApprove, userID.String(), nil, nil, + ) + + c.JSON(http.StatusOK, gin.H{"message": "tech file section approved"}) +} + +// ExportTechFile handles GET /projects/:id/tech-file/export +// Exports all tech file sections as a combined JSON document. +// TODO: Implement PDF export with proper formatting. +func (h *IACEHandler) ExportTechFile(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + project, err := h.store.GetProject(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if project == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "project not found"}) + return + } + + sections, err := h.store.ListTechFileSections(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Check if all sections are approved + allApproved := true + for _, s := range sections { + if s.Status != iace.TechFileSectionStatusApproved { + allApproved = false + break + } + } + + classifications, _ := h.store.GetClassifications(c.Request.Context(), projectID) + riskSummary, _ := h.store.GetRiskSummary(c.Request.Context(), projectID) + + c.JSON(http.StatusOK, gin.H{ + "project": project, + "sections": sections, + "classifications": classifications, + "risk_summary": riskSummary, + "all_approved": allApproved, + "export_format": "json", + "_note": "PDF export will be available in a future release", + }) +} + +// ============================================================================ +// Monitoring +// ============================================================================ + +// CreateMonitoringEvent handles POST /projects/:id/monitoring +// Creates a new post-market monitoring event. +func (h *IACEHandler) CreateMonitoringEvent(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + var req iace.CreateMonitoringEventRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Override project ID from URL path + req.ProjectID = projectID + + event, err := h.store.CreateMonitoringEvent(c.Request.Context(), req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Audit trail + userID := rbac.GetUserID(c) + newVals, _ := json.Marshal(event) + h.store.AddAuditEntry( + c.Request.Context(), projectID, "monitoring_event", event.ID, + iace.AuditActionCreate, userID.String(), nil, newVals, + ) + + c.JSON(http.StatusCreated, gin.H{"monitoring_event": event}) +} + +// ListMonitoringEvents handles GET /projects/:id/monitoring +// Lists all monitoring events for a project. +func (h *IACEHandler) ListMonitoringEvents(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + events, err := h.store.ListMonitoringEvents(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if events == nil { + events = []iace.MonitoringEvent{} + } + + c.JSON(http.StatusOK, gin.H{ + "monitoring_events": events, + "total": len(events), + }) +} + +// UpdateMonitoringEvent handles PUT /projects/:id/monitoring/:eid +// Updates a monitoring event with the provided fields. +func (h *IACEHandler) UpdateMonitoringEvent(c *gin.Context) { + _, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + eventID, err := uuid.Parse(c.Param("eid")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid monitoring event ID"}) + return + } + + var updates map[string]interface{} + if err := c.ShouldBindJSON(&updates); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + event, err := h.store.UpdateMonitoringEvent(c.Request.Context(), eventID, updates) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if event == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "monitoring event not found"}) + return + } + + c.JSON(http.StatusOK, gin.H{"monitoring_event": event}) +} + +// GetAuditTrail handles GET /projects/:id/audit-trail +// Returns all audit trail entries for a project, newest first. +func (h *IACEHandler) GetAuditTrail(c *gin.Context) { + projectID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid project ID"}) + return + } + + entries, err := h.store.ListAuditTrail(c.Request.Context(), projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if entries == nil { + entries = []iace.AuditTrailEntry{} + } + + c.JSON(http.StatusOK, gin.H{ + "audit_trail": entries, + "total": len(entries), + }) +} + +// ============================================================================ +// Internal Helpers +// ============================================================================ + +// buildCompletenessContext constructs the CompletenessContext needed by the checker +// by loading all related entities for a project. +func (h *IACEHandler) buildCompletenessContext( + c *gin.Context, + project *iace.Project, + components []iace.Component, + classifications []iace.RegulatoryClassification, +) *iace.CompletenessContext { + projectID := project.ID + + hazards, _ := h.store.ListHazards(c.Request.Context(), projectID) + + var allAssessments []iace.RiskAssessment + var allMitigations []iace.Mitigation + for _, hazard := range hazards { + assessments, _ := h.store.ListAssessments(c.Request.Context(), hazard.ID) + allAssessments = append(allAssessments, assessments...) + + mitigations, _ := h.store.ListMitigations(c.Request.Context(), hazard.ID) + allMitigations = append(allMitigations, mitigations...) + } + + evidence, _ := h.store.ListEvidence(c.Request.Context(), projectID) + techFileSections, _ := h.store.ListTechFileSections(c.Request.Context(), projectID) + + hasAI := false + for _, comp := range components { + if comp.ComponentType == iace.ComponentTypeAIModel { + hasAI = true + break + } + } + + return &iace.CompletenessContext{ + Project: project, + Components: components, + Classifications: classifications, + Hazards: hazards, + Assessments: allAssessments, + Mitigations: allMitigations, + Evidence: evidence, + TechFileSections: techFileSections, + HasAI: hasAI, + } +} + +// componentTypeKeys extracts keys from a map[string]bool and returns them as a sorted slice. +func componentTypeKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + // Sort for deterministic output + sortStrings(keys) + return keys +} + +// sortStrings sorts a slice of strings in place using a simple insertion sort. +func sortStrings(s []string) { + for i := 1; i < len(s); i++ { + for j := i; j > 0 && strings.Compare(s[j-1], s[j]) > 0; j-- { + s[j-1], s[j] = s[j], s[j-1] + } + } +} diff --git a/ai-compliance-sdk/internal/iace/classifier.go b/ai-compliance-sdk/internal/iace/classifier.go new file mode 100644 index 0000000..7e9c7dc --- /dev/null +++ b/ai-compliance-sdk/internal/iace/classifier.go @@ -0,0 +1,415 @@ +package iace + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ============================================================================ +// Classifier Types +// ============================================================================ + +// ClassificationResult holds the output of a single regulatory classification. +type ClassificationResult struct { + Regulation RegulationType `json:"regulation"` + ClassificationResult string `json:"classification_result"` + RiskLevel string `json:"risk_level"` + Confidence float64 `json:"confidence"` + Reasoning string `json:"reasoning"` + Requirements []string `json:"requirements"` +} + +// ============================================================================ +// Classifier +// ============================================================================ + +// Classifier determines which EU regulations apply to a machine or product +// based on project metadata and component analysis. +type Classifier struct{} + +// NewClassifier creates a new Classifier instance. +func NewClassifier() *Classifier { return &Classifier{} } + +// ============================================================================ +// Public Methods +// ============================================================================ + +// ClassifyAll runs all four regulatory classifiers (AI Act, Machinery Regulation, +// CRA, NIS2) and returns the combined results. +func (c *Classifier) ClassifyAll(project *Project, components []Component) []ClassificationResult { + return []ClassificationResult{ + c.ClassifyAIAct(project, components), + c.ClassifyMachineryRegulation(project, components), + c.ClassifyCRA(project, components), + c.ClassifyNIS2(project, components), + } +} + +// ClassifyAIAct determines the AI Act classification based on whether the system +// contains AI components and their safety relevance. +// +// Classification logic: +// - Has safety-relevant AI component: "high_risk" +// - Has AI component (not safety-relevant): "limited_risk" +// - No AI components: "not_applicable" +func (c *Classifier) ClassifyAIAct(project *Project, components []Component) ClassificationResult { + result := ClassificationResult{ + Regulation: RegulationAIAct, + } + + hasAI := false + hasSafetyRelevantAI := false + var aiComponentNames []string + + for _, comp := range components { + if comp.ComponentType == ComponentTypeAIModel { + hasAI = true + aiComponentNames = append(aiComponentNames, comp.Name) + if comp.IsSafetyRelevant { + hasSafetyRelevantAI = true + } + } + } + + switch { + case hasSafetyRelevantAI: + result.ClassificationResult = "high_risk" + result.RiskLevel = "high" + result.Confidence = 0.9 + result.Reasoning = fmt.Sprintf( + "System contains safety-relevant AI component(s): %s. "+ + "Under EU AI Act Art. 6, AI systems that are safety components of products "+ + "covered by Union harmonisation legislation are classified as high-risk.", + strings.Join(aiComponentNames, ", "), + ) + result.Requirements = []string{ + "Risk management system (Art. 9)", + "Data governance and management (Art. 10)", + "Technical documentation (Art. 11)", + "Record-keeping / logging (Art. 12)", + "Transparency and information to deployers (Art. 13)", + "Human oversight measures (Art. 14)", + "Accuracy, robustness and cybersecurity (Art. 15)", + "Quality management system (Art. 17)", + "Conformity assessment before placing on market", + } + + case hasAI: + result.ClassificationResult = "limited_risk" + result.RiskLevel = "medium" + result.Confidence = 0.85 + result.Reasoning = fmt.Sprintf( + "System contains AI component(s): %s, but none are marked as safety-relevant. "+ + "Classified as limited-risk under EU AI Act with transparency obligations.", + strings.Join(aiComponentNames, ", "), + ) + result.Requirements = []string{ + "Transparency obligations (Art. 52)", + "AI literacy measures (Art. 4)", + "Technical documentation recommended", + } + + default: + result.ClassificationResult = "not_applicable" + result.RiskLevel = "none" + result.Confidence = 0.95 + result.Reasoning = "No AI model components found in the system. EU AI Act does not apply." + result.Requirements = nil + } + + return result +} + +// ClassifyMachineryRegulation determines the Machinery Regulation (EU) 2023/1230 +// classification based on CE marking intent and component analysis. +// +// Classification logic: +// - CE marking target set: "applicable" (standard machinery) +// - Has safety-relevant software/firmware: "annex_iii" (high-risk, Annex III machinery) +// - Otherwise: "standard" +func (c *Classifier) ClassifyMachineryRegulation(project *Project, components []Component) ClassificationResult { + result := ClassificationResult{ + Regulation: RegulationMachineryRegulation, + } + + hasCETarget := project.CEMarkingTarget != "" + hasSafetyRelevantSoftware := false + var safetySwNames []string + + for _, comp := range components { + if (comp.ComponentType == ComponentTypeSoftware || comp.ComponentType == ComponentTypeFirmware) && comp.IsSafetyRelevant { + hasSafetyRelevantSoftware = true + safetySwNames = append(safetySwNames, comp.Name) + } + } + + switch { + case hasSafetyRelevantSoftware: + result.ClassificationResult = "annex_iii" + result.RiskLevel = "high" + result.Confidence = 0.9 + result.Reasoning = fmt.Sprintf( + "Machine contains safety-relevant software/firmware component(s): %s. "+ + "Under Machinery Regulation (EU) 2023/1230 Annex III, machinery with safety-relevant "+ + "digital components requires third-party conformity assessment.", + strings.Join(safetySwNames, ", "), + ) + result.Requirements = []string{ + "Third-party conformity assessment (Annex III)", + "Essential health and safety requirements (Annex III EHSR)", + "Technical documentation per Annex IV", + "Risk assessment per ISO 12100", + "Software validation (IEC 62443 / IEC 61508)", + "EU Declaration of Conformity", + "CE marking", + } + + case hasCETarget: + result.ClassificationResult = "applicable" + result.RiskLevel = "medium" + result.Confidence = 0.85 + result.Reasoning = fmt.Sprintf( + "CE marking target is set (%s). Machinery Regulation (EU) 2023/1230 applies. "+ + "No safety-relevant software/firmware detected; standard conformity assessment path.", + project.CEMarkingTarget, + ) + result.Requirements = []string{ + "Essential health and safety requirements (EHSR)", + "Technical documentation per Annex IV", + "Risk assessment per ISO 12100", + "EU Declaration of Conformity", + "CE marking", + } + + default: + result.ClassificationResult = "standard" + result.RiskLevel = "low" + result.Confidence = 0.7 + result.Reasoning = "No CE marking target specified and no safety-relevant software/firmware detected. " + + "Standard machinery regulation requirements may still apply depending on product placement." + result.Requirements = []string{ + "Risk assessment recommended", + "Technical documentation recommended", + "Verify if CE marking is required for intended market", + } + } + + return result +} + +// ClassifyCRA determines the Cyber Resilience Act (CRA) classification based on +// whether the system contains networked components and their criticality. +// +// Classification logic: +// - Safety-relevant + networked: "class_ii" (highest CRA category) +// - Networked with critical component types: "class_i" +// - Networked (other): "default" (self-assessment) +// - No networked components: "not_applicable" +func (c *Classifier) ClassifyCRA(project *Project, components []Component) ClassificationResult { + result := ClassificationResult{ + Regulation: RegulationCRA, + } + + hasNetworked := false + hasSafetyRelevantNetworked := false + hasCriticalType := false + var networkedNames []string + + // Component types considered critical under CRA + criticalTypes := map[ComponentType]bool{ + ComponentTypeController: true, + ComponentTypeNetwork: true, + ComponentTypeSensor: true, + } + + for _, comp := range components { + if comp.IsNetworked { + hasNetworked = true + networkedNames = append(networkedNames, comp.Name) + + if comp.IsSafetyRelevant { + hasSafetyRelevantNetworked = true + } + if criticalTypes[comp.ComponentType] { + hasCriticalType = true + } + } + } + + switch { + case hasSafetyRelevantNetworked: + result.ClassificationResult = "class_ii" + result.RiskLevel = "high" + result.Confidence = 0.9 + result.Reasoning = fmt.Sprintf( + "System contains safety-relevant networked component(s): %s. "+ + "Under CRA, products with digital elements that are safety-relevant and networked "+ + "fall into Class II, requiring third-party conformity assessment.", + strings.Join(networkedNames, ", "), + ) + result.Requirements = []string{ + "Third-party conformity assessment", + "Vulnerability handling process", + "Security updates for product lifetime (min. 5 years)", + "SBOM (Software Bill of Materials)", + "Incident reporting to ENISA within 24h", + "Coordinated vulnerability disclosure", + "Secure by default configuration", + "Technical documentation with cybersecurity risk assessment", + } + + case hasCriticalType: + result.ClassificationResult = "class_i" + result.RiskLevel = "medium" + result.Confidence = 0.85 + result.Reasoning = fmt.Sprintf( + "System contains networked component(s) of critical type: %s. "+ + "Under CRA Class I, these products require conformity assessment via harmonised "+ + "standards or third-party assessment.", + strings.Join(networkedNames, ", "), + ) + result.Requirements = []string{ + "Conformity assessment (self or third-party with harmonised standards)", + "Vulnerability handling process", + "Security updates for product lifetime (min. 5 years)", + "SBOM (Software Bill of Materials)", + "Incident reporting to ENISA within 24h", + "Coordinated vulnerability disclosure", + "Secure by default configuration", + } + + case hasNetworked: + result.ClassificationResult = "default" + result.RiskLevel = "low" + result.Confidence = 0.85 + result.Reasoning = fmt.Sprintf( + "System contains networked component(s): %s. "+ + "CRA default category applies; self-assessment is sufficient.", + strings.Join(networkedNames, ", "), + ) + result.Requirements = []string{ + "Self-assessment conformity", + "Vulnerability handling process", + "Security updates for product lifetime (min. 5 years)", + "SBOM (Software Bill of Materials)", + "Incident reporting to ENISA within 24h", + } + + default: + result.ClassificationResult = "not_applicable" + result.RiskLevel = "none" + result.Confidence = 0.9 + result.Reasoning = "No networked components found. The Cyber Resilience Act applies to " + + "products with digital elements that have a network connection. Currently not applicable." + result.Requirements = nil + } + + return result +} + +// ClassifyNIS2 determines the NIS2 Directive classification based on project +// metadata indicating whether the manufacturer supplies critical infrastructure sectors. +// +// Classification logic: +// - Project metadata indicates KRITIS supplier: "indirect_obligation" +// - Otherwise: "not_applicable" +func (c *Classifier) ClassifyNIS2(project *Project, components []Component) ClassificationResult { + result := ClassificationResult{ + Regulation: RegulationNIS2, + } + + isKRITISSupplier := c.isKRITISSupplier(project) + + if isKRITISSupplier { + result.ClassificationResult = "indirect_obligation" + result.RiskLevel = "medium" + result.Confidence = 0.8 + result.Reasoning = "Project metadata indicates this product/system is supplied to clients " + + "in critical infrastructure sectors (KRITIS). Under NIS2, suppliers to essential and " + + "important entities have indirect obligations for supply chain security." + result.Requirements = []string{ + "Supply chain security measures", + "Incident notification support for customers", + "Cybersecurity risk management documentation", + "Security-by-design evidence", + "Contractual security requirements with KRITIS customers", + "Regular security assessments and audits", + } + } else { + result.ClassificationResult = "not_applicable" + result.RiskLevel = "none" + result.Confidence = 0.75 + result.Reasoning = "No indication in project metadata that this product is supplied to " + + "critical infrastructure (KRITIS) sectors. NIS2 indirect obligations do not currently apply. " + + "Re-evaluate if customer base changes." + result.Requirements = nil + } + + return result +} + +// ============================================================================ +// Helper Methods +// ============================================================================ + +// isKRITISSupplier checks project metadata for indicators that the manufacturer +// supplies critical infrastructure sectors. +func (c *Classifier) isKRITISSupplier(project *Project) bool { + if project.Metadata == nil { + return false + } + + var metadata map[string]interface{} + if err := json.Unmarshal(project.Metadata, &metadata); err != nil { + return false + } + + // Check for explicit KRITIS flag + if kritis, ok := metadata["kritis_supplier"]; ok { + if val, ok := kritis.(bool); ok && val { + return true + } + } + + // Check for critical sector clients + if sectors, ok := metadata["critical_sector_clients"]; ok { + switch v := sectors.(type) { + case []interface{}: + return len(v) > 0 + case bool: + return v + } + } + + // Check for NIS2-relevant target sectors + if targetSectors, ok := metadata["target_sectors"]; ok { + kriticalSectors := map[string]bool{ + "energy": true, + "transport": true, + "banking": true, + "health": true, + "water": true, + "digital_infra": true, + "public_admin": true, + "space": true, + "food": true, + "manufacturing": true, + "waste_management": true, + "postal": true, + "chemicals": true, + } + + if sectorList, ok := targetSectors.([]interface{}); ok { + for _, s := range sectorList { + if str, ok := s.(string); ok { + if kriticalSectors[strings.ToLower(str)] { + return true + } + } + } + } + } + + return false +} diff --git a/ai-compliance-sdk/internal/iace/classifier_test.go b/ai-compliance-sdk/internal/iace/classifier_test.go new file mode 100644 index 0000000..9982ad0 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/classifier_test.go @@ -0,0 +1,553 @@ +package iace + +import ( + "encoding/json" + "testing" +) + +func TestClassifyAIAct(t *testing.T) { + c := NewClassifier() + + tests := []struct { + name string + project *Project + components []Component + wantResult string + wantRiskLevel string + wantReqsEmpty bool + wantConfidence float64 + }{ + { + name: "no AI components returns not_applicable", + project: &Project{MachineName: "TestMachine"}, + components: []Component{ + {Name: "PLC", ComponentType: ComponentTypeSoftware}, + {Name: "Ethernet", ComponentType: ComponentTypeNetwork}, + }, + wantResult: "not_applicable", + wantRiskLevel: "none", + wantReqsEmpty: true, + wantConfidence: 0.95, + }, + { + name: "no components at all returns not_applicable", + project: &Project{MachineName: "EmptyMachine"}, + components: []Component{}, + wantResult: "not_applicable", + wantRiskLevel: "none", + wantReqsEmpty: true, + wantConfidence: 0.95, + }, + { + name: "AI model not safety relevant returns limited_risk", + project: &Project{MachineName: "VisionMachine"}, + components: []Component{ + {Name: "QualityChecker", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: false}, + }, + wantResult: "limited_risk", + wantRiskLevel: "medium", + wantReqsEmpty: false, + wantConfidence: 0.85, + }, + { + name: "safety-relevant AI model returns high_risk", + project: &Project{MachineName: "SafetyMachine"}, + components: []Component{ + {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true}, + }, + wantResult: "high_risk", + wantRiskLevel: "high", + wantReqsEmpty: false, + wantConfidence: 0.9, + }, + { + name: "mixed components with safety-relevant AI returns high_risk", + project: &Project{MachineName: "ComplexMachine"}, + components: []Component{ + {Name: "PLC", ComponentType: ComponentTypeSoftware}, + {Name: "BasicAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: false}, + {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true}, + {Name: "Cam", ComponentType: ComponentTypeSensor}, + }, + wantResult: "high_risk", + wantRiskLevel: "high", + wantReqsEmpty: false, + wantConfidence: 0.9, + }, + { + name: "non-AI safety-relevant component does not trigger AI act", + project: &Project{MachineName: "SafetySoftwareMachine"}, + components: []Component{ + {Name: "SafetyPLC", ComponentType: ComponentTypeSoftware, IsSafetyRelevant: true}, + }, + wantResult: "not_applicable", + wantRiskLevel: "none", + wantReqsEmpty: true, + wantConfidence: 0.95, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.ClassifyAIAct(tt.project, tt.components) + + if result.Regulation != RegulationAIAct { + t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationAIAct) + } + if result.ClassificationResult != tt.wantResult { + t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) + } + if result.RiskLevel != tt.wantRiskLevel { + t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, tt.wantRiskLevel) + } + if (result.Requirements == nil || len(result.Requirements) == 0) != tt.wantReqsEmpty { + t.Errorf("Requirements empty = %v, want %v", result.Requirements == nil || len(result.Requirements) == 0, tt.wantReqsEmpty) + } + if result.Confidence != tt.wantConfidence { + t.Errorf("Confidence = %f, want %f", result.Confidence, tt.wantConfidence) + } + if result.Reasoning == "" { + t.Error("Reasoning should not be empty") + } + }) + } +} + +func TestClassifyMachineryRegulation(t *testing.T) { + c := NewClassifier() + + tests := []struct { + name string + project *Project + components []Component + wantResult string + wantRiskLevel string + wantReqsLen int + }{ + { + name: "no CE target and no safety SW returns standard", + project: &Project{MachineName: "BasicMachine", CEMarkingTarget: ""}, + components: []Component{{Name: "App", ComponentType: ComponentTypeSoftware}}, + wantResult: "standard", + wantRiskLevel: "low", + wantReqsLen: 3, + }, + { + name: "CE target set returns applicable", + project: &Project{MachineName: "CEMachine", CEMarkingTarget: "2023/1230"}, + components: []Component{{Name: "App", ComponentType: ComponentTypeSoftware}}, + wantResult: "applicable", + wantRiskLevel: "medium", + wantReqsLen: 5, + }, + { + name: "safety-relevant software overrides CE target to annex_iii", + project: &Project{MachineName: "SafetyMachine", CEMarkingTarget: "2023/1230"}, + components: []Component{{Name: "SafetyPLC", ComponentType: ComponentTypeSoftware, IsSafetyRelevant: true}}, + wantResult: "annex_iii", + wantRiskLevel: "high", + wantReqsLen: 7, + }, + { + name: "safety-relevant firmware returns annex_iii", + project: &Project{MachineName: "FirmwareMachine", CEMarkingTarget: ""}, + components: []Component{{Name: "SafetyFW", ComponentType: ComponentTypeFirmware, IsSafetyRelevant: true}}, + wantResult: "annex_iii", + wantRiskLevel: "high", + wantReqsLen: 7, + }, + { + name: "safety-relevant non-SW component does not trigger annex_iii", + project: &Project{MachineName: "SensorMachine", CEMarkingTarget: ""}, + components: []Component{ + {Name: "SafetySensor", ComponentType: ComponentTypeSensor, IsSafetyRelevant: true}, + }, + wantResult: "standard", + wantRiskLevel: "low", + wantReqsLen: 3, + }, + { + name: "AI model safety-relevant does not trigger annex_iii (not software/firmware type)", + project: &Project{MachineName: "AIModelMachine", CEMarkingTarget: ""}, + components: []Component{ + {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true}, + }, + wantResult: "standard", + wantRiskLevel: "low", + wantReqsLen: 3, + }, + { + name: "empty components with no CE target returns standard", + project: &Project{MachineName: "EmptyMachine"}, + components: []Component{}, + wantResult: "standard", + wantRiskLevel: "low", + wantReqsLen: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.ClassifyMachineryRegulation(tt.project, tt.components) + + if result.Regulation != RegulationMachineryRegulation { + t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationMachineryRegulation) + } + if result.ClassificationResult != tt.wantResult { + t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) + } + if result.RiskLevel != tt.wantRiskLevel { + t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, tt.wantRiskLevel) + } + if len(result.Requirements) != tt.wantReqsLen { + t.Errorf("Requirements length = %d, want %d", len(result.Requirements), tt.wantReqsLen) + } + if result.Reasoning == "" { + t.Error("Reasoning should not be empty") + } + }) + } +} + +func TestClassifyCRA(t *testing.T) { + c := NewClassifier() + + tests := []struct { + name string + project *Project + components []Component + wantResult string + wantRiskLevel string + wantReqsNil bool + }{ + { + name: "no networked components returns not_applicable", + project: &Project{MachineName: "OfflineMachine"}, + components: []Component{{Name: "PLC", ComponentType: ComponentTypeSoftware, IsNetworked: false}}, + wantResult: "not_applicable", + wantRiskLevel: "none", + wantReqsNil: true, + }, + { + name: "empty components returns not_applicable", + project: &Project{MachineName: "EmptyMachine"}, + components: []Component{}, + wantResult: "not_applicable", + wantRiskLevel: "none", + wantReqsNil: true, + }, + { + name: "networked generic software returns default", + project: &Project{MachineName: "GenericNetworkedMachine"}, + components: []Component{ + {Name: "App", ComponentType: ComponentTypeSoftware, IsNetworked: true}, + }, + wantResult: "default", + wantRiskLevel: "low", + wantReqsNil: false, + }, + { + name: "networked controller returns class_i", + project: &Project{MachineName: "ControllerMachine"}, + components: []Component{ + {Name: "MainPLC", ComponentType: ComponentTypeController, IsNetworked: true}, + }, + wantResult: "class_i", + wantRiskLevel: "medium", + wantReqsNil: false, + }, + { + name: "networked network component returns class_i", + project: &Project{MachineName: "NetworkMachine"}, + components: []Component{ + {Name: "Switch", ComponentType: ComponentTypeNetwork, IsNetworked: true}, + }, + wantResult: "class_i", + wantRiskLevel: "medium", + wantReqsNil: false, + }, + { + name: "networked sensor returns class_i", + project: &Project{MachineName: "SensorMachine"}, + components: []Component{ + {Name: "IoTSensor", ComponentType: ComponentTypeSensor, IsNetworked: true}, + }, + wantResult: "class_i", + wantRiskLevel: "medium", + wantReqsNil: false, + }, + { + name: "safety-relevant networked component returns class_ii", + project: &Project{MachineName: "SafetyNetworkedMachine"}, + components: []Component{ + {Name: "SafetyNet", ComponentType: ComponentTypeSoftware, IsNetworked: true, IsSafetyRelevant: true}, + }, + wantResult: "class_ii", + wantRiskLevel: "high", + wantReqsNil: false, + }, + { + name: "safety-relevant overrides critical type", + project: &Project{MachineName: "MixedMachine"}, + components: []Component{ + {Name: "PLC", ComponentType: ComponentTypeController, IsNetworked: true, IsSafetyRelevant: true}, + }, + wantResult: "class_ii", + wantRiskLevel: "high", + wantReqsNil: false, + }, + { + name: "non-networked critical type is not_applicable", + project: &Project{MachineName: "OfflineControllerMachine"}, + components: []Component{ + {Name: "PLC", ComponentType: ComponentTypeController, IsNetworked: false}, + }, + wantResult: "not_applicable", + wantRiskLevel: "none", + wantReqsNil: true, + }, + { + name: "HMI networked but not critical type returns default", + project: &Project{MachineName: "HMIMachine"}, + components: []Component{ + {Name: "Panel", ComponentType: ComponentTypeHMI, IsNetworked: true}, + }, + wantResult: "default", + wantRiskLevel: "low", + wantReqsNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.ClassifyCRA(tt.project, tt.components) + + if result.Regulation != RegulationCRA { + t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationCRA) + } + if result.ClassificationResult != tt.wantResult { + t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) + } + if result.RiskLevel != tt.wantRiskLevel { + t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, tt.wantRiskLevel) + } + if (result.Requirements == nil) != tt.wantReqsNil { + t.Errorf("Requirements nil = %v, want %v", result.Requirements == nil, tt.wantReqsNil) + } + if result.Reasoning == "" { + t.Error("Reasoning should not be empty") + } + }) + } +} + +func TestClassifyNIS2(t *testing.T) { + c := NewClassifier() + + tests := []struct { + name string + metadata json.RawMessage + wantResult string + }{ + { + name: "nil metadata returns not_applicable", + metadata: nil, + wantResult: "not_applicable", + }, + { + name: "empty JSON object returns not_applicable", + metadata: json.RawMessage(`{}`), + wantResult: "not_applicable", + }, + { + name: "invalid JSON returns not_applicable", + metadata: json.RawMessage(`not-json`), + wantResult: "not_applicable", + }, + { + name: "kritis_supplier true returns indirect_obligation", + metadata: json.RawMessage(`{"kritis_supplier": true}`), + wantResult: "indirect_obligation", + }, + { + name: "kritis_supplier false returns not_applicable", + metadata: json.RawMessage(`{"kritis_supplier": false}`), + wantResult: "not_applicable", + }, + { + name: "critical_sector_clients non-empty array returns indirect_obligation", + metadata: json.RawMessage(`{"critical_sector_clients": ["energy"]}`), + wantResult: "indirect_obligation", + }, + { + name: "critical_sector_clients empty array returns not_applicable", + metadata: json.RawMessage(`{"critical_sector_clients": []}`), + wantResult: "not_applicable", + }, + { + name: "critical_sector_clients bool true returns indirect_obligation", + metadata: json.RawMessage(`{"critical_sector_clients": true}`), + wantResult: "indirect_obligation", + }, + { + name: "critical_sector_clients bool false returns not_applicable", + metadata: json.RawMessage(`{"critical_sector_clients": false}`), + wantResult: "not_applicable", + }, + { + name: "target_sectors with critical sector returns indirect_obligation", + metadata: json.RawMessage(`{"target_sectors": ["health"]}`), + wantResult: "indirect_obligation", + }, + { + name: "target_sectors energy returns indirect_obligation", + metadata: json.RawMessage(`{"target_sectors": ["energy"]}`), + wantResult: "indirect_obligation", + }, + { + name: "target_sectors transport returns indirect_obligation", + metadata: json.RawMessage(`{"target_sectors": ["transport"]}`), + wantResult: "indirect_obligation", + }, + { + name: "target_sectors banking returns indirect_obligation", + metadata: json.RawMessage(`{"target_sectors": ["banking"]}`), + wantResult: "indirect_obligation", + }, + { + name: "target_sectors water returns indirect_obligation", + metadata: json.RawMessage(`{"target_sectors": ["water"]}`), + wantResult: "indirect_obligation", + }, + { + name: "target_sectors digital_infra returns indirect_obligation", + metadata: json.RawMessage(`{"target_sectors": ["digital_infra"]}`), + wantResult: "indirect_obligation", + }, + { + name: "target_sectors non-critical sector returns not_applicable", + metadata: json.RawMessage(`{"target_sectors": ["retail"]}`), + wantResult: "not_applicable", + }, + { + name: "target_sectors empty array returns not_applicable", + metadata: json.RawMessage(`{"target_sectors": []}`), + wantResult: "not_applicable", + }, + { + name: "target_sectors case insensitive match", + metadata: json.RawMessage(`{"target_sectors": ["Health"]}`), + wantResult: "indirect_obligation", + }, + { + name: "kritis_supplier takes precedence over target_sectors", + metadata: json.RawMessage(`{"kritis_supplier": true, "target_sectors": ["retail"]}`), + wantResult: "indirect_obligation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + project := &Project{ + MachineName: "TestMachine", + Metadata: tt.metadata, + } + + result := c.ClassifyNIS2(project, nil) + + if result.Regulation != RegulationNIS2 { + t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationNIS2) + } + if result.ClassificationResult != tt.wantResult { + t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) + } + if result.Reasoning == "" { + t.Error("Reasoning should not be empty") + } + if tt.wantResult == "indirect_obligation" { + if result.RiskLevel != "medium" { + t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, "medium") + } + if result.Requirements == nil || len(result.Requirements) == 0 { + t.Error("Requirements should not be empty for indirect_obligation") + } + } else { + if result.RiskLevel != "none" { + t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, "none") + } + if result.Requirements != nil { + t.Errorf("Requirements should be nil for not_applicable, got %v", result.Requirements) + } + } + }) + } +} + +func TestClassifyAll(t *testing.T) { + c := NewClassifier() + + tests := []struct { + name string + project *Project + components []Component + }{ + { + name: "returns exactly 4 results for empty project", + project: &Project{MachineName: "TestMachine"}, + components: []Component{}, + }, + { + name: "returns exactly 4 results for complex project", + project: &Project{MachineName: "ComplexMachine", CEMarkingTarget: "2023/1230", Metadata: json.RawMessage(`{"kritis_supplier": true}`)}, + components: []Component{ + {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true, IsNetworked: true}, + {Name: "PLC", ComponentType: ComponentTypeController, IsNetworked: true}, + {Name: "SafetyFW", ComponentType: ComponentTypeFirmware, IsSafetyRelevant: true}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := c.ClassifyAll(tt.project, tt.components) + + if len(results) != 4 { + t.Fatalf("ClassifyAll returned %d results, want 4", len(results)) + } + + expectedRegulations := map[RegulationType]bool{ + RegulationAIAct: false, + RegulationMachineryRegulation: false, + RegulationCRA: false, + RegulationNIS2: false, + } + + for _, r := range results { + if _, ok := expectedRegulations[r.Regulation]; !ok { + t.Errorf("unexpected regulation %q in results", r.Regulation) + } + expectedRegulations[r.Regulation] = true + } + + for reg, found := range expectedRegulations { + if !found { + t.Errorf("missing regulation %q in results", reg) + } + } + + // Verify order: AI Act, Machinery, CRA, NIS2 + if results[0].Regulation != RegulationAIAct { + t.Errorf("results[0].Regulation = %q, want %q", results[0].Regulation, RegulationAIAct) + } + if results[1].Regulation != RegulationMachineryRegulation { + t.Errorf("results[1].Regulation = %q, want %q", results[1].Regulation, RegulationMachineryRegulation) + } + if results[2].Regulation != RegulationCRA { + t.Errorf("results[2].Regulation = %q, want %q", results[2].Regulation, RegulationCRA) + } + if results[3].Regulation != RegulationNIS2 { + t.Errorf("results[3].Regulation = %q, want %q", results[3].Regulation, RegulationNIS2) + } + }) + } +} diff --git a/ai-compliance-sdk/internal/iace/completeness.go b/ai-compliance-sdk/internal/iace/completeness.go new file mode 100644 index 0000000..0fa5463 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/completeness.go @@ -0,0 +1,485 @@ +package iace + +import ( + "encoding/json" + "fmt" +) + +// ============================================================================ +// Completeness Types +// ============================================================================ + +// GateDefinition describes a single completeness gate with its check function. +type GateDefinition struct { + ID string + Category string // onboarding, classification, hazard_risk, evidence, tech_file + Label string + Required bool + Recommended bool + CheckFunc func(ctx *CompletenessContext) bool +} + +// CompletenessContext provides all project data needed to evaluate completeness gates. +type CompletenessContext struct { + Project *Project + Components []Component + Classifications []RegulatoryClassification + Hazards []Hazard + Assessments []RiskAssessment + Mitigations []Mitigation + Evidence []Evidence + TechFileSections []TechFileSection + HasAI bool +} + +// CompletenessResult contains the aggregated result of all gate checks. +type CompletenessResult struct { + Score float64 `json:"score"` + Gates []CompletenessGate `json:"gates"` + PassedRequired int `json:"passed_required"` + TotalRequired int `json:"total_required"` + PassedRecommended int `json:"passed_recommended"` + TotalRecommended int `json:"total_recommended"` + CanExport bool `json:"can_export"` +} + +// ============================================================================ +// Gate Definitions (25 CE Completeness Gates) +// ============================================================================ + +// buildGateDefinitions returns the full set of 25 CE completeness gate definitions. +func buildGateDefinitions() []GateDefinition { + return []GateDefinition{ + // ===================================================================== + // Onboarding Gates (G01-G08) - Required + // ===================================================================== + { + ID: "G01", + Category: "onboarding", + Label: "Machine identity set", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return ctx.Project != nil && ctx.Project.MachineName != "" + }, + }, + { + ID: "G02", + Category: "onboarding", + Label: "Intended use described", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return ctx.Project != nil && ctx.Project.Description != "" + }, + }, + { + ID: "G03", + Category: "onboarding", + Label: "Operating limits defined", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return ctx.Project != nil && hasMetadataKey(ctx.Project.Metadata, "operating_limits") + }, + }, + { + ID: "G04", + Category: "onboarding", + Label: "Foreseeable misuse documented", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return ctx.Project != nil && hasMetadataKey(ctx.Project.Metadata, "foreseeable_misuse") + }, + }, + { + ID: "G05", + Category: "onboarding", + Label: "Component tree exists", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return len(ctx.Components) > 0 + }, + }, + { + ID: "G06", + Category: "onboarding", + Label: "AI classification done (if applicable)", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + // If no AI present, this gate passes automatically + if !ctx.HasAI { + return true + } + return hasClassificationFor(ctx.Classifications, RegulationAIAct) + }, + }, + { + ID: "G07", + Category: "onboarding", + Label: "Safety relevance marked", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + for _, comp := range ctx.Components { + if comp.IsSafetyRelevant { + return true + } + } + return false + }, + }, + { + ID: "G08", + Category: "onboarding", + Label: "Manufacturer info present", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return ctx.Project != nil && ctx.Project.Manufacturer != "" + }, + }, + + // ===================================================================== + // Classification Gates (G10-G13) - Required + // ===================================================================== + { + ID: "G10", + Category: "classification", + Label: "AI Act classification complete", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return hasClassificationFor(ctx.Classifications, RegulationAIAct) + }, + }, + { + ID: "G11", + Category: "classification", + Label: "Machinery Regulation check done", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return hasClassificationFor(ctx.Classifications, RegulationMachineryRegulation) + }, + }, + { + ID: "G12", + Category: "classification", + Label: "NIS2 check done", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return hasClassificationFor(ctx.Classifications, RegulationNIS2) + }, + }, + { + ID: "G13", + Category: "classification", + Label: "CRA check done", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return hasClassificationFor(ctx.Classifications, RegulationCRA) + }, + }, + + // ===================================================================== + // Hazard & Risk Gates (G20-G24) - Required + // ===================================================================== + { + ID: "G20", + Category: "hazard_risk", + Label: "Hazards identified", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return len(ctx.Hazards) > 0 + }, + }, + { + ID: "G21", + Category: "hazard_risk", + Label: "All hazards assessed", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + if len(ctx.Hazards) == 0 { + return false + } + // Build a set of hazard IDs that have at least one assessment + assessedHazards := make(map[string]bool) + for _, a := range ctx.Assessments { + assessedHazards[a.HazardID.String()] = true + } + for _, h := range ctx.Hazards { + if !assessedHazards[h.ID.String()] { + return false + } + } + return true + }, + }, + { + ID: "G22", + Category: "hazard_risk", + Label: "Critical/High risks mitigated", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + // Find all hazards that have a critical or high assessment + criticalHighHazards := make(map[string]bool) + for _, a := range ctx.Assessments { + if a.RiskLevel == RiskLevelCritical || a.RiskLevel == RiskLevelHigh { + criticalHighHazards[a.HazardID.String()] = true + } + } + + // If no critical/high hazards, gate passes + if len(criticalHighHazards) == 0 { + return true + } + + // Check that every critical/high hazard has at least one mitigation + mitigatedHazards := make(map[string]bool) + for _, m := range ctx.Mitigations { + mitigatedHazards[m.HazardID.String()] = true + } + + for hazardID := range criticalHighHazards { + if !mitigatedHazards[hazardID] { + return false + } + } + return true + }, + }, + { + ID: "G23", + Category: "hazard_risk", + Label: "Mitigations verified", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + // All mitigations with status "implemented" must also be verified + for _, m := range ctx.Mitigations { + if m.Status == MitigationStatusImplemented { + // Implemented but not yet verified -> gate fails + return false + } + } + // All mitigations are either planned, verified, or rejected + return true + }, + }, + { + ID: "G24", + Category: "hazard_risk", + Label: "Residual risk accepted", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + if len(ctx.Assessments) == 0 { + return false + } + for _, a := range ctx.Assessments { + if !a.IsAcceptable && a.RiskLevel != RiskLevelLow && a.RiskLevel != RiskLevelNegligible { + return false + } + } + return true + }, + }, + + // ===================================================================== + // Evidence Gate (G30) - Recommended + // ===================================================================== + { + ID: "G30", + Category: "evidence", + Label: "Test evidence linked", + Required: false, + Recommended: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return len(ctx.Evidence) > 0 + }, + }, + + // ===================================================================== + // Tech File Gates (G40-G42) - Required for completion + // ===================================================================== + { + ID: "G40", + Category: "tech_file", + Label: "Risk assessment report generated", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return hasTechFileSection(ctx.TechFileSections, "risk_assessment_report") + }, + }, + { + ID: "G41", + Category: "tech_file", + Label: "Hazard log generated", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + return hasTechFileSection(ctx.TechFileSections, "hazard_log_combined") + }, + }, + { + ID: "G42", + Category: "tech_file", + Label: "AI documents present (if applicable)", + Required: true, + CheckFunc: func(ctx *CompletenessContext) bool { + // If no AI present, this gate passes automatically + if !ctx.HasAI { + return true + } + hasIntendedPurpose := hasTechFileSection(ctx.TechFileSections, "ai_intended_purpose") + hasModelDescription := hasTechFileSection(ctx.TechFileSections, "ai_model_description") + return hasIntendedPurpose && hasModelDescription + }, + }, + } +} + +// ============================================================================ +// CompletenessChecker +// ============================================================================ + +// CompletenessChecker evaluates the 25 CE completeness gates for an IACE project. +type CompletenessChecker struct{} + +// NewCompletenessChecker creates a new CompletenessChecker instance. +func NewCompletenessChecker() *CompletenessChecker { return &CompletenessChecker{} } + +// Check evaluates all 25 completeness gates against the provided context and +// returns an aggregated result with a weighted score. +// +// Scoring formula: +// +// score = (passed_required / total_required) * 80 +// + (passed_recommended / total_recommended) * 15 +// + (passed_optional / total_optional) * 5 +// +// Optional gates are those that are neither required nor recommended. +// CanExport is true only when all required gates have passed. +func (c *CompletenessChecker) Check(ctx *CompletenessContext) CompletenessResult { + gates := buildGateDefinitions() + + var result CompletenessResult + var passedOptional, totalOptional int + + for _, gate := range gates { + passed := gate.CheckFunc(ctx) + + details := "" + if !passed { + details = fmt.Sprintf("Gate %s not satisfied: %s", gate.ID, gate.Label) + } + + result.Gates = append(result.Gates, CompletenessGate{ + ID: gate.ID, + Category: gate.Category, + Label: gate.Label, + Required: gate.Required, + Passed: passed, + Details: details, + }) + + switch { + case gate.Required: + result.TotalRequired++ + if passed { + result.PassedRequired++ + } + case gate.Recommended: + result.TotalRecommended++ + if passed { + result.PassedRecommended++ + } + default: + // Optional gate (neither required nor recommended) + totalOptional++ + if passed { + passedOptional++ + } + } + } + + // Calculate weighted score + result.Score = calculateWeightedScore( + result.PassedRequired, result.TotalRequired, + result.PassedRecommended, result.TotalRecommended, + passedOptional, totalOptional, + ) + + // CanExport is true only when ALL required gates pass + result.CanExport = result.PassedRequired == result.TotalRequired + + return result +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +// hasMetadataKey checks whether a JSON metadata blob contains a non-empty value +// for the given key. +func hasMetadataKey(metadata json.RawMessage, key string) bool { + if metadata == nil { + return false + } + + var m map[string]interface{} + if err := json.Unmarshal(metadata, &m); err != nil { + return false + } + + val, exists := m[key] + if !exists { + return false + } + + // Check that the value is not empty/nil + switch v := val.(type) { + case string: + return v != "" + case nil: + return false + default: + return true + } +} + +// hasClassificationFor checks whether a classification exists for the given regulation type. +func hasClassificationFor(classifications []RegulatoryClassification, regulation RegulationType) bool { + for _, c := range classifications { + if c.Regulation == regulation { + return true + } + } + return false +} + +// hasTechFileSection checks whether a tech file section of the given type exists. +func hasTechFileSection(sections []TechFileSection, sectionType string) bool { + for _, s := range sections { + if s.SectionType == sectionType { + return true + } + } + return false +} + +// calculateWeightedScore computes the weighted completeness score (0-100). +// +// Formula: +// +// score = (passedRequired/totalRequired) * 80 +// + (passedRecommended/totalRecommended) * 15 +// + (passedOptional/totalOptional) * 5 +// +// If any denominator is 0, that component contributes 0 to the score. +func calculateWeightedScore(passedRequired, totalRequired, passedRecommended, totalRecommended, passedOptional, totalOptional int) float64 { + var score float64 + + if totalRequired > 0 { + score += (float64(passedRequired) / float64(totalRequired)) * 80 + } + if totalRecommended > 0 { + score += (float64(passedRecommended) / float64(totalRecommended)) * 15 + } + if totalOptional > 0 { + score += (float64(passedOptional) / float64(totalOptional)) * 5 + } + + return score +} diff --git a/ai-compliance-sdk/internal/iace/completeness_test.go b/ai-compliance-sdk/internal/iace/completeness_test.go new file mode 100644 index 0000000..5b974e4 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/completeness_test.go @@ -0,0 +1,678 @@ +package iace + +import ( + "encoding/json" + "math" + "testing" + + "github.com/google/uuid" +) + +// helper to build metadata with the given keys set to non-empty string values. +func metadataWith(keys ...string) json.RawMessage { + m := make(map[string]interface{}) + for _, k := range keys { + m[k] = "defined" + } + data, _ := json.Marshal(m) + return data +} + +func TestCompletenessCheck_EmptyContext(t *testing.T) { + checker := NewCompletenessChecker() + + ctx := &CompletenessContext{ + Project: nil, + } + + result := checker.Check(ctx) + + if result.CanExport { + t.Error("CanExport should be false for empty context") + } + // With nil project, most gates fail. However, some auto-pass: + // G06 (AI classification): auto-passes when HasAI=false + // G22 (critical/high mitigated): auto-passes when no critical/high assessments exist + // G23 (mitigations verified): auto-passes when no mitigations with status "implemented" + // G42 (AI documents): auto-passes when HasAI=false + // That gives 4 required gates passing even with empty context. + if result.PassedRequired != 4 { + t.Errorf("PassedRequired = %d, want 4 (G06, G22, G23, G42 auto-pass)", result.PassedRequired) + } + // Score should be low: 4/20 * 80 = 16 + if result.Score > 20 { + t.Errorf("Score = %f, expected <= 20 for empty context", result.Score) + } + if len(result.Gates) == 0 { + t.Error("Gates should not be empty") + } +} + +func TestCompletenessCheck_MinimalValidProject(t *testing.T) { + checker := NewCompletenessChecker() + + projectID := uuid.New() + hazardID := uuid.New() + componentID := uuid.New() + + ctx := &CompletenessContext{ + Project: &Project{ + ID: projectID, + MachineName: "TestMachine", + Description: "A test machine for unit testing", + Manufacturer: "TestCorp", + CEMarkingTarget: "2023/1230", + Metadata: metadataWith("operating_limits", "foreseeable_misuse"), + }, + Components: []Component{ + {ID: componentID, Name: "SafetyPLC", ComponentType: ComponentTypeSoftware, IsSafetyRelevant: true}, + }, + Classifications: []RegulatoryClassification{ + {Regulation: RegulationAIAct}, + {Regulation: RegulationMachineryRegulation}, + {Regulation: RegulationNIS2}, + {Regulation: RegulationCRA}, + }, + Hazards: []Hazard{ + {ID: hazardID, ProjectID: projectID, ComponentID: componentID, Name: "TestHazard", Category: "test"}, + }, + Assessments: []RiskAssessment{ + {ID: uuid.New(), HazardID: hazardID, RiskLevel: RiskLevelLow, IsAcceptable: true}, + }, + Mitigations: []Mitigation{ + {ID: uuid.New(), HazardID: hazardID, Status: MitigationStatusVerified}, + }, + Evidence: []Evidence{ + {ID: uuid.New(), ProjectID: projectID, FileName: "test.pdf"}, + }, + TechFileSections: []TechFileSection{ + {ID: uuid.New(), ProjectID: projectID, SectionType: "risk_assessment_report"}, + {ID: uuid.New(), ProjectID: projectID, SectionType: "hazard_log_combined"}, + }, + HasAI: false, + } + + result := checker.Check(ctx) + + if !result.CanExport { + t.Error("CanExport should be true for fully valid project") + for _, g := range result.Gates { + if g.Required && !g.Passed { + t.Errorf(" Required gate %s (%s) not passed: %s", g.ID, g.Label, g.Details) + } + } + } + if result.PassedRequired != result.TotalRequired { + t.Errorf("PassedRequired = %d, TotalRequired = %d, want all passed", result.PassedRequired, result.TotalRequired) + } + // Score should be at least 80 (all required) + 15 (evidence recommended) = 95 + if result.Score < 80 { + t.Errorf("Score = %f, expected >= 80 for fully valid project", result.Score) + } +} + +func TestCompletenessCheck_PartialRequiredGates(t *testing.T) { + checker := NewCompletenessChecker() + + // Provide only some required data: machine name, manufacturer, description, one component, one safety-relevant. + // Missing: operating_limits, foreseeable_misuse, classifications, hazards, assessments, tech files. + ctx := &CompletenessContext{ + Project: &Project{ + MachineName: "PartialMachine", + Description: "Some description", + Manufacturer: "TestCorp", + }, + Components: []Component{ + {Name: "Sensor", ComponentType: ComponentTypeSensor, IsSafetyRelevant: true}, + }, + HasAI: false, + } + + result := checker.Check(ctx) + + if result.CanExport { + t.Error("CanExport should be false when not all required gates pass") + } + if result.PassedRequired == 0 { + t.Error("Some required gates should pass (G01, G02, G05, G06, G07, G08)") + } + if result.PassedRequired >= result.TotalRequired { + t.Errorf("PassedRequired (%d) should be less than TotalRequired (%d)", result.PassedRequired, result.TotalRequired) + } + // Score should be partial + if result.Score <= 0 || result.Score >= 95 { + t.Errorf("Score = %f, expected partial score between 0 and 95", result.Score) + } +} + +func TestCompletenessCheck_G06_AIClassificationGate(t *testing.T) { + checker := NewCompletenessChecker() + + tests := []struct { + name string + hasAI bool + classifications []RegulatoryClassification + wantG06Passed bool + }{ + { + name: "no AI present auto-passes G06", + hasAI: false, + classifications: nil, + wantG06Passed: true, + }, + { + name: "AI present without classification fails G06", + hasAI: true, + classifications: nil, + wantG06Passed: false, + }, + { + name: "AI present with AI Act classification passes G06", + hasAI: true, + classifications: []RegulatoryClassification{ + {Regulation: RegulationAIAct}, + }, + wantG06Passed: true, + }, + { + name: "AI present with non-AI classification fails G06", + hasAI: true, + classifications: []RegulatoryClassification{ + {Regulation: RegulationCRA}, + }, + wantG06Passed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &CompletenessContext{ + Project: &Project{MachineName: "Test"}, + Classifications: tt.classifications, + HasAI: tt.hasAI, + } + + result := checker.Check(ctx) + + for _, g := range result.Gates { + if g.ID == "G06" { + if g.Passed != tt.wantG06Passed { + t.Errorf("G06 Passed = %v, want %v", g.Passed, tt.wantG06Passed) + } + return + } + } + t.Error("G06 gate not found in results") + }) + } +} + +func TestCompletenessCheck_G42_AIDocumentsGate(t *testing.T) { + checker := NewCompletenessChecker() + + tests := []struct { + name string + hasAI bool + techFileSections []TechFileSection + wantG42Passed bool + }{ + { + name: "no AI auto-passes G42", + hasAI: false, + techFileSections: nil, + wantG42Passed: true, + }, + { + name: "AI present without tech files fails G42", + hasAI: true, + techFileSections: nil, + wantG42Passed: false, + }, + { + name: "AI present with only intended_purpose fails G42", + hasAI: true, + techFileSections: []TechFileSection{ + {SectionType: "ai_intended_purpose"}, + }, + wantG42Passed: false, + }, + { + name: "AI present with only model_description fails G42", + hasAI: true, + techFileSections: []TechFileSection{ + {SectionType: "ai_model_description"}, + }, + wantG42Passed: false, + }, + { + name: "AI present with both AI sections passes G42", + hasAI: true, + techFileSections: []TechFileSection{ + {SectionType: "ai_intended_purpose"}, + {SectionType: "ai_model_description"}, + }, + wantG42Passed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &CompletenessContext{ + Project: &Project{MachineName: "Test"}, + TechFileSections: tt.techFileSections, + HasAI: tt.hasAI, + } + + result := checker.Check(ctx) + + for _, g := range result.Gates { + if g.ID == "G42" { + if g.Passed != tt.wantG42Passed { + t.Errorf("G42 Passed = %v, want %v", g.Passed, tt.wantG42Passed) + } + return + } + } + t.Error("G42 gate not found in results") + }) + } +} + +func TestCompletenessCheck_G22_CriticalHighMitigated(t *testing.T) { + checker := NewCompletenessChecker() + + hazardID := uuid.New() + + tests := []struct { + name string + assessments []RiskAssessment + mitigations []Mitigation + wantG22Passed bool + }{ + { + name: "no critical/high hazards auto-passes G22", + assessments: []RiskAssessment{{HazardID: hazardID, RiskLevel: RiskLevelLow}}, + mitigations: nil, + wantG22Passed: true, + }, + { + name: "no assessments at all auto-passes G22 (no critical/high found)", + assessments: nil, + mitigations: nil, + wantG22Passed: true, + }, + { + name: "critical hazard without mitigation fails G22", + assessments: []RiskAssessment{{HazardID: hazardID, RiskLevel: RiskLevelCritical}}, + mitigations: nil, + wantG22Passed: false, + }, + { + name: "high hazard without mitigation fails G22", + assessments: []RiskAssessment{{HazardID: hazardID, RiskLevel: RiskLevelHigh}}, + mitigations: nil, + wantG22Passed: false, + }, + { + name: "critical hazard with mitigation passes G22", + assessments: []RiskAssessment{{HazardID: hazardID, RiskLevel: RiskLevelCritical}}, + mitigations: []Mitigation{{HazardID: hazardID, Status: MitigationStatusVerified}}, + wantG22Passed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &CompletenessContext{ + Project: &Project{MachineName: "Test"}, + Assessments: tt.assessments, + Mitigations: tt.mitigations, + } + + result := checker.Check(ctx) + + for _, g := range result.Gates { + if g.ID == "G22" { + if g.Passed != tt.wantG22Passed { + t.Errorf("G22 Passed = %v, want %v", g.Passed, tt.wantG22Passed) + } + return + } + } + t.Error("G22 gate not found in results") + }) + } +} + +func TestCompletenessCheck_G23_MitigationsVerified(t *testing.T) { + checker := NewCompletenessChecker() + + hazardID := uuid.New() + + tests := []struct { + name string + mitigations []Mitigation + wantG23Passed bool + }{ + { + name: "no mitigations passes G23", + mitigations: nil, + wantG23Passed: true, + }, + { + name: "all mitigations verified passes G23", + mitigations: []Mitigation{ + {HazardID: hazardID, Status: MitigationStatusVerified}, + {HazardID: hazardID, Status: MitigationStatusVerified}, + }, + wantG23Passed: true, + }, + { + name: "one mitigation still implemented fails G23", + mitigations: []Mitigation{ + {HazardID: hazardID, Status: MitigationStatusVerified}, + {HazardID: hazardID, Status: MitigationStatusImplemented}, + }, + wantG23Passed: false, + }, + { + name: "planned mitigations pass G23 (not yet implemented)", + mitigations: []Mitigation{ + {HazardID: hazardID, Status: MitigationStatusPlanned}, + }, + wantG23Passed: true, + }, + { + name: "rejected mitigations pass G23", + mitigations: []Mitigation{ + {HazardID: hazardID, Status: MitigationStatusRejected}, + }, + wantG23Passed: true, + }, + { + name: "mix of verified planned rejected passes G23", + mitigations: []Mitigation{ + {HazardID: hazardID, Status: MitigationStatusVerified}, + {HazardID: hazardID, Status: MitigationStatusPlanned}, + {HazardID: hazardID, Status: MitigationStatusRejected}, + }, + wantG23Passed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &CompletenessContext{ + Project: &Project{MachineName: "Test"}, + Mitigations: tt.mitigations, + } + + result := checker.Check(ctx) + + for _, g := range result.Gates { + if g.ID == "G23" { + if g.Passed != tt.wantG23Passed { + t.Errorf("G23 Passed = %v, want %v", g.Passed, tt.wantG23Passed) + } + return + } + } + t.Error("G23 gate not found in results") + }) + } +} + +func TestCompletenessCheck_G24_ResidualRiskAccepted(t *testing.T) { + checker := NewCompletenessChecker() + + hazardID := uuid.New() + + tests := []struct { + name string + assessments []RiskAssessment + wantG24Passed bool + }{ + { + name: "no assessments fails G24", + assessments: nil, + wantG24Passed: false, + }, + { + name: "all assessments acceptable passes G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: true, RiskLevel: RiskLevelMedium}, + {HazardID: hazardID, IsAcceptable: true, RiskLevel: RiskLevelHigh}, + }, + wantG24Passed: true, + }, + { + name: "not acceptable but low risk passes G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: false, RiskLevel: RiskLevelLow}, + }, + wantG24Passed: true, + }, + { + name: "not acceptable but negligible risk passes G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: false, RiskLevel: RiskLevelNegligible}, + }, + wantG24Passed: true, + }, + { + name: "not acceptable with high risk fails G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: false, RiskLevel: RiskLevelHigh}, + }, + wantG24Passed: false, + }, + { + name: "not acceptable with critical risk fails G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: false, RiskLevel: RiskLevelCritical}, + }, + wantG24Passed: false, + }, + { + name: "not acceptable with medium risk fails G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: false, RiskLevel: RiskLevelMedium}, + }, + wantG24Passed: false, + }, + { + name: "mix acceptable and unacceptable high fails G24", + assessments: []RiskAssessment{ + {HazardID: hazardID, IsAcceptable: true, RiskLevel: RiskLevelHigh}, + {HazardID: hazardID, IsAcceptable: false, RiskLevel: RiskLevelHigh}, + }, + wantG24Passed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &CompletenessContext{ + Project: &Project{MachineName: "Test"}, + Assessments: tt.assessments, + } + + result := checker.Check(ctx) + + for _, g := range result.Gates { + if g.ID == "G24" { + if g.Passed != tt.wantG24Passed { + t.Errorf("G24 Passed = %v, want %v", g.Passed, tt.wantG24Passed) + } + return + } + } + t.Error("G24 gate not found in results") + }) + } +} + +func TestCompletenessCheck_ScoringFormula(t *testing.T) { + tests := []struct { + name string + passedRequired int + totalRequired int + passedRecommended int + totalRecommended int + passedOptional int + totalOptional int + wantScore float64 + }{ + { + name: "all zeros produces zero score", + passedRequired: 0, + totalRequired: 0, + passedRecommended: 0, + totalRecommended: 0, + passedOptional: 0, + totalOptional: 0, + wantScore: 0, + }, + { + name: "all required passed gives 80", + passedRequired: 20, + totalRequired: 20, + passedRecommended: 0, + totalRecommended: 1, + passedOptional: 0, + totalOptional: 0, + wantScore: 80, + }, + { + name: "half required passed gives 40", + passedRequired: 10, + totalRequired: 20, + passedRecommended: 0, + totalRecommended: 1, + passedOptional: 0, + totalOptional: 0, + wantScore: 40, + }, + { + name: "all required and all recommended gives 95", + passedRequired: 20, + totalRequired: 20, + passedRecommended: 1, + totalRecommended: 1, + passedOptional: 0, + totalOptional: 0, + wantScore: 95, + }, + { + name: "all categories full gives 100", + passedRequired: 20, + totalRequired: 20, + passedRecommended: 1, + totalRecommended: 1, + passedOptional: 1, + totalOptional: 1, + wantScore: 100, + }, + { + name: "only recommended passed", + passedRequired: 0, + totalRequired: 20, + passedRecommended: 1, + totalRecommended: 1, + passedOptional: 0, + totalOptional: 0, + wantScore: 15, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + score := calculateWeightedScore( + tt.passedRequired, tt.totalRequired, + tt.passedRecommended, tt.totalRecommended, + tt.passedOptional, tt.totalOptional, + ) + + if math.Abs(score-tt.wantScore) > 0.01 { + t.Errorf("calculateWeightedScore = %f, want %f", score, tt.wantScore) + } + }) + } +} + +func TestCompletenessCheck_GateCountsAndCategories(t *testing.T) { + checker := NewCompletenessChecker() + + ctx := &CompletenessContext{ + Project: &Project{MachineName: "Test"}, + } + result := checker.Check(ctx) + + // The buildGateDefinitions function returns exactly 21 gates + // (G01-G08: 8, G10-G13: 4, G20-G24: 5, G30: 1, G40-G42: 3 = 21 total) + if len(result.Gates) != 21 { + t.Errorf("Total gates = %d, want 21", len(result.Gates)) + } + + // Count required vs recommended + requiredCount := 0 + recommendedCount := 0 + for _, g := range result.Gates { + if g.Required { + requiredCount++ + } + } + // G30 is the only recommended gate (Required=false, Recommended=true) + // All others are required (20 required, 1 recommended) + if requiredCount != 20 { + t.Errorf("Required gates count = %d, want 20", requiredCount) + } + + if result.TotalRequired != 20 { + t.Errorf("TotalRequired = %d, want 20", result.TotalRequired) + } + + // TotalRecommended should be 1 (G30) + if result.TotalRecommended != 1 { + t.Errorf("TotalRecommended = %d, want 1", result.TotalRecommended) + } + _ = recommendedCount + + // Verify expected categories exist + categories := make(map[string]int) + for _, g := range result.Gates { + categories[g.Category]++ + } + + expectedCategories := map[string]int{ + "onboarding": 8, + "classification": 4, + "hazard_risk": 5, + "evidence": 1, + "tech_file": 3, + } + for cat, expectedCount := range expectedCategories { + if categories[cat] != expectedCount { + t.Errorf("Category %q count = %d, want %d", cat, categories[cat], expectedCount) + } + } +} + +func TestCompletenessCheck_FailedGateHasDetails(t *testing.T) { + checker := NewCompletenessChecker() + + ctx := &CompletenessContext{ + Project: &Project{}, // empty project, many gates will fail + } + + result := checker.Check(ctx) + + for _, g := range result.Gates { + if !g.Passed && g.Details == "" { + t.Errorf("Gate %s failed but has empty Details", g.ID) + } + if g.Passed && g.Details != "" { + t.Errorf("Gate %s passed but has non-empty Details: %s", g.ID, g.Details) + } + } +} diff --git a/ai-compliance-sdk/internal/iace/engine.go b/ai-compliance-sdk/internal/iace/engine.go new file mode 100644 index 0000000..af915e8 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/engine.go @@ -0,0 +1,202 @@ +package iace + +import ( + "fmt" + "math" +) + +// RiskLevel, AssessRiskRequest, and RiskAssessment types are defined in models.go. +// This file only contains calculation methods. + +// RiskComputeInput contains the input parameters for the engine's risk computation. +type RiskComputeInput struct { + Severity int `json:"severity"` // 1-5 + Exposure int `json:"exposure"` // 1-5 + Probability int `json:"probability"` // 1-5 + ControlMaturity int `json:"control_maturity"` // 0-4 + ControlCoverage float64 `json:"control_coverage"` // 0-1 + TestEvidence float64 `json:"test_evidence"` // 0-1 + HasJustification bool `json:"has_justification"` +} + +// RiskComputeResult contains the output of the engine's risk computation. +type RiskComputeResult struct { + InherentRisk float64 `json:"inherent_risk"` + ControlEffectiveness float64 `json:"control_effectiveness"` + ResidualRisk float64 `json:"residual_risk"` + RiskLevel RiskLevel `json:"risk_level"` + IsAcceptable bool `json:"is_acceptable"` + AcceptanceReason string `json:"acceptance_reason"` +} + +// ============================================================================ +// RiskEngine +// ============================================================================ + +// RiskEngine provides methods for mathematical risk calculations +// according to the IACE (Inherent-risk Adjusted Control Effectiveness) model. +type RiskEngine struct{} + +// NewRiskEngine creates a new RiskEngine instance. +func NewRiskEngine() *RiskEngine { + return &RiskEngine{} +} + +// ============================================================================ +// Calculations +// ============================================================================ + +// clamp restricts v to the range [lo, hi]. +func clamp(v, lo, hi int) int { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +// clampFloat restricts v to the range [lo, hi]. +func clampFloat(v, lo, hi float64) float64 { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +// CalculateInherentRisk computes the inherent risk score as S * E * P. +// Each factor is expected in the range 1-5 and will be clamped if out of range. +func (e *RiskEngine) CalculateInherentRisk(severity, exposure, probability int) float64 { + s := clamp(severity, 1, 5) + ex := clamp(exposure, 1, 5) + p := clamp(probability, 1, 5) + return float64(s) * float64(ex) * float64(p) +} + +// CalculateControlEffectiveness computes the control effectiveness score. +// +// Formula: C_eff = min(1, 0.2*(maturity/4.0) + 0.5*coverage + 0.3*testEvidence) +// +// Parameters: +// - maturity: 0-4, clamped if out of range +// - coverage: 0-1, clamped if out of range +// - testEvidence: 0-1, clamped if out of range +// +// Returns a value between 0 and 1. +func (e *RiskEngine) CalculateControlEffectiveness(maturity int, coverage, testEvidence float64) float64 { + m := clamp(maturity, 0, 4) + cov := clampFloat(coverage, 0, 1) + te := clampFloat(testEvidence, 0, 1) + + cEff := 0.2*(float64(m)/4.0) + 0.5*cov + 0.3*te + return math.Min(1, cEff) +} + +// CalculateResidualRisk computes the residual risk after applying controls. +// +// Formula: R_residual = S * E * P * (1 - cEff) +// +// Parameters: +// - severity, exposure, probability: 1-5, clamped if out of range +// - cEff: control effectiveness, 0-1 +func (e *RiskEngine) CalculateResidualRisk(severity, exposure, probability int, cEff float64) float64 { + inherent := e.CalculateInherentRisk(severity, exposure, probability) + return inherent * (1 - cEff) +} + +// DetermineRiskLevel classifies the residual risk into a RiskLevel category. +// +// Thresholds: +// - >= 75: critical +// - >= 40: high +// - >= 15: medium +// - >= 5: low +// - < 5: negligible +func (e *RiskEngine) DetermineRiskLevel(residualRisk float64) RiskLevel { + switch { + case residualRisk >= 75: + return RiskLevelCritical + case residualRisk >= 40: + return RiskLevelHigh + case residualRisk >= 15: + return RiskLevelMedium + case residualRisk >= 5: + return RiskLevelLow + default: + return RiskLevelNegligible + } +} + +// IsAcceptable determines whether the residual risk is acceptable based on +// the ALARP (As Low As Reasonably Practicable) principle and EU AI Act thresholds. +// +// Decision logic: +// - residualRisk < 15: acceptable ("Restrisiko unter Schwellwert") +// - residualRisk < 40 AND allReductionStepsApplied AND hasJustification: +// acceptable under ALARP ("ALARP-Prinzip: Restrisiko akzeptabel mit vollstaendiger Risikominderung") +// - residualRisk >= 40: not acceptable ("Restrisiko zu hoch - blockiert CE-Export") +func (e *RiskEngine) IsAcceptable(residualRisk float64, allReductionStepsApplied bool, hasJustification bool) (bool, string) { + if residualRisk < 15 { + return true, "Restrisiko unter Schwellwert" + } + if residualRisk < 40 && allReductionStepsApplied && hasJustification { + return true, "ALARP-Prinzip: Restrisiko akzeptabel mit vollstaendiger Risikominderung" + } + return false, "Restrisiko zu hoch - blockiert CE-Export" +} + +// CalculateCompletenessScore computes a weighted completeness score (0-100). +// +// Formula: +// +// score = (passedRequired/totalRequired)*80 +// + (passedRecommended/totalRecommended)*15 +// + (passedOptional/totalOptional)*5 +// +// If any totalX is 0, that component contributes 0 to the score. +func (e *RiskEngine) CalculateCompletenessScore(passedRequired, totalRequired, passedRecommended, totalRecommended, passedOptional, totalOptional int) float64 { + var score float64 + + if totalRequired > 0 { + score += (float64(passedRequired) / float64(totalRequired)) * 80 + } + if totalRecommended > 0 { + score += (float64(passedRecommended) / float64(totalRecommended)) * 15 + } + if totalOptional > 0 { + score += (float64(passedOptional) / float64(totalOptional)) * 5 + } + + return score +} + +// ComputeRisk performs a complete risk computation using all calculation methods. +// It returns a RiskComputeResult with inherent risk, control effectiveness, residual risk, +// risk level, and acceptability. +// +// The allReductionStepsApplied parameter for IsAcceptable is set to false; +// the caller is responsible for updating acceptance status after reduction steps are applied. +func (e *RiskEngine) ComputeRisk(req RiskComputeInput) (*RiskComputeResult, error) { + if req.Severity < 1 || req.Exposure < 1 || req.Probability < 1 { + return nil, fmt.Errorf("severity, exposure, and probability must be >= 1") + } + + inherentRisk := e.CalculateInherentRisk(req.Severity, req.Exposure, req.Probability) + controlEff := e.CalculateControlEffectiveness(req.ControlMaturity, req.ControlCoverage, req.TestEvidence) + residualRisk := e.CalculateResidualRisk(req.Severity, req.Exposure, req.Probability, controlEff) + riskLevel := e.DetermineRiskLevel(residualRisk) + acceptable, reason := e.IsAcceptable(residualRisk, false, req.HasJustification) + + return &RiskComputeResult{ + InherentRisk: inherentRisk, + ControlEffectiveness: controlEff, + ResidualRisk: residualRisk, + RiskLevel: riskLevel, + IsAcceptable: acceptable, + AcceptanceReason: reason, + }, nil +} diff --git a/ai-compliance-sdk/internal/iace/engine_test.go b/ai-compliance-sdk/internal/iace/engine_test.go new file mode 100644 index 0000000..c7af96c --- /dev/null +++ b/ai-compliance-sdk/internal/iace/engine_test.go @@ -0,0 +1,936 @@ +package iace + +import ( + "math" + "testing" +) + +// ============================================================================ +// Helper +// ============================================================================ + +const floatTolerance = 1e-9 + +func almostEqual(a, b float64) bool { + return math.Abs(a-b) < floatTolerance +} + +// ============================================================================ +// 1. CalculateInherentRisk — S × E × P +// ============================================================================ + +func TestCalculateInherentRisk_BasicCases(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + s, ex, p int + expected float64 + }{ + // Minimum + {"min 1×1×1", 1, 1, 1, 1}, + // Maximum + {"max 5×5×5", 5, 5, 5, 125}, + // Single factor high + {"5×1×1", 5, 1, 1, 5}, + {"1×5×1", 1, 5, 1, 5}, + {"1×1×5", 1, 1, 5, 5}, + // Typical mid-range + {"3×3×3", 3, 3, 3, 27}, + {"2×4×3", 2, 4, 3, 24}, + {"4×2×5", 4, 2, 5, 40}, + // Boundary at thresholds + {"3×5×5 = 75 (critical threshold)", 3, 5, 5, 75}, + {"2×4×5 = 40 (high threshold)", 2, 4, 5, 40}, + {"3×5×1 = 15 (medium threshold)", 3, 5, 1, 15}, + {"5×1×1 = 5 (low threshold)", 5, 1, 1, 5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.CalculateInherentRisk(tt.s, tt.ex, tt.p) + if !almostEqual(result, tt.expected) { + t.Errorf("CalculateInherentRisk(%d, %d, %d) = %v, want %v", tt.s, tt.ex, tt.p, result, tt.expected) + } + }) + } +} + +func TestCalculateInherentRisk_Clamping(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + s, ex, p int + expected float64 + }{ + {"below min clamped to 1", 0, 0, 0, 1}, + {"negative clamped to 1", -5, -3, -1, 1}, + {"above max clamped to 5", 10, 8, 6, 125}, + {"mixed out-of-range", 0, 10, 3, 15}, // clamp(0,1,5)=1, clamp(10,1,5)=5, clamp(3,1,5)=3 → 1*5*3=15 + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.CalculateInherentRisk(tt.s, tt.ex, tt.p) + if !almostEqual(result, tt.expected) { + t.Errorf("CalculateInherentRisk(%d, %d, %d) = %v, want %v", tt.s, tt.ex, tt.p, result, tt.expected) + } + }) + } +} + +// Full S×E×P coverage: verify all 125 combinations produce correct multiplication. +func TestCalculateInherentRisk_FullCoverage(t *testing.T) { + e := NewRiskEngine() + + for s := 1; s <= 5; s++ { + for ex := 1; ex <= 5; ex++ { + for p := 1; p <= 5; p++ { + expected := float64(s * ex * p) + result := e.CalculateInherentRisk(s, ex, p) + if !almostEqual(result, expected) { + t.Errorf("CalculateInherentRisk(%d, %d, %d) = %v, want %v", s, ex, p, result, expected) + } + } + } + } +} + +// ============================================================================ +// 2. CalculateControlEffectiveness +// C_eff = min(1, 0.2*(maturity/4.0) + 0.5*coverage + 0.3*testEvidence) +// ============================================================================ + +func TestCalculateControlEffectiveness(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + maturity int + coverage float64 + testEvidence float64 + expected float64 + }{ + // All zeros → 0 + {"all zero", 0, 0.0, 0.0, 0.0}, + // All max → min(1, 0.2*1 + 0.5*1 + 0.3*1) = min(1, 1.0) = 1.0 + {"all max", 4, 1.0, 1.0, 1.0}, + // Only maturity max → 0.2 * (4/4) = 0.2 + {"maturity only", 4, 0.0, 0.0, 0.2}, + // Only coverage max → 0.5 + {"coverage only", 0, 1.0, 0.0, 0.5}, + // Only test evidence max → 0.3 + {"evidence only", 0, 0.0, 1.0, 0.3}, + // Half maturity → 0.2 * (2/4) = 0.1 + {"half maturity", 2, 0.0, 0.0, 0.1}, + // Typical mid-range: maturity=2, coverage=0.6, evidence=0.4 + // 0.2*(2/4) + 0.5*0.6 + 0.3*0.4 = 0.1 + 0.3 + 0.12 = 0.52 + {"typical mid", 2, 0.6, 0.4, 0.52}, + // High values exceeding 1.0 should be capped + // maturity=4, coverage=1.0, evidence=1.0 → 0.2+0.5+0.3 = 1.0 + {"capped at 1.0", 4, 1.0, 1.0, 1.0}, + // maturity=3, coverage=0.8, evidence=0.9 + // 0.2*(3/4) + 0.5*0.8 + 0.3*0.9 = 0.15 + 0.4 + 0.27 = 0.82 + {"high controls", 3, 0.8, 0.9, 0.82}, + // maturity=1, coverage=0.2, evidence=0.1 + // 0.2*(1/4) + 0.5*0.2 + 0.3*0.1 = 0.05 + 0.1 + 0.03 = 0.18 + {"low controls", 1, 0.2, 0.1, 0.18}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.CalculateControlEffectiveness(tt.maturity, tt.coverage, tt.testEvidence) + if !almostEqual(result, tt.expected) { + t.Errorf("CalculateControlEffectiveness(%d, %v, %v) = %v, want %v", + tt.maturity, tt.coverage, tt.testEvidence, result, tt.expected) + } + }) + } +} + +func TestCalculateControlEffectiveness_Clamping(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + maturity int + coverage float64 + testEvidence float64 + expected float64 + }{ + // Maturity below 0 → clamped to 0 + {"maturity below zero", -1, 0.5, 0.5, 0.5*0.5 + 0.3*0.5}, + // Maturity above 4 → clamped to 4 + {"maturity above max", 10, 0.0, 0.0, 0.2}, + // Coverage below 0 → clamped to 0 + {"coverage below zero", 0, -0.5, 0.0, 0.0}, + // Coverage above 1 → clamped to 1 + {"coverage above max", 0, 2.0, 0.0, 0.5}, + // Evidence below 0 → clamped to 0 + {"evidence below zero", 0, 0.0, -1.0, 0.0}, + // Evidence above 1 → clamped to 1 + {"evidence above max", 0, 0.0, 5.0, 0.3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.CalculateControlEffectiveness(tt.maturity, tt.coverage, tt.testEvidence) + if !almostEqual(result, tt.expected) { + t.Errorf("CalculateControlEffectiveness(%d, %v, %v) = %v, want %v", + tt.maturity, tt.coverage, tt.testEvidence, result, tt.expected) + } + }) + } +} + +// ============================================================================ +// 3. CalculateResidualRisk — R_residual = S × E × P × (1 - C_eff) +// ============================================================================ + +func TestCalculateResidualRisk(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + s, ex, p int + cEff float64 + expected float64 + }{ + // No controls → residual = inherent + {"no controls", 5, 5, 5, 0.0, 125.0}, + // Perfect controls → residual = 0 + {"perfect controls", 5, 5, 5, 1.0, 0.0}, + // Half effectiveness + {"half controls 3×3×3", 3, 3, 3, 0.5, 13.5}, + // Typical scenario: inherent=40, cEff=0.6 → residual=16 + {"typical 2×4×5 cEff=0.6", 2, 4, 5, 0.6, 16.0}, + // Low risk with some controls + {"low 1×2×3 cEff=0.3", 1, 2, 3, 0.3, 4.2}, + // High risk with strong controls + {"high 5×4×4 cEff=0.82", 5, 4, 4, 0.82, 14.4}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.CalculateResidualRisk(tt.s, tt.ex, tt.p, tt.cEff) + if !almostEqual(result, tt.expected) { + t.Errorf("CalculateResidualRisk(%d, %d, %d, %v) = %v, want %v", + tt.s, tt.ex, tt.p, tt.cEff, result, tt.expected) + } + }) + } +} + +// ============================================================================ +// 4. DetermineRiskLevel — threshold classification +// ============================================================================ + +func TestDetermineRiskLevel(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + residual float64 + expected RiskLevel + }{ + // Critical: >= 75 + {"critical at 75", 75.0, RiskLevelCritical}, + {"critical at 125", 125.0, RiskLevelCritical}, + {"critical at 100", 100.0, RiskLevelCritical}, + // High: >= 40 + {"high at 40", 40.0, RiskLevelHigh}, + {"high at 74.9", 74.9, RiskLevelHigh}, + {"high at 50", 50.0, RiskLevelHigh}, + // Medium: >= 15 + {"medium at 15", 15.0, RiskLevelMedium}, + {"medium at 39.9", 39.9, RiskLevelMedium}, + {"medium at 27", 27.0, RiskLevelMedium}, + // Low: >= 5 + {"low at 5", 5.0, RiskLevelLow}, + {"low at 14.9", 14.9, RiskLevelLow}, + {"low at 10", 10.0, RiskLevelLow}, + // Negligible: < 5 + {"negligible at 4.9", 4.9, RiskLevelNegligible}, + {"negligible at 0", 0.0, RiskLevelNegligible}, + {"negligible at 1", 1.0, RiskLevelNegligible}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.DetermineRiskLevel(tt.residual) + if result != tt.expected { + t.Errorf("DetermineRiskLevel(%v) = %v, want %v", tt.residual, result, tt.expected) + } + }) + } +} + +// ============================================================================ +// 5. IsAcceptable — ALARP principle +// ============================================================================ + +func TestIsAcceptable(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + residual float64 + allReduction bool + justification bool + wantAcceptable bool + wantReason string + }{ + // Below 15 → always acceptable + {"residual 14.9 always ok", 14.9, false, false, true, "Restrisiko unter Schwellwert"}, + {"residual 0 always ok", 0.0, false, false, true, "Restrisiko unter Schwellwert"}, + {"residual 10 always ok", 10.0, false, false, true, "Restrisiko unter Schwellwert"}, + // 15-39.9 with all reduction + justification → ALARP + {"ALARP 20 all+just", 20.0, true, true, true, "ALARP-Prinzip: Restrisiko akzeptabel mit vollstaendiger Risikominderung"}, + {"ALARP 39.9 all+just", 39.9, true, true, true, "ALARP-Prinzip: Restrisiko akzeptabel mit vollstaendiger Risikominderung"}, + {"ALARP 15 all+just", 15.0, true, true, true, "ALARP-Prinzip: Restrisiko akzeptabel mit vollstaendiger Risikominderung"}, + // 15-39.9 without all reduction → NOT acceptable + {"no reduction 20", 20.0, false, true, false, "Restrisiko zu hoch - blockiert CE-Export"}, + // 15-39.9 without justification → NOT acceptable + {"no justification 20", 20.0, true, false, false, "Restrisiko zu hoch - blockiert CE-Export"}, + // 15-39.9 without either → NOT acceptable + {"neither 30", 30.0, false, false, false, "Restrisiko zu hoch - blockiert CE-Export"}, + // >= 40 → NEVER acceptable + {"residual 40 blocked", 40.0, true, true, false, "Restrisiko zu hoch - blockiert CE-Export"}, + {"residual 75 blocked", 75.0, true, true, false, "Restrisiko zu hoch - blockiert CE-Export"}, + {"residual 125 blocked", 125.0, true, true, false, "Restrisiko zu hoch - blockiert CE-Export"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + acceptable, reason := e.IsAcceptable(tt.residual, tt.allReduction, tt.justification) + if acceptable != tt.wantAcceptable { + t.Errorf("IsAcceptable(%v, %v, %v) acceptable = %v, want %v", + tt.residual, tt.allReduction, tt.justification, acceptable, tt.wantAcceptable) + } + if reason != tt.wantReason { + t.Errorf("IsAcceptable(%v, %v, %v) reason = %q, want %q", + tt.residual, tt.allReduction, tt.justification, reason, tt.wantReason) + } + }) + } +} + +// ============================================================================ +// 6. CalculateCompletenessScore +// ============================================================================ + +func TestCalculateCompletenessScore(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + passedReq, totalReq, passedRec, totalRec, passedOpt, totalOpt int + expected float64 + }{ + // All passed + {"all passed", 20, 20, 5, 5, 3, 3, 100.0}, + // Nothing passed + {"nothing passed", 0, 20, 0, 5, 0, 3, 0.0}, + // Only required fully passed + {"only required", 20, 20, 0, 5, 0, 3, 80.0}, + // Only recommended fully passed + {"only recommended", 0, 20, 5, 5, 0, 3, 15.0}, + // Only optional fully passed + {"only optional", 0, 20, 0, 5, 3, 3, 5.0}, + // Half required, no others + {"half required", 10, 20, 0, 5, 0, 3, 40.0}, + // All zero totals → 0 (division by zero safety) + {"all zero totals", 0, 0, 0, 0, 0, 0, 0.0}, + // Typical: 18/20 req + 3/5 rec + 1/3 opt + // (18/20)*80 + (3/5)*15 + (1/3)*5 = 72 + 9 + 1.6667 = 82.6667 + {"typical", 18, 20, 3, 5, 1, 3, 72.0 + 9.0 + 5.0/3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := e.CalculateCompletenessScore( + tt.passedReq, tt.totalReq, tt.passedRec, tt.totalRec, tt.passedOpt, tt.totalOpt) + if !almostEqual(result, tt.expected) { + t.Errorf("CalculateCompletenessScore(%d/%d, %d/%d, %d/%d) = %v, want %v", + tt.passedReq, tt.totalReq, tt.passedRec, tt.totalRec, tt.passedOpt, tt.totalOpt, + result, tt.expected) + } + }) + } +} + +// ============================================================================ +// 7. ComputeRisk — integration test +// ============================================================================ + +func TestComputeRisk_ValidInput(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + input RiskComputeInput + wantInherent float64 + wantCEff float64 + wantResidual float64 + wantLevel RiskLevel + wantAcceptable bool + }{ + { + name: "no controls high risk", + input: RiskComputeInput{ + Severity: 5, Exposure: 5, Probability: 5, + ControlMaturity: 0, ControlCoverage: 0, TestEvidence: 0, + }, + wantInherent: 125, + wantCEff: 0, + wantResidual: 125, + wantLevel: RiskLevelCritical, + wantAcceptable: false, + }, + { + name: "perfect controls zero residual", + input: RiskComputeInput{ + Severity: 5, Exposure: 5, Probability: 5, + ControlMaturity: 4, ControlCoverage: 1.0, TestEvidence: 1.0, + }, + wantInherent: 125, + wantCEff: 1.0, + wantResidual: 0, + wantLevel: RiskLevelNegligible, + wantAcceptable: true, + }, + { + name: "medium risk acceptable", + input: RiskComputeInput{ + Severity: 2, Exposure: 2, Probability: 2, + ControlMaturity: 2, ControlCoverage: 0.5, TestEvidence: 0.5, + // C_eff = 0.2*(2/4) + 0.5*0.5 + 0.3*0.5 = 0.1 + 0.25 + 0.15 = 0.5 + // inherent = 8, residual = 8 * 0.5 = 4 → negligible → acceptable + }, + wantInherent: 8, + wantCEff: 0.5, + wantResidual: 4, + wantLevel: RiskLevelNegligible, + wantAcceptable: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := e.ComputeRisk(tt.input) + if err != nil { + t.Fatalf("ComputeRisk returned error: %v", err) + } + if !almostEqual(result.InherentRisk, tt.wantInherent) { + t.Errorf("InherentRisk = %v, want %v", result.InherentRisk, tt.wantInherent) + } + if !almostEqual(result.ControlEffectiveness, tt.wantCEff) { + t.Errorf("ControlEffectiveness = %v, want %v", result.ControlEffectiveness, tt.wantCEff) + } + if !almostEqual(result.ResidualRisk, tt.wantResidual) { + t.Errorf("ResidualRisk = %v, want %v", result.ResidualRisk, tt.wantResidual) + } + if result.RiskLevel != tt.wantLevel { + t.Errorf("RiskLevel = %v, want %v", result.RiskLevel, tt.wantLevel) + } + if result.IsAcceptable != tt.wantAcceptable { + t.Errorf("IsAcceptable = %v, want %v", result.IsAcceptable, tt.wantAcceptable) + } + }) + } +} + +func TestComputeRisk_InvalidInput(t *testing.T) { + e := NewRiskEngine() + + tests := []struct { + name string + input RiskComputeInput + }{ + {"severity zero", RiskComputeInput{Severity: 0, Exposure: 3, Probability: 3}}, + {"exposure zero", RiskComputeInput{Severity: 3, Exposure: 0, Probability: 3}}, + {"probability zero", RiskComputeInput{Severity: 3, Exposure: 3, Probability: 0}}, + {"all zero", RiskComputeInput{Severity: 0, Exposure: 0, Probability: 0}}, + {"negative values", RiskComputeInput{Severity: -1, Exposure: -2, Probability: -3}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := e.ComputeRisk(tt.input) + if err == nil { + t.Errorf("ComputeRisk expected error for input %+v, got result %+v", tt.input, result) + } + }) + } +} + +// ============================================================================ +// 8. Golden Test Suite — 10 Referenzmaschinen (Industrial Machine Scenarios) +// +// Each machine scenario tests the full pipeline: +// Inherent Risk → Control Effectiveness → Residual Risk → Risk Level → Acceptability +// +// The control parameters reflect realistic mitigation states for each machine type. +// These serve as regression tests: if any threshold or formula changes, +// these tests will catch the impact immediately. +// ============================================================================ + +// referenceMachine defines a complete end-to-end test scenario for a real machine type. +type referenceMachine struct { + name string + description string + + // Inherent risk factors (pre-mitigation) + severity int // 1-5 + exposure int // 1-5 + probability int // 1-5 + + // Control parameters (mitigation state) + controlMaturity int // 0-4 + controlCoverage float64 // 0-1 + testEvidence float64 // 0-1 + hasJustification bool + + // Expected outputs + expectedInherentRisk float64 + expectedCEff float64 + expectedResidualRisk float64 + expectedRiskLevel RiskLevel + expectedAcceptable bool +} + +// computeExpectedCEff calculates the expected control effectiveness for documentation/verification. +func computeExpectedCEff(maturity int, coverage, testEvidence float64) float64 { + cEff := 0.2*(float64(maturity)/4.0) + 0.5*coverage + 0.3*testEvidence + if cEff > 1.0 { + return 1.0 + } + return cEff +} + +func getReferenceMachines() []referenceMachine { + return []referenceMachine{ + // --------------------------------------------------------------- + // 1. Industrieroboter-Zelle mit Schutzzaun + // 6-Achs-Roboter, Materialhandling, Schutzzaun + Lichtschranke. + // Hauptgefaehrdung: Quetschung/Kollision bei Betreten. + // Hohe Massnahmen: Sicherheits-SPS, zweikanalige Tuerueberwachung. + // --------------------------------------------------------------- + { + name: "Industrieroboter-Zelle", + description: "6-Achs-Roboter mit Schutzzaun, Quetsch-/Kollisionsgefahr", + severity: 5, // Lebensgefaehrlich + exposure: 3, // Regelmaessiger Zugang (Wartung) + probability: 4, // Wahrscheinlich bei offenem Zugang + controlMaturity: 3, + controlCoverage: 0.8, + testEvidence: 0.7, + hasJustification: false, + // Inherent: 5*3*4 = 60 + // C_eff: 0.2*(3/4) + 0.5*0.8 + 0.3*0.7 = 0.15 + 0.4 + 0.21 = 0.76 + // Residual: 60 * (1-0.76) = 60 * 0.24 = 14.4 + // Level: medium (>=5, <15) → actually 14.4 < 15 → low + // Acceptable: 14.4 < 15 → yes + expectedInherentRisk: 60, + expectedCEff: 0.76, + expectedResidualRisk: 60 * (1 - 0.76), + expectedRiskLevel: RiskLevelLow, + expectedAcceptable: true, + }, + // --------------------------------------------------------------- + // 2. CNC-Fraesmaschine mit automatischem Werkzeugwechsel + // Werkzeugbruch, Spaeneflug. Vollschutzkabine + Drehzahlueberwachung. + // --------------------------------------------------------------- + { + name: "CNC-Fraesmaschine", + description: "CNC mit Werkzeugwechsel, Werkzeugbruch/Spaeneflug", + severity: 4, // Schwere Verletzung + exposure: 3, // Bediener steht regelmaessig davor + probability: 3, // Moeglich + controlMaturity: 3, + controlCoverage: 0.9, + testEvidence: 0.8, + hasJustification: true, + // Inherent: 4*3*3 = 36 + // C_eff: 0.2*(3/4) + 0.5*0.9 + 0.3*0.8 = 0.15 + 0.45 + 0.24 = 0.84 + // Residual: 36 * 0.16 = 5.76 + // Level: low (>=5, <15) + // Acceptable: 5.76 < 15 → yes + expectedInherentRisk: 36, + expectedCEff: 0.84, + expectedResidualRisk: 36 * (1 - 0.84), + expectedRiskLevel: RiskLevelLow, + expectedAcceptable: true, + }, + // --------------------------------------------------------------- + // 3. Verpackungsmaschine mit Schneideeinheit + // Foerderband + Klinge. Schnittverletzung. + // Zweihandbedienung, Sicherheitsrelais, Abdeckung. + // --------------------------------------------------------------- + { + name: "Verpackungsmaschine", + description: "Foerderband + Schneideeinheit, Schnittverletzungsgefahr", + severity: 4, // Schwere Schnittverletzung + exposure: 4, // Dauerbetrieb mit Bediener + probability: 3, // Moeglich + controlMaturity: 2, + controlCoverage: 0.7, + testEvidence: 0.5, + hasJustification: true, + // Inherent: 4*4*3 = 48 + // C_eff: 0.2*(2/4) + 0.5*0.7 + 0.3*0.5 = 0.1 + 0.35 + 0.15 = 0.6 + // Residual: 48 * 0.4 = 19.2 + // Level: medium (>=15, <40) + // Acceptable: 19.2 >= 15, allReduction=false (ComputeRisk default) → NOT acceptable + expectedInherentRisk: 48, + expectedCEff: 0.6, + expectedResidualRisk: 48 * (1 - 0.6), + expectedRiskLevel: RiskLevelMedium, + expectedAcceptable: false, // ComputeRisk sets allReductionStepsApplied=false + }, + // --------------------------------------------------------------- + // 4. Automatisierte Pressanlage + // Quetschung im Pressbereich. Hoechste Gefaehrdung. + // Lichtvorhang, Kat-4-Steuerung, mechanische Verriegelung. + // --------------------------------------------------------------- + { + name: "Pressanlage", + description: "Automatische Presse, Quetschgefahr im Pressbereich", + severity: 5, // Toedlich + exposure: 4, // Bediener staendig im Bereich + probability: 4, // Wahrscheinlich ohne Schutz + controlMaturity: 4, + controlCoverage: 0.9, + testEvidence: 0.9, + hasJustification: true, + // Inherent: 5*4*4 = 80 + // C_eff: 0.2*(4/4) + 0.5*0.9 + 0.3*0.9 = 0.2 + 0.45 + 0.27 = 0.92 + // Residual: 80 * 0.08 = 6.4 + // Level: low (>=5, <15) + // Acceptable: 6.4 < 15 → yes + expectedInherentRisk: 80, + expectedCEff: 0.92, + expectedResidualRisk: 80 * (1 - 0.92), + expectedRiskLevel: RiskLevelLow, + expectedAcceptable: true, + }, + // --------------------------------------------------------------- + // 5. Lasergravur-Anlage (Klasse 4) + // Augenverletzung durch Laserstrahl. + // Geschlossene Kabine, Interlock. + // --------------------------------------------------------------- + { + name: "Lasergravur-Anlage", + description: "Klasse-4-Laser, Augenverletzungsgefahr", + severity: 5, // Irreversible Augenschaeden + exposure: 2, // Selten direkter Zugang + probability: 2, // Selten bei geschlossener Kabine + controlMaturity: 3, + controlCoverage: 0.95, + testEvidence: 0.8, + hasJustification: false, + // Inherent: 5*2*2 = 20 + // C_eff: 0.2*(3/4) + 0.5*0.95 + 0.3*0.8 = 0.15 + 0.475 + 0.24 = 0.865 + // Residual: 20 * 0.135 = 2.7 + // Level: negligible (<5) + // Acceptable: 2.7 < 15 → yes + expectedInherentRisk: 20, + expectedCEff: 0.865, + expectedResidualRisk: 20 * (1 - 0.865), + expectedRiskLevel: RiskLevelNegligible, + expectedAcceptable: true, + }, + // --------------------------------------------------------------- + // 6. Fahrerloses Transportsystem (AGV) + // Kollision mit Personen in Produktionshalle. + // Laserscanner, Not-Aus, Geschwindigkeitsbegrenzung. + // --------------------------------------------------------------- + { + name: "AGV (Fahrerloses Transportsystem)", + description: "Autonomes Fahrzeug, Kollisionsgefahr mit Personen", + severity: 4, // Schwere Verletzung + exposure: 4, // Dauerhaft Personen in der Naehe + probability: 3, // Moeglich in offener Umgebung + controlMaturity: 3, + controlCoverage: 0.7, + testEvidence: 0.6, + hasJustification: true, + // Inherent: 4*4*3 = 48 + // C_eff: 0.2*(3/4) + 0.5*0.7 + 0.3*0.6 = 0.15 + 0.35 + 0.18 = 0.68 + // Residual: 48 * 0.32 = 15.36 + // Level: medium (>=15, <40) + // Acceptable: 15.36 >= 15, allReduction=false → NOT acceptable + expectedInherentRisk: 48, + expectedCEff: 0.68, + expectedResidualRisk: 48 * (1 - 0.68), + expectedRiskLevel: RiskLevelMedium, + expectedAcceptable: false, // ComputeRisk sets allReductionStepsApplied=false + }, + // --------------------------------------------------------------- + // 7. Abfuellanlage fuer Chemikalien + // Kontakt mit gefaehrlichem Medium. + // Geschlossene Leitungen, Leckageerkennung. + // --------------------------------------------------------------- + { + name: "Abfuellanlage Chemikalien", + description: "Chemikalien-Abfuellung, Kontaktgefahr", + severity: 4, // Schwere Veraetzung + exposure: 2, // Gelegentlicher Zugang + probability: 2, // Selten bei geschlossenen Leitungen + controlMaturity: 3, + controlCoverage: 0.85, + testEvidence: 0.7, + hasJustification: false, + // Inherent: 4*2*2 = 16 + // C_eff: 0.2*(3/4) + 0.5*0.85 + 0.3*0.7 = 0.15 + 0.425 + 0.21 = 0.785 + // Residual: 16 * 0.215 = 3.44 + // Level: negligible (<5) + // Acceptable: 3.44 < 15 → yes + expectedInherentRisk: 16, + expectedCEff: 0.785, + expectedResidualRisk: 16 * (1 - 0.785), + expectedRiskLevel: RiskLevelNegligible, + expectedAcceptable: true, + }, + // --------------------------------------------------------------- + // 8. Industrie-3D-Drucker (Metallpulver) + // Feinstaub-Inhalation, Explosionsgefahr. + // Absauganlage, Explosionsschutz. + // --------------------------------------------------------------- + { + name: "Industrie-3D-Drucker Metallpulver", + description: "Metallpulver-Drucker, Feinstaub und ATEX", + severity: 4, // Schwere Lungenschaeden / Explosion + exposure: 3, // Regelmaessig (Druckjob-Wechsel) + probability: 3, // Moeglich + controlMaturity: 2, + controlCoverage: 0.6, + testEvidence: 0.5, + hasJustification: true, + // Inherent: 4*3*3 = 36 + // C_eff: 0.2*(2/4) + 0.5*0.6 + 0.3*0.5 = 0.1 + 0.3 + 0.15 = 0.55 + // Residual: 36 * 0.45 = 16.2 + // Level: medium (>=15, <40) + // Acceptable: 16.2 >= 15, allReduction=false → NOT acceptable + expectedInherentRisk: 36, + expectedCEff: 0.55, + expectedResidualRisk: 36 * (1 - 0.55), + expectedRiskLevel: RiskLevelMedium, + expectedAcceptable: false, + }, + // --------------------------------------------------------------- + // 9. Automatisches Hochregallager + // Absturz von Lasten, Regalbediengeraet. + // Lastsicherung, Sensoren, regelmaessige Wartung. + // --------------------------------------------------------------- + { + name: "Hochregallager", + description: "Automatisches Regallager, Lastabsturzgefahr", + severity: 5, // Toedlich bei Absturz schwerer Lasten + exposure: 2, // Selten (automatisiert, Wartungszugang) + probability: 2, // Selten bei ordnungsgemaessem Betrieb + controlMaturity: 3, + controlCoverage: 0.8, + testEvidence: 0.7, + hasJustification: false, + // Inherent: 5*2*2 = 20 + // C_eff: 0.2*(3/4) + 0.5*0.8 + 0.3*0.7 = 0.15 + 0.4 + 0.21 = 0.76 + // Residual: 20 * 0.24 = 4.8 + // Level: negligible (<5) + // Acceptable: 4.8 < 15 → yes + expectedInherentRisk: 20, + expectedCEff: 0.76, + expectedResidualRisk: 20 * (1 - 0.76), + expectedRiskLevel: RiskLevelNegligible, + expectedAcceptable: true, + }, + // --------------------------------------------------------------- + // 10. KI-Bildverarbeitung Qualitaetskontrolle + // Fehlklassifikation → sicherheitsrelevantes Bauteil wird freigegeben. + // Redundante Pruefung, Validierungsdatensatz, KI-Risikobeurteilung. + // AI Act relevant. + // --------------------------------------------------------------- + { + name: "KI-Qualitaetskontrolle", + description: "KI-Vision fuer Bauteilpruefung, Fehlklassifikationsgefahr (AI Act relevant)", + severity: 4, // Fehlerhaftes Sicherheitsbauteil im Feld + exposure: 5, // Kontinuierlich (jedes Bauteil) + probability: 3, // Moeglich (ML nie 100% korrekt) + controlMaturity: 2, + controlCoverage: 0.5, + testEvidence: 0.6, + hasJustification: true, + // Inherent: 4*5*3 = 60 + // C_eff: 0.2*(2/4) + 0.5*0.5 + 0.3*0.6 = 0.1 + 0.25 + 0.18 = 0.53 + // Residual: 60 * 0.47 = 28.2 + // Level: medium (>=15, <40) + // Acceptable: 28.2 >= 15, allReduction=false → NOT acceptable + expectedInherentRisk: 60, + expectedCEff: 0.53, + expectedResidualRisk: 60 * (1 - 0.53), + expectedRiskLevel: RiskLevelMedium, + expectedAcceptable: false, + }, + } +} + +func TestReferenceMachines_ComputeRisk(t *testing.T) { + engine := NewRiskEngine() + machines := getReferenceMachines() + + for _, m := range machines { + t.Run(m.name, func(t *testing.T) { + input := RiskComputeInput{ + Severity: m.severity, + Exposure: m.exposure, + Probability: m.probability, + ControlMaturity: m.controlMaturity, + ControlCoverage: m.controlCoverage, + TestEvidence: m.testEvidence, + HasJustification: m.hasJustification, + } + + result, err := engine.ComputeRisk(input) + if err != nil { + t.Fatalf("ComputeRisk returned error: %v", err) + } + + // Verify inherent risk + if !almostEqual(result.InherentRisk, m.expectedInherentRisk) { + t.Errorf("InherentRisk = %v, want %v", result.InherentRisk, m.expectedInherentRisk) + } + + // Verify control effectiveness + if !almostEqual(result.ControlEffectiveness, m.expectedCEff) { + t.Errorf("ControlEffectiveness = %v, want %v", result.ControlEffectiveness, m.expectedCEff) + } + + // Verify residual risk + if !almostEqual(result.ResidualRisk, m.expectedResidualRisk) { + t.Errorf("ResidualRisk = %v, want %v (diff: %v)", + result.ResidualRisk, m.expectedResidualRisk, + math.Abs(result.ResidualRisk-m.expectedResidualRisk)) + } + + // Verify risk level + if result.RiskLevel != m.expectedRiskLevel { + t.Errorf("RiskLevel = %v, want %v (residual=%v)", + result.RiskLevel, m.expectedRiskLevel, result.ResidualRisk) + } + + // Verify acceptability + if result.IsAcceptable != m.expectedAcceptable { + t.Errorf("IsAcceptable = %v, want %v (residual=%v, level=%v)", + result.IsAcceptable, m.expectedAcceptable, result.ResidualRisk, result.RiskLevel) + } + }) + } +} + +// TestReferenceMachines_InherentRiskDistribution verifies that the 10 machines +// cover a meaningful range of inherent risk values (not all clustered). +func TestReferenceMachines_InherentRiskDistribution(t *testing.T) { + machines := getReferenceMachines() + + var minRisk, maxRisk float64 + minRisk = 999 + levelCounts := map[RiskLevel]int{} + + for _, m := range machines { + if m.expectedInherentRisk < minRisk { + minRisk = m.expectedInherentRisk + } + if m.expectedInherentRisk > maxRisk { + maxRisk = m.expectedInherentRisk + } + levelCounts[m.expectedRiskLevel]++ + } + + // Should span a meaningful range + if maxRisk-minRisk < 40 { + t.Errorf("Inherent risk range too narrow: [%v, %v], want spread >= 40", minRisk, maxRisk) + } + + // Should cover at least 3 different risk levels + if len(levelCounts) < 3 { + t.Errorf("Only %d risk levels covered, want at least 3: %v", len(levelCounts), levelCounts) + } +} + +// TestReferenceMachines_AcceptabilityMix verifies that the test suite has both +// acceptable and unacceptable outcomes. +func TestReferenceMachines_AcceptabilityMix(t *testing.T) { + machines := getReferenceMachines() + + acceptableCount := 0 + unacceptableCount := 0 + for _, m := range machines { + if m.expectedAcceptable { + acceptableCount++ + } else { + unacceptableCount++ + } + } + + if acceptableCount == 0 { + t.Error("No acceptable machines in test suite — need at least one") + } + if unacceptableCount == 0 { + t.Error("No unacceptable machines in test suite — need at least one") + } + + t.Logf("Acceptability mix: %d acceptable, %d unacceptable out of %d machines", + acceptableCount, unacceptableCount, len(machines)) +} + +// ============================================================================ +// 9. Edge Cases +// ============================================================================ + +func TestClamp(t *testing.T) { + tests := []struct { + name string + v, lo, hi int + expected int + }{ + {"in range", 3, 1, 5, 3}, + {"at low", 1, 1, 5, 1}, + {"at high", 5, 1, 5, 5}, + {"below low", 0, 1, 5, 1}, + {"above high", 10, 1, 5, 5}, + {"negative", -100, 0, 4, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := clamp(tt.v, tt.lo, tt.hi) + if result != tt.expected { + t.Errorf("clamp(%d, %d, %d) = %d, want %d", tt.v, tt.lo, tt.hi, result, tt.expected) + } + }) + } +} + +func TestClampFloat(t *testing.T) { + tests := []struct { + name string + v, lo, hi float64 + expected float64 + }{ + {"in range", 0.5, 0, 1, 0.5}, + {"at low", 0, 0, 1, 0}, + {"at high", 1.0, 0, 1, 1.0}, + {"below low", -0.5, 0, 1, 0}, + {"above high", 2.5, 0, 1, 1.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := clampFloat(tt.v, tt.lo, tt.hi) + if !almostEqual(result, tt.expected) { + t.Errorf("clampFloat(%v, %v, %v) = %v, want %v", tt.v, tt.lo, tt.hi, result, tt.expected) + } + }) + } +} diff --git a/ai-compliance-sdk/internal/iace/hazard_library.go b/ai-compliance-sdk/internal/iace/hazard_library.go new file mode 100644 index 0000000..ac3a5e6 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/hazard_library.go @@ -0,0 +1,606 @@ +package iace + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" +) + +// hazardUUID generates a deterministic UUID for a hazard library entry +// based on category and a 1-based index within that category. +func hazardUUID(category string, index int) uuid.UUID { + name := fmt.Sprintf("iace.hazard.%s.%d", category, index) + return uuid.NewSHA1(uuid.NameSpaceDNS, []byte(name)) +} + +// mustMarshalJSON marshals the given value to json.RawMessage, panicking on error. +// This is safe to use for static data known at compile time. +func mustMarshalJSON(v interface{}) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + panic(fmt.Sprintf("hazard_library: failed to marshal JSON: %v", err)) + } + return data +} + +// GetBuiltinHazardLibrary returns the complete built-in hazard library with 40+ +// template entries for SW/FW/KI hazards in industrial machines. These entries are +// intended to be seeded into the iace_hazard_library table during initial setup. +// +// All entries have IsBuiltin=true and TenantID=nil (system-level templates). +// UUIDs are deterministic, generated via uuid.NewSHA1 based on category and index. +func GetBuiltinHazardLibrary() []HazardLibraryEntry { + now := time.Now() + + entries := []HazardLibraryEntry{ + // ==================================================================== + // Category: false_classification (4 entries) + // ==================================================================== + { + ID: hazardUUID("false_classification", 1), + Category: "false_classification", + Name: "Falsche Bauteil-Klassifikation durch KI", + Description: "Das KI-Modell klassifiziert ein Bauteil fehlerhaft, was zu falscher Weiterverarbeitung oder Montage fuehren kann.", + DefaultSeverity: 4, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"ai_model", "sensor"}, + RegulationReferences: []string{"EU AI Act Art. 9", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Redundante Pruefung", "Konfidenz-Schwellwert"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("false_classification", 2), + Category: "false_classification", + Name: "Falsche Qualitaetsentscheidung (IO/NIO)", + Description: "Fehlerhafte IO/NIO-Entscheidung durch das KI-System fuehrt dazu, dass defekte Teile als gut bewertet oder gute Teile verworfen werden.", + DefaultSeverity: 4, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"ai_model", "software"}, + RegulationReferences: []string{"EU AI Act Art. 9", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Human-in-the-Loop", "Stichproben-Gegenpruefung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("false_classification", 3), + Category: "false_classification", + Name: "Fehlklassifikation bei Grenzwertfaellen", + Description: "Bauteile nahe an Toleranzgrenzen werden systematisch falsch klassifiziert, da das Modell in Grenzwertbereichen unsicher agiert.", + DefaultSeverity: 3, + DefaultProbability: 4, + ApplicableComponentTypes: []string{"ai_model"}, + RegulationReferences: []string{"EU AI Act Art. 9", "ISO 13849"}, + SuggestedMitigations: mustMarshalJSON([]string{"Erweitertes Training", "Grauzone-Eskalation"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("false_classification", 4), + Category: "false_classification", + Name: "Verwechslung von Bauteiltypen", + Description: "Unterschiedliche Bauteiltypen werden vom KI-Modell verwechselt, was zu falscher Montage oder Verarbeitung fuehrt.", + DefaultSeverity: 4, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"ai_model", "sensor"}, + RegulationReferences: []string{"EU AI Act Art. 9", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Barcode-Gegenpruefung", "Doppelte Sensorik"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: timing_error (3 entries) + // ==================================================================== + { + ID: hazardUUID("timing_error", 1), + Category: "timing_error", + Name: "Verzoegerte KI-Reaktion in Echtzeitsystem", + Description: "Die KI-Inferenz dauert laenger als die zulaessige Echtzeitfrist, was zu verspaeteten Sicherheitsreaktionen fuehrt.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"software", "ai_model"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849", "IEC 62443"}, + SuggestedMitigations: mustMarshalJSON([]string{"Watchdog-Timer", "Fallback-Steuerung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("timing_error", 2), + Category: "timing_error", + Name: "Echtzeit-Verletzung Safety-Loop", + Description: "Der sicherheitsgerichtete Regelkreis kann die geforderten Zykluszeiten nicht einhalten, wodurch Sicherheitsfunktionen versagen koennen.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"software", "firmware"}, + RegulationReferences: []string{"ISO 13849", "IEC 61508", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Deterministische Ausfuehrung", "WCET-Analyse"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("timing_error", 3), + Category: "timing_error", + Name: "Timing-Jitter bei Netzwerkkommunikation", + Description: "Schwankende Netzwerklatenzen fuehren zu unvorhersehbaren Verzoegerungen in der Datenuebertragung sicherheitsrelevanter Signale.", + DefaultSeverity: 3, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"network", "software"}, + RegulationReferences: []string{"IEC 62443", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"TSN-Netzwerk", "Pufferung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: data_poisoning (2 entries) + // ==================================================================== + { + ID: hazardUUID("data_poisoning", 1), + Category: "data_poisoning", + Name: "Manipulierte Trainingsdaten", + Description: "Trainingsdaten werden absichtlich oder unbeabsichtigt manipuliert, wodurch das Modell systematisch fehlerhafte Entscheidungen trifft.", + DefaultSeverity: 4, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"ai_model"}, + RegulationReferences: []string{"EU AI Act Art. 10", "CRA"}, + SuggestedMitigations: mustMarshalJSON([]string{"Daten-Validierung", "Anomalie-Erkennung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("data_poisoning", 2), + Category: "data_poisoning", + Name: "Adversarial Input Angriff", + Description: "Gezielte Manipulation von Eingabedaten (z.B. Bilder, Sensorsignale), um das KI-Modell zu taeuschen und Fehlentscheidungen auszuloesen.", + DefaultSeverity: 4, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"ai_model", "sensor"}, + RegulationReferences: []string{"EU AI Act Art. 15", "CRA", "IEC 62443"}, + SuggestedMitigations: mustMarshalJSON([]string{"Input-Validation", "Adversarial Training"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: model_drift (3 entries) + // ==================================================================== + { + ID: hazardUUID("model_drift", 1), + Category: "model_drift", + Name: "Performance-Degradation durch Concept Drift", + Description: "Die statistische Verteilung der Eingabedaten aendert sich ueber die Zeit, wodurch die Modellgenauigkeit schleichend abnimmt.", + DefaultSeverity: 3, + DefaultProbability: 4, + ApplicableComponentTypes: []string{"ai_model"}, + RegulationReferences: []string{"EU AI Act Art. 9", "EU AI Act Art. 72"}, + SuggestedMitigations: mustMarshalJSON([]string{"Monitoring-Dashboard", "Automatisches Retraining"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("model_drift", 2), + Category: "model_drift", + Name: "Data Drift durch veraenderte Umgebung", + Description: "Aenderungen in der physischen Umgebung (Beleuchtung, Temperatur, Material) fuehren zu veraenderten Sensordaten und Modellfehlern.", + DefaultSeverity: 3, + DefaultProbability: 4, + ApplicableComponentTypes: []string{"ai_model", "sensor"}, + RegulationReferences: []string{"EU AI Act Art. 9", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Statistische Ueberwachung", "Sensor-Kalibrierung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("model_drift", 3), + Category: "model_drift", + Name: "Schleichende Modell-Verschlechterung", + Description: "Ohne aktives Monitoring verschlechtert sich die Modellqualitaet ueber Wochen oder Monate unbemerkt.", + DefaultSeverity: 3, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"ai_model"}, + RegulationReferences: []string{"EU AI Act Art. 9", "EU AI Act Art. 72"}, + SuggestedMitigations: mustMarshalJSON([]string{"Regelmaessige Evaluierung", "A/B-Testing"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: sensor_spoofing (3 entries) + // ==================================================================== + { + ID: hazardUUID("sensor_spoofing", 1), + Category: "sensor_spoofing", + Name: "Kamera-Manipulation / Abdeckung", + Description: "Kamerasensoren werden absichtlich oder unbeabsichtigt abgedeckt oder manipuliert, sodass das System auf Basis falscher Bilddaten agiert.", + DefaultSeverity: 4, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"sensor"}, + RegulationReferences: []string{"IEC 62443", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Plausibilitaetspruefung", "Mehrfach-Sensorik"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("sensor_spoofing", 2), + Category: "sensor_spoofing", + Name: "Sensor-Signal-Injection", + Description: "Einspeisung gefaelschter Signale in die Sensorleitungen oder Schnittstellen, um das System gezielt zu manipulieren.", + DefaultSeverity: 5, + DefaultProbability: 1, + ApplicableComponentTypes: []string{"sensor", "network"}, + RegulationReferences: []string{"IEC 62443", "CRA"}, + SuggestedMitigations: mustMarshalJSON([]string{"Signalverschluesselung", "Anomalie-Erkennung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("sensor_spoofing", 3), + Category: "sensor_spoofing", + Name: "Umgebungsbasierte Sensor-Taeuschung", + Description: "Natuerliche oder kuenstliche Umgebungsveraenderungen (Licht, Staub, Vibration) fuehren zu fehlerhaften Sensorwerten.", + DefaultSeverity: 3, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"sensor"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849"}, + SuggestedMitigations: mustMarshalJSON([]string{"Sensor-Fusion", "Redundanz"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: communication_failure (3 entries) + // ==================================================================== + { + ID: hazardUUID("communication_failure", 1), + Category: "communication_failure", + Name: "Feldbus-Ausfall", + Description: "Ausfall des industriellen Feldbusses (z.B. PROFINET, EtherCAT) fuehrt zum Verlust der Kommunikation zwischen Steuerung und Aktorik.", + DefaultSeverity: 4, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"network", "controller"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849", "IEC 62443"}, + SuggestedMitigations: mustMarshalJSON([]string{"Redundanter Bus", "Safe-State-Transition"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("communication_failure", 2), + Category: "communication_failure", + Name: "Cloud-Verbindungsverlust", + Description: "Die Verbindung zur Cloud-Infrastruktur bricht ab, wodurch cloud-abhaengige Funktionen (z.B. Modell-Updates, Monitoring) nicht verfuegbar sind.", + DefaultSeverity: 3, + DefaultProbability: 4, + ApplicableComponentTypes: []string{"network", "software"}, + RegulationReferences: []string{"CRA", "EU AI Act Art. 15"}, + SuggestedMitigations: mustMarshalJSON([]string{"Offline-Faehigkeit", "Edge-Computing"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("communication_failure", 3), + Category: "communication_failure", + Name: "Netzwerk-Latenz-Spitzen", + Description: "Unkontrollierte Latenzspitzen im Netzwerk fuehren zu Timeouts und verspaeteter Datenlieferung an sicherheitsrelevante Systeme.", + DefaultSeverity: 3, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"network"}, + RegulationReferences: []string{"IEC 62443", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"QoS-Konfiguration", "Timeout-Handling"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: unauthorized_access (4 entries) + // ==================================================================== + { + ID: hazardUUID("unauthorized_access", 1), + Category: "unauthorized_access", + Name: "Unautorisierter Remote-Zugriff", + Description: "Ein Angreifer erlangt ueber das Netzwerk Zugriff auf die Maschinensteuerung und kann sicherheitsrelevante Parameter aendern.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"network", "software"}, + RegulationReferences: []string{"IEC 62443", "CRA", "EU AI Act Art. 15"}, + SuggestedMitigations: mustMarshalJSON([]string{"VPN", "MFA", "Netzwerksegmentierung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("unauthorized_access", 2), + Category: "unauthorized_access", + Name: "Konfigurations-Manipulation", + Description: "Sicherheitsrelevante Konfigurationsparameter werden unautorisiert geaendert, z.B. Grenzwerte, Schwellwerte oder Betriebsmodi.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"software", "firmware"}, + RegulationReferences: []string{"IEC 62443", "CRA", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Zugriffskontrolle", "Audit-Log"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("unauthorized_access", 3), + Category: "unauthorized_access", + Name: "Privilege Escalation", + Description: "Ein Benutzer oder Prozess erlangt hoehere Berechtigungen als vorgesehen und kann sicherheitskritische Aktionen ausfuehren.", + DefaultSeverity: 5, + DefaultProbability: 1, + ApplicableComponentTypes: []string{"software"}, + RegulationReferences: []string{"IEC 62443", "CRA"}, + SuggestedMitigations: mustMarshalJSON([]string{"RBAC", "Least Privilege"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("unauthorized_access", 4), + Category: "unauthorized_access", + Name: "Supply-Chain-Angriff auf Komponente", + Description: "Eine kompromittierte Softwarekomponente oder Firmware wird ueber die Lieferkette eingeschleust und enthaelt Schadcode oder Backdoors.", + DefaultSeverity: 5, + DefaultProbability: 1, + ApplicableComponentTypes: []string{"software", "firmware"}, + RegulationReferences: []string{"CRA", "IEC 62443", "EU AI Act Art. 15"}, + SuggestedMitigations: mustMarshalJSON([]string{"SBOM", "Signaturpruefung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: firmware_corruption (3 entries) + // ==================================================================== + { + ID: hazardUUID("firmware_corruption", 1), + Category: "firmware_corruption", + Name: "Update-Abbruch mit inkonsistentem Zustand", + Description: "Ein Firmware-Update wird unterbrochen (z.B. Stromausfall), wodurch das System in einem inkonsistenten und potenziell unsicheren Zustand verbleibt.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"firmware"}, + RegulationReferences: []string{"CRA", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"A/B-Partitioning", "Rollback"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("firmware_corruption", 2), + Category: "firmware_corruption", + Name: "Rollback-Fehler auf alte Version", + Description: "Ein Rollback auf eine aeltere Firmware-Version schlaegt fehl oder fuehrt zu Inkompatibilitaeten mit der aktuellen Hardware-/Softwarekonfiguration.", + DefaultSeverity: 4, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"firmware"}, + RegulationReferences: []string{"CRA", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Versionsmanagement", "Kompatibilitaetspruefung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("firmware_corruption", 3), + Category: "firmware_corruption", + Name: "Boot-Chain-Angriff", + Description: "Die Bootsequenz wird manipuliert, um unsignierte oder kompromittierte Firmware auszufuehren, was die gesamte Sicherheitsarchitektur untergaebt.", + DefaultSeverity: 5, + DefaultProbability: 1, + ApplicableComponentTypes: []string{"firmware"}, + RegulationReferences: []string{"CRA", "IEC 62443"}, + SuggestedMitigations: mustMarshalJSON([]string{"Secure Boot", "TPM"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: safety_boundary_violation (4 entries) + // ==================================================================== + { + ID: hazardUUID("safety_boundary_violation", 1), + Category: "safety_boundary_violation", + Name: "Kraft-/Drehmoment-Ueberschreitung", + Description: "Aktorische Systeme ueberschreiten die zulaessigen Kraft- oder Drehmomentwerte, was zu Verletzungen oder Maschinenschaeden fuehren kann.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"controller", "actuator"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849", "IEC 62061"}, + SuggestedMitigations: mustMarshalJSON([]string{"Hardware-Limiter", "SIL-Ueberwachung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("safety_boundary_violation", 2), + Category: "safety_boundary_violation", + Name: "Geschwindigkeitsueberschreitung Roboter", + Description: "Ein Industrieroboter ueberschreitet die zulaessige Geschwindigkeit, insbesondere bei Mensch-Roboter-Kollaboration.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"controller", "software"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849", "ISO 10218"}, + SuggestedMitigations: mustMarshalJSON([]string{"Safe Speed Monitoring", "Lichtgitter"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("safety_boundary_violation", 3), + Category: "safety_boundary_violation", + Name: "Versagen des Safe-State", + Description: "Das System kann im Fehlerfall keinen sicheren Zustand einnehmen, da die Sicherheitssteuerung selbst versagt.", + DefaultSeverity: 5, + DefaultProbability: 1, + ApplicableComponentTypes: []string{"controller", "software", "firmware"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849", "IEC 62061"}, + SuggestedMitigations: mustMarshalJSON([]string{"Redundante Sicherheitssteuerung", "Diverse Programmierung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("safety_boundary_violation", 4), + Category: "safety_boundary_violation", + Name: "Arbeitsraum-Verletzung", + Description: "Ein Roboter oder Aktor verlaesst seinen definierten Arbeitsraum und dringt in den Schutzbereich von Personen ein.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"controller", "sensor"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849", "ISO 10218"}, + SuggestedMitigations: mustMarshalJSON([]string{"Sichere Achsueberwachung", "Schutzzaun-Sensorik"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: mode_confusion (3 entries) + // ==================================================================== + { + ID: hazardUUID("mode_confusion", 1), + Category: "mode_confusion", + Name: "Falsche Betriebsart aktiv", + Description: "Das System befindet sich in einer unbeabsichtigten Betriebsart (z.B. Automatik statt Einrichtbetrieb), was zu unerwarteten Maschinenbewegungen fuehrt.", + DefaultSeverity: 4, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"hmi", "software"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849"}, + SuggestedMitigations: mustMarshalJSON([]string{"Betriebsart-Anzeige", "Schluesselschalter"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("mode_confusion", 2), + Category: "mode_confusion", + Name: "Wartung/Normal-Verwechslung", + Description: "Das System wird im Normalbetrieb gewartet oder der Wartungsmodus wird nicht korrekt verlassen, was zu gefaehrlichen Situationen fuehrt.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"hmi", "software"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849"}, + SuggestedMitigations: mustMarshalJSON([]string{"Zugangskontrolle", "Sicherheitsverriegelung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("mode_confusion", 3), + Category: "mode_confusion", + Name: "Automatik-Eingriff waehrend Handbetrieb", + Description: "Das System wechselt waehrend des Handbetriebs unerwartet in den Automatikbetrieb, wodurch eine Person im Gefahrenbereich verletzt werden kann.", + DefaultSeverity: 5, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"software", "controller"}, + RegulationReferences: []string{"Maschinenverordnung 2023/1230", "ISO 13849"}, + SuggestedMitigations: mustMarshalJSON([]string{"Exklusive Betriebsarten", "Zustimmtaster"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: unintended_bias (2 entries) + // ==================================================================== + { + ID: hazardUUID("unintended_bias", 1), + Category: "unintended_bias", + Name: "Diskriminierende KI-Entscheidung", + Description: "Das KI-Modell trifft systematisch diskriminierende Entscheidungen, z.B. bei der Qualitaetsbewertung bestimmter Produktchargen oder Lieferanten.", + DefaultSeverity: 3, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"ai_model"}, + RegulationReferences: []string{"EU AI Act Art. 10", "EU AI Act Art. 71"}, + SuggestedMitigations: mustMarshalJSON([]string{"Bias-Testing", "Fairness-Metriken"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("unintended_bias", 2), + Category: "unintended_bias", + Name: "Verzerrte Trainingsdaten", + Description: "Die Trainingsdaten sind nicht repraesentativ und enthalten systematische Verzerrungen, die zu unfairen oder fehlerhaften Modellergebnissen fuehren.", + DefaultSeverity: 3, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"ai_model"}, + RegulationReferences: []string{"EU AI Act Art. 10", "EU AI Act Art. 71"}, + SuggestedMitigations: mustMarshalJSON([]string{"Datensatz-Audit", "Ausgewogenes Sampling"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + + // ==================================================================== + // Category: update_failure (3 entries) + // ==================================================================== + { + ID: hazardUUID("update_failure", 1), + Category: "update_failure", + Name: "Unvollstaendiges OTA-Update", + Description: "Ein Over-the-Air-Update wird nur teilweise uebertragen oder angewendet, wodurch das System in einem inkonsistenten Zustand verbleibt.", + DefaultSeverity: 4, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"firmware", "software"}, + RegulationReferences: []string{"CRA", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Atomare Updates", "Integritaetspruefung"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("update_failure", 2), + Category: "update_failure", + Name: "Versionskonflikt nach Update", + Description: "Nach einem Update sind Software- und Firmware-Versionen inkompatibel, was zu Fehlfunktionen oder Ausfaellen fuehrt.", + DefaultSeverity: 3, + DefaultProbability: 3, + ApplicableComponentTypes: []string{"software", "firmware"}, + RegulationReferences: []string{"CRA", "Maschinenverordnung 2023/1230"}, + SuggestedMitigations: mustMarshalJSON([]string{"Kompatibilitaetsmatrix", "Staging-Tests"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + { + ID: hazardUUID("update_failure", 3), + Category: "update_failure", + Name: "Unkontrollierter Auto-Update", + Description: "Ein automatisches Update wird ohne Genehmigung oder ausserhalb eines Wartungsfensters eingespielt und stoert den laufenden Betrieb.", + DefaultSeverity: 4, + DefaultProbability: 2, + ApplicableComponentTypes: []string{"software"}, + RegulationReferences: []string{"CRA", "Maschinenverordnung 2023/1230", "IEC 62443"}, + SuggestedMitigations: mustMarshalJSON([]string{"Update-Genehmigung", "Wartungsfenster"}), + IsBuiltin: true, + TenantID: nil, + CreatedAt: now, + }, + } + + return entries +} diff --git a/ai-compliance-sdk/internal/iace/hazard_library_test.go b/ai-compliance-sdk/internal/iace/hazard_library_test.go new file mode 100644 index 0000000..bf10d60 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/hazard_library_test.go @@ -0,0 +1,293 @@ +package iace + +import ( + "testing" + + "github.com/google/uuid" +) + +func TestGetBuiltinHazardLibrary_EntryCount(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + // Expected: 4+3+2+3+3+3+4+3+4+3+2+3 = 37 + if len(entries) != 37 { + t.Fatalf("GetBuiltinHazardLibrary returned %d entries, want 37", len(entries)) + } +} + +func TestGetBuiltinHazardLibrary_AllBuiltinAndNoTenant(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if !e.IsBuiltin { + t.Errorf("entries[%d] (%s): IsBuiltin = false, want true", i, e.Name) + } + if e.TenantID != nil { + t.Errorf("entries[%d] (%s): TenantID = %v, want nil", i, e.Name, e.TenantID) + } + } +} + +func TestGetBuiltinHazardLibrary_UniqueNonZeroUUIDs(t *testing.T) { + entries := GetBuiltinHazardLibrary() + seen := make(map[uuid.UUID]string) + + for i, e := range entries { + if e.ID == uuid.Nil { + t.Errorf("entries[%d] (%s): ID is zero UUID", i, e.Name) + } + if prev, exists := seen[e.ID]; exists { + t.Errorf("entries[%d] (%s): duplicate UUID %s, same as %q", i, e.Name, e.ID, prev) + } + seen[e.ID] = e.Name + } +} + +func TestGetBuiltinHazardLibrary_AllCategoriesPresent(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + expectedCategories := map[string]bool{ + "false_classification": false, + "timing_error": false, + "data_poisoning": false, + "model_drift": false, + "sensor_spoofing": false, + "communication_failure": false, + "unauthorized_access": false, + "firmware_corruption": false, + "safety_boundary_violation": false, + "mode_confusion": false, + "unintended_bias": false, + "update_failure": false, + } + + for _, e := range entries { + if _, ok := expectedCategories[e.Category]; !ok { + t.Errorf("unexpected category %q in entry %q", e.Category, e.Name) + } + expectedCategories[e.Category] = true + } + + for cat, found := range expectedCategories { + if !found { + t.Errorf("expected category %q not found in any entry", cat) + } + } +} + +func TestGetBuiltinHazardLibrary_CategoryCounts(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + expectedCounts := map[string]int{ + "false_classification": 4, + "timing_error": 3, + "data_poisoning": 2, + "model_drift": 3, + "sensor_spoofing": 3, + "communication_failure": 3, + "unauthorized_access": 4, + "firmware_corruption": 3, + "safety_boundary_violation": 4, + "mode_confusion": 3, + "unintended_bias": 2, + "update_failure": 3, + } + + actualCounts := make(map[string]int) + for _, e := range entries { + actualCounts[e.Category]++ + } + + for cat, expected := range expectedCounts { + if actualCounts[cat] != expected { + t.Errorf("category %q: count = %d, want %d", cat, actualCounts[cat], expected) + } + } +} + +func TestGetBuiltinHazardLibrary_SeverityRange(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if e.DefaultSeverity < 1 || e.DefaultSeverity > 5 { + t.Errorf("entries[%d] (%s): DefaultSeverity = %d, want 1-5", i, e.Name, e.DefaultSeverity) + } + } +} + +func TestGetBuiltinHazardLibrary_ProbabilityRange(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if e.DefaultProbability < 1 || e.DefaultProbability > 5 { + t.Errorf("entries[%d] (%s): DefaultProbability = %d, want 1-5", i, e.Name, e.DefaultProbability) + } + } +} + +func TestGetBuiltinHazardLibrary_NonEmptyFields(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if e.Name == "" { + t.Errorf("entries[%d]: Name is empty", i) + } + if e.Category == "" { + t.Errorf("entries[%d] (%s): Category is empty", i, e.Name) + } + if len(e.ApplicableComponentTypes) == 0 { + t.Errorf("entries[%d] (%s): ApplicableComponentTypes is empty", i, e.Name) + } + if len(e.RegulationReferences) == 0 { + t.Errorf("entries[%d] (%s): RegulationReferences is empty", i, e.Name) + } + } +} + +func TestHazardUUID_Deterministic(t *testing.T) { + tests := []struct { + category string + index int + }{ + {"false_classification", 1}, + {"timing_error", 2}, + {"data_poisoning", 1}, + {"update_failure", 3}, + {"mode_confusion", 1}, + } + + for _, tt := range tests { + t.Run(tt.category, func(t *testing.T) { + id1 := hazardUUID(tt.category, tt.index) + id2 := hazardUUID(tt.category, tt.index) + + if id1 != id2 { + t.Errorf("hazardUUID(%q, %d) not deterministic: %s != %s", tt.category, tt.index, id1, id2) + } + if id1 == uuid.Nil { + t.Errorf("hazardUUID(%q, %d) returned zero UUID", tt.category, tt.index) + } + }) + } +} + +func TestHazardUUID_DifferentInputsProduceDifferentUUIDs(t *testing.T) { + tests := []struct { + cat1 string + idx1 int + cat2 string + idx2 int + }{ + {"false_classification", 1, "false_classification", 2}, + {"false_classification", 1, "timing_error", 1}, + {"data_poisoning", 1, "data_poisoning", 2}, + {"mode_confusion", 1, "mode_confusion", 3}, + } + + for _, tt := range tests { + name := tt.cat1 + "_" + tt.cat2 + t.Run(name, func(t *testing.T) { + id1 := hazardUUID(tt.cat1, tt.idx1) + id2 := hazardUUID(tt.cat2, tt.idx2) + + if id1 == id2 { + t.Errorf("hazardUUID(%q,%d) == hazardUUID(%q,%d): %s", tt.cat1, tt.idx1, tt.cat2, tt.idx2, id1) + } + }) + } +} + +func TestGetBuiltinHazardLibrary_CreatedAtSet(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if e.CreatedAt.IsZero() { + t.Errorf("entries[%d] (%s): CreatedAt is zero", i, e.Name) + } + } +} + +func TestGetBuiltinHazardLibrary_DescriptionPresent(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if e.Description == "" { + t.Errorf("entries[%d] (%s): Description is empty", i, e.Name) + } + } +} + +func TestGetBuiltinHazardLibrary_SuggestedMitigationsPresent(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + for i, e := range entries { + if e.SuggestedMitigations == nil || len(e.SuggestedMitigations) == 0 { + t.Errorf("entries[%d] (%s): SuggestedMitigations is nil/empty", i, e.Name) + } + } +} + +func TestGetBuiltinHazardLibrary_ApplicableComponentTypesAreValid(t *testing.T) { + entries := GetBuiltinHazardLibrary() + + validTypes := map[string]bool{ + string(ComponentTypeSoftware): true, + string(ComponentTypeFirmware): true, + string(ComponentTypeAIModel): true, + string(ComponentTypeHMI): true, + string(ComponentTypeSensor): true, + string(ComponentTypeActuator): true, + string(ComponentTypeController): true, + string(ComponentTypeNetwork): true, + string(ComponentTypeOther): true, + } + + for i, e := range entries { + for _, ct := range e.ApplicableComponentTypes { + if !validTypes[ct] { + t.Errorf("entries[%d] (%s): invalid component type %q in ApplicableComponentTypes", i, e.Name, ct) + } + } + } +} + +func TestGetBuiltinHazardLibrary_UUIDsMatchExpected(t *testing.T) { + // Verify the first entry of each category has the expected UUID + // based on the deterministic hazardUUID function. + entries := GetBuiltinHazardLibrary() + + categoryFirstSeen := make(map[string]uuid.UUID) + for _, e := range entries { + if _, exists := categoryFirstSeen[e.Category]; !exists { + categoryFirstSeen[e.Category] = e.ID + } + } + + for cat, actualID := range categoryFirstSeen { + expectedID := hazardUUID(cat, 1) + if actualID != expectedID { + t.Errorf("first entry of category %q: ID = %s, want %s", cat, actualID, expectedID) + } + } +} + +func TestGetBuiltinHazardLibrary_ConsistentAcrossCalls(t *testing.T) { + entries1 := GetBuiltinHazardLibrary() + entries2 := GetBuiltinHazardLibrary() + + if len(entries1) != len(entries2) { + t.Fatalf("inconsistent lengths: %d vs %d", len(entries1), len(entries2)) + } + + for i := range entries1 { + if entries1[i].ID != entries2[i].ID { + t.Errorf("entries[%d]: ID mismatch across calls: %s vs %s", i, entries1[i].ID, entries2[i].ID) + } + if entries1[i].Name != entries2[i].Name { + t.Errorf("entries[%d]: Name mismatch across calls: %q vs %q", i, entries1[i].Name, entries2[i].Name) + } + if entries1[i].Category != entries2[i].Category { + t.Errorf("entries[%d]: Category mismatch across calls: %q vs %q", i, entries1[i].Category, entries2[i].Category) + } + } +} diff --git a/ai-compliance-sdk/internal/iace/models.go b/ai-compliance-sdk/internal/iace/models.go new file mode 100644 index 0000000..56e5b7a --- /dev/null +++ b/ai-compliance-sdk/internal/iace/models.go @@ -0,0 +1,485 @@ +package iace + +import ( + "encoding/json" + "time" + + "github.com/google/uuid" +) + +// ============================================================================ +// Constants / Enums +// ============================================================================ + +// ProjectStatus represents the lifecycle status of an IACE project +type ProjectStatus string + +const ( + ProjectStatusDraft ProjectStatus = "draft" + ProjectStatusOnboarding ProjectStatus = "onboarding" + ProjectStatusClassification ProjectStatus = "classification" + ProjectStatusHazardAnalysis ProjectStatus = "hazard_analysis" + ProjectStatusMitigation ProjectStatus = "mitigation" + ProjectStatusVerification ProjectStatus = "verification" + ProjectStatusTechFile ProjectStatus = "tech_file" + ProjectStatusCompleted ProjectStatus = "completed" + ProjectStatusArchived ProjectStatus = "archived" +) + +// ComponentType represents the type of a system component +type ComponentType string + +const ( + ComponentTypeSoftware ComponentType = "software" + ComponentTypeFirmware ComponentType = "firmware" + ComponentTypeAIModel ComponentType = "ai_model" + ComponentTypeHMI ComponentType = "hmi" + ComponentTypeSensor ComponentType = "sensor" + ComponentTypeActuator ComponentType = "actuator" + ComponentTypeController ComponentType = "controller" + ComponentTypeNetwork ComponentType = "network" + ComponentTypeOther ComponentType = "other" +) + +// RegulationType represents the applicable EU regulation +type RegulationType string + +const ( + RegulationNIS2 RegulationType = "nis2" + RegulationAIAct RegulationType = "ai_act" + RegulationCRA RegulationType = "cra" + RegulationMachineryRegulation RegulationType = "machinery_regulation" +) + +// HazardStatus represents the lifecycle status of a hazard +type HazardStatus string + +const ( + HazardStatusIdentified HazardStatus = "identified" + HazardStatusAssessed HazardStatus = "assessed" + HazardStatusMitigated HazardStatus = "mitigated" + HazardStatusAccepted HazardStatus = "accepted" + HazardStatusClosed HazardStatus = "closed" +) + +// AssessmentType represents the type of risk assessment +type AssessmentType string + +const ( + AssessmentTypeInitial AssessmentType = "initial" + AssessmentTypePostMitigation AssessmentType = "post_mitigation" + AssessmentTypeReassessment AssessmentType = "reassessment" +) + +// RiskLevel represents the severity level of a risk +type RiskLevel string + +const ( + RiskLevelCritical RiskLevel = "critical" + RiskLevelHigh RiskLevel = "high" + RiskLevelMedium RiskLevel = "medium" + RiskLevelLow RiskLevel = "low" + RiskLevelNegligible RiskLevel = "negligible" +) + +// ReductionType represents the type of risk reduction measure +type ReductionType string + +const ( + ReductionTypeDesign ReductionType = "design" + ReductionTypeProtective ReductionType = "protective" + ReductionTypeInformation ReductionType = "information" +) + +// MitigationStatus represents the lifecycle status of a mitigation measure +type MitigationStatus string + +const ( + MitigationStatusPlanned MitigationStatus = "planned" + MitigationStatusImplemented MitigationStatus = "implemented" + MitigationStatusVerified MitigationStatus = "verified" + MitigationStatusRejected MitigationStatus = "rejected" +) + +// VerificationMethod represents the method used for verification +type VerificationMethod string + +const ( + VerificationMethodTest VerificationMethod = "test" + VerificationMethodAnalysis VerificationMethod = "analysis" + VerificationMethodInspection VerificationMethod = "inspection" + VerificationMethodReview VerificationMethod = "review" +) + +// TechFileSectionStatus represents the status of a technical file section +type TechFileSectionStatus string + +const ( + TechFileSectionStatusDraft TechFileSectionStatus = "draft" + TechFileSectionStatusGenerated TechFileSectionStatus = "generated" + TechFileSectionStatusReviewed TechFileSectionStatus = "reviewed" + TechFileSectionStatusApproved TechFileSectionStatus = "approved" +) + +// MonitoringEventType represents the type of monitoring event +type MonitoringEventType string + +const ( + MonitoringEventTypeIncident MonitoringEventType = "incident" + MonitoringEventTypeUpdate MonitoringEventType = "update" + MonitoringEventTypeDriftAlert MonitoringEventType = "drift_alert" + MonitoringEventTypeRegulationChange MonitoringEventType = "regulation_change" + MonitoringEventTypeAudit MonitoringEventType = "audit" +) + +// AuditAction represents the type of action recorded in the audit trail +type AuditAction string + +const ( + AuditActionCreate AuditAction = "create" + AuditActionUpdate AuditAction = "update" + AuditActionDelete AuditAction = "delete" + AuditActionApprove AuditAction = "approve" + AuditActionVerify AuditAction = "verify" +) + +// ============================================================================ +// Main Entities +// ============================================================================ + +// Project represents an IACE compliance project for a machine or system +type Project struct { + ID uuid.UUID `json:"id"` + TenantID uuid.UUID `json:"tenant_id"` + MachineName string `json:"machine_name"` + MachineType string `json:"machine_type"` + Manufacturer string `json:"manufacturer"` + Description string `json:"description,omitempty"` + NarrativeText string `json:"narrative_text,omitempty"` + Status ProjectStatus `json:"status"` + CEMarkingTarget string `json:"ce_marking_target,omitempty"` + CompletenessScore float64 `json:"completeness_score"` + RiskSummary map[string]int `json:"risk_summary,omitempty"` + TriggeredRegulations json.RawMessage `json:"triggered_regulations,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + ArchivedAt *time.Time `json:"archived_at,omitempty"` +} + +// Component represents a system component within a project +type Component struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ParentID *uuid.UUID `json:"parent_id,omitempty"` + Name string `json:"name"` + ComponentType ComponentType `json:"component_type"` + Version string `json:"version,omitempty"` + Description string `json:"description,omitempty"` + IsSafetyRelevant bool `json:"is_safety_relevant"` + IsNetworked bool `json:"is_networked"` + Metadata json.RawMessage `json:"metadata,omitempty"` + SortOrder int `json:"sort_order"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// RegulatoryClassification represents the classification result for a regulation +type RegulatoryClassification struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + Regulation RegulationType `json:"regulation"` + ClassificationResult string `json:"classification_result"` + RiskLevel RiskLevel `json:"risk_level"` + Confidence float64 `json:"confidence"` + Reasoning string `json:"reasoning,omitempty"` + RAGSources json.RawMessage `json:"rag_sources,omitempty"` + Requirements json.RawMessage `json:"requirements,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// HazardLibraryEntry represents a reusable hazard template from the library +type HazardLibraryEntry struct { + ID uuid.UUID `json:"id"` + Category string `json:"category"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + DefaultSeverity int `json:"default_severity"` + DefaultProbability int `json:"default_probability"` + ApplicableComponentTypes []string `json:"applicable_component_types"` + RegulationReferences []string `json:"regulation_references"` + SuggestedMitigations json.RawMessage `json:"suggested_mitigations,omitempty"` + IsBuiltin bool `json:"is_builtin"` + TenantID *uuid.UUID `json:"tenant_id,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// Hazard represents a specific hazard identified within a project +type Hazard struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ComponentID uuid.UUID `json:"component_id"` + LibraryHazardID *uuid.UUID `json:"library_hazard_id,omitempty"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Scenario string `json:"scenario,omitempty"` + Category string `json:"category"` + Status HazardStatus `json:"status"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// RiskAssessment represents a quantitative risk assessment for a hazard +type RiskAssessment struct { + ID uuid.UUID `json:"id"` + HazardID uuid.UUID `json:"hazard_id"` + Version int `json:"version"` + AssessmentType AssessmentType `json:"assessment_type"` + Severity int `json:"severity"` + Exposure int `json:"exposure"` + Probability int `json:"probability"` + InherentRisk float64 `json:"inherent_risk"` + ControlMaturity int `json:"control_maturity"` + ControlCoverage float64 `json:"control_coverage"` + TestEvidenceStrength float64 `json:"test_evidence_strength"` + CEff float64 `json:"c_eff"` + ResidualRisk float64 `json:"residual_risk"` + RiskLevel RiskLevel `json:"risk_level"` + IsAcceptable bool `json:"is_acceptable"` + AcceptanceJustification string `json:"acceptance_justification,omitempty"` + AssessedBy uuid.UUID `json:"assessed_by"` + CreatedAt time.Time `json:"created_at"` +} + +// Mitigation represents a risk reduction measure applied to a hazard +type Mitigation struct { + ID uuid.UUID `json:"id"` + HazardID uuid.UUID `json:"hazard_id"` + ReductionType ReductionType `json:"reduction_type"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Status MitigationStatus `json:"status"` + VerificationMethod VerificationMethod `json:"verification_method,omitempty"` + VerificationResult string `json:"verification_result,omitempty"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + VerifiedBy uuid.UUID `json:"verified_by,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Evidence represents an uploaded file that serves as evidence for compliance +type Evidence struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + MitigationID *uuid.UUID `json:"mitigation_id,omitempty"` + VerificationPlanID *uuid.UUID `json:"verification_plan_id,omitempty"` + FileName string `json:"file_name"` + FilePath string `json:"file_path"` + FileHash string `json:"file_hash"` + FileSize int64 `json:"file_size"` + MimeType string `json:"mime_type"` + Description string `json:"description,omitempty"` + UploadedBy uuid.UUID `json:"uploaded_by"` + CreatedAt time.Time `json:"created_at"` +} + +// VerificationPlan represents a plan for verifying compliance measures +type VerificationPlan struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + HazardID *uuid.UUID `json:"hazard_id,omitempty"` + MitigationID *uuid.UUID `json:"mitigation_id,omitempty"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + AcceptanceCriteria string `json:"acceptance_criteria,omitempty"` + Method VerificationMethod `json:"method"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + CompletedBy uuid.UUID `json:"completed_by,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TechFileSection represents a section of the technical documentation file +type TechFileSection struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + SectionType string `json:"section_type"` + Title string `json:"title"` + Content string `json:"content,omitempty"` + Version int `json:"version"` + Status TechFileSectionStatus `json:"status"` + ApprovedBy uuid.UUID `json:"approved_by,omitempty"` + ApprovedAt *time.Time `json:"approved_at,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// MonitoringEvent represents a post-market monitoring event +type MonitoringEvent struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + EventType MonitoringEventType `json:"event_type"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Severity string `json:"severity"` + ImpactAssessment string `json:"impact_assessment,omitempty"` + Status string `json:"status"` + ResolvedAt *time.Time `json:"resolved_at,omitempty"` + ResolvedBy uuid.UUID `json:"resolved_by,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// AuditTrailEntry represents an immutable audit log entry for compliance traceability +type AuditTrailEntry struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + EntityType string `json:"entity_type"` + EntityID uuid.UUID `json:"entity_id"` + Action AuditAction `json:"action"` + UserID uuid.UUID `json:"user_id"` + OldValues json.RawMessage `json:"old_values,omitempty"` + NewValues json.RawMessage `json:"new_values,omitempty"` + Hash string `json:"hash"` + CreatedAt time.Time `json:"created_at"` +} + +// ============================================================================ +// API Request Types +// ============================================================================ + +// CreateProjectRequest is the API request for creating a new IACE project +type CreateProjectRequest struct { + MachineName string `json:"machine_name" binding:"required"` + MachineType string `json:"machine_type" binding:"required"` + Manufacturer string `json:"manufacturer" binding:"required"` + Description string `json:"description,omitempty"` + NarrativeText string `json:"narrative_text,omitempty"` + CEMarkingTarget string `json:"ce_marking_target,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +// UpdateProjectRequest is the API request for updating an existing project +type UpdateProjectRequest struct { + MachineName *string `json:"machine_name,omitempty"` + MachineType *string `json:"machine_type,omitempty"` + Manufacturer *string `json:"manufacturer,omitempty"` + Description *string `json:"description,omitempty"` + NarrativeText *string `json:"narrative_text,omitempty"` + CEMarkingTarget *string `json:"ce_marking_target,omitempty"` + Metadata *json.RawMessage `json:"metadata,omitempty"` +} + +// CreateComponentRequest is the API request for adding a component to a project +type CreateComponentRequest struct { + ProjectID uuid.UUID `json:"project_id" binding:"required"` + ParentID *uuid.UUID `json:"parent_id,omitempty"` + Name string `json:"name" binding:"required"` + ComponentType ComponentType `json:"component_type" binding:"required"` + Version string `json:"version,omitempty"` + Description string `json:"description,omitempty"` + IsSafetyRelevant bool `json:"is_safety_relevant"` + IsNetworked bool `json:"is_networked"` +} + +// CreateHazardRequest is the API request for creating a new hazard +type CreateHazardRequest struct { + ProjectID uuid.UUID `json:"project_id" binding:"required"` + ComponentID uuid.UUID `json:"component_id" binding:"required"` + LibraryHazardID *uuid.UUID `json:"library_hazard_id,omitempty"` + Name string `json:"name" binding:"required"` + Description string `json:"description,omitempty"` + Scenario string `json:"scenario,omitempty"` + Category string `json:"category" binding:"required"` +} + +// AssessRiskRequest is the API request for performing a risk assessment +type AssessRiskRequest struct { + HazardID uuid.UUID `json:"hazard_id" binding:"required"` + Severity int `json:"severity" binding:"required"` + Exposure int `json:"exposure" binding:"required"` + Probability int `json:"probability" binding:"required"` + ControlMaturity int `json:"control_maturity" binding:"required"` + ControlCoverage float64 `json:"control_coverage" binding:"required"` + TestEvidenceStrength float64 `json:"test_evidence_strength" binding:"required"` + AcceptanceJustification string `json:"acceptance_justification,omitempty"` +} + +// CreateMitigationRequest is the API request for creating a mitigation measure +type CreateMitigationRequest struct { + HazardID uuid.UUID `json:"hazard_id" binding:"required"` + ReductionType ReductionType `json:"reduction_type" binding:"required"` + Name string `json:"name" binding:"required"` + Description string `json:"description,omitempty"` +} + +// CreateVerificationPlanRequest is the API request for creating a verification plan +type CreateVerificationPlanRequest struct { + ProjectID uuid.UUID `json:"project_id" binding:"required"` + HazardID *uuid.UUID `json:"hazard_id,omitempty"` + MitigationID *uuid.UUID `json:"mitigation_id,omitempty"` + Title string `json:"title" binding:"required"` + Description string `json:"description,omitempty"` + AcceptanceCriteria string `json:"acceptance_criteria,omitempty"` + Method VerificationMethod `json:"method" binding:"required"` +} + +// CreateMonitoringEventRequest is the API request for logging a monitoring event +type CreateMonitoringEventRequest struct { + ProjectID uuid.UUID `json:"project_id" binding:"required"` + EventType MonitoringEventType `json:"event_type" binding:"required"` + Title string `json:"title" binding:"required"` + Description string `json:"description,omitempty"` + Severity string `json:"severity" binding:"required"` +} + +// InitFromProfileRequest is the API request for initializing a project from a company profile +type InitFromProfileRequest struct { + CompanyProfile json.RawMessage `json:"company_profile" binding:"required"` + ComplianceScope json.RawMessage `json:"compliance_scope" binding:"required"` +} + +// ============================================================================ +// API Response Types +// ============================================================================ + +// ProjectListResponse is the API response for listing projects +type ProjectListResponse struct { + Projects []Project `json:"projects"` + Total int `json:"total"` +} + +// ProjectDetailResponse is the API response for a single project with related entities +type ProjectDetailResponse struct { + Project + Components []Component `json:"components"` + Classifications []RegulatoryClassification `json:"classifications"` + CompletenessGates []CompletenessGate `json:"completeness_gates"` +} + +// RiskSummaryResponse is the API response for an aggregated risk overview +type RiskSummaryResponse struct { + TotalHazards int `json:"total_hazards"` + Critical int `json:"critical"` + High int `json:"high"` + Medium int `json:"medium"` + Low int `json:"low"` + Negligible int `json:"negligible"` + OverallRiskLevel RiskLevel `json:"overall_risk_level"` + AllAcceptable bool `json:"all_acceptable"` +} + +// CompletenessGate represents a single gate in the project completeness checklist +type CompletenessGate struct { + ID string `json:"id"` + Category string `json:"category"` + Label string `json:"label"` + Required bool `json:"required"` + Passed bool `json:"passed"` + Details string `json:"details,omitempty"` +} diff --git a/ai-compliance-sdk/internal/iace/store.go b/ai-compliance-sdk/internal/iace/store.go new file mode 100644 index 0000000..390ff77 --- /dev/null +++ b/ai-compliance-sdk/internal/iace/store.go @@ -0,0 +1,1777 @@ +package iace + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Store handles IACE data persistence using PostgreSQL +type Store struct { + pool *pgxpool.Pool +} + +// NewStore creates a new IACE store +func NewStore(pool *pgxpool.Pool) *Store { + return &Store{pool: pool} +} + +// ============================================================================ +// Project CRUD Operations +// ============================================================================ + +// CreateProject creates a new IACE project +func (s *Store) CreateProject(ctx context.Context, tenantID uuid.UUID, req CreateProjectRequest) (*Project, error) { + project := &Project{ + ID: uuid.New(), + TenantID: tenantID, + MachineName: req.MachineName, + MachineType: req.MachineType, + Manufacturer: req.Manufacturer, + Description: req.Description, + NarrativeText: req.NarrativeText, + Status: ProjectStatusDraft, + CEMarkingTarget: req.CEMarkingTarget, + Metadata: req.Metadata, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_projects ( + id, tenant_id, machine_name, machine_type, manufacturer, + description, narrative_text, status, ce_marking_target, + completeness_score, risk_summary, triggered_regulations, metadata, + created_at, updated_at, archived_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, $9, + $10, $11, $12, $13, + $14, $15, $16 + ) + `, + project.ID, project.TenantID, project.MachineName, project.MachineType, project.Manufacturer, + project.Description, project.NarrativeText, string(project.Status), project.CEMarkingTarget, + project.CompletenessScore, nil, project.TriggeredRegulations, project.Metadata, + project.CreatedAt, project.UpdatedAt, project.ArchivedAt, + ) + if err != nil { + return nil, fmt.Errorf("create project: %w", err) + } + + return project, nil +} + +// GetProject retrieves a project by ID +func (s *Store) GetProject(ctx context.Context, id uuid.UUID) (*Project, error) { + var p Project + var status string + var riskSummary, triggeredRegulations, metadata []byte + + err := s.pool.QueryRow(ctx, ` + SELECT + id, tenant_id, machine_name, machine_type, manufacturer, + description, narrative_text, status, ce_marking_target, + completeness_score, risk_summary, triggered_regulations, metadata, + created_at, updated_at, archived_at + FROM iace_projects WHERE id = $1 + `, id).Scan( + &p.ID, &p.TenantID, &p.MachineName, &p.MachineType, &p.Manufacturer, + &p.Description, &p.NarrativeText, &status, &p.CEMarkingTarget, + &p.CompletenessScore, &riskSummary, &triggeredRegulations, &metadata, + &p.CreatedAt, &p.UpdatedAt, &p.ArchivedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get project: %w", err) + } + + p.Status = ProjectStatus(status) + json.Unmarshal(riskSummary, &p.RiskSummary) + json.Unmarshal(triggeredRegulations, &p.TriggeredRegulations) + json.Unmarshal(metadata, &p.Metadata) + + return &p, nil +} + +// ListProjects lists all projects for a tenant +func (s *Store) ListProjects(ctx context.Context, tenantID uuid.UUID) ([]Project, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, tenant_id, machine_name, machine_type, manufacturer, + description, narrative_text, status, ce_marking_target, + completeness_score, risk_summary, triggered_regulations, metadata, + created_at, updated_at, archived_at + FROM iace_projects WHERE tenant_id = $1 + ORDER BY created_at DESC + `, tenantID) + if err != nil { + return nil, fmt.Errorf("list projects: %w", err) + } + defer rows.Close() + + var projects []Project + for rows.Next() { + var p Project + var status string + var riskSummary, triggeredRegulations, metadata []byte + + err := rows.Scan( + &p.ID, &p.TenantID, &p.MachineName, &p.MachineType, &p.Manufacturer, + &p.Description, &p.NarrativeText, &status, &p.CEMarkingTarget, + &p.CompletenessScore, &riskSummary, &triggeredRegulations, &metadata, + &p.CreatedAt, &p.UpdatedAt, &p.ArchivedAt, + ) + if err != nil { + return nil, fmt.Errorf("list projects scan: %w", err) + } + + p.Status = ProjectStatus(status) + json.Unmarshal(riskSummary, &p.RiskSummary) + json.Unmarshal(triggeredRegulations, &p.TriggeredRegulations) + json.Unmarshal(metadata, &p.Metadata) + + projects = append(projects, p) + } + + return projects, nil +} + +// UpdateProject updates an existing project's mutable fields +func (s *Store) UpdateProject(ctx context.Context, id uuid.UUID, req UpdateProjectRequest) (*Project, error) { + // Fetch current project first + project, err := s.GetProject(ctx, id) + if err != nil { + return nil, err + } + if project == nil { + return nil, nil + } + + // Apply partial updates + if req.MachineName != nil { + project.MachineName = *req.MachineName + } + if req.MachineType != nil { + project.MachineType = *req.MachineType + } + if req.Manufacturer != nil { + project.Manufacturer = *req.Manufacturer + } + if req.Description != nil { + project.Description = *req.Description + } + if req.NarrativeText != nil { + project.NarrativeText = *req.NarrativeText + } + if req.CEMarkingTarget != nil { + project.CEMarkingTarget = *req.CEMarkingTarget + } + if req.Metadata != nil { + project.Metadata = *req.Metadata + } + + project.UpdatedAt = time.Now().UTC() + + _, err = s.pool.Exec(ctx, ` + UPDATE iace_projects SET + machine_name = $2, machine_type = $3, manufacturer = $4, + description = $5, narrative_text = $6, ce_marking_target = $7, + metadata = $8, updated_at = $9 + WHERE id = $1 + `, + id, project.MachineName, project.MachineType, project.Manufacturer, + project.Description, project.NarrativeText, project.CEMarkingTarget, + project.Metadata, project.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("update project: %w", err) + } + + return project, nil +} + +// ArchiveProject sets the archived_at timestamp and status for a project +func (s *Store) ArchiveProject(ctx context.Context, id uuid.UUID) error { + now := time.Now().UTC() + _, err := s.pool.Exec(ctx, ` + UPDATE iace_projects SET + status = $2, archived_at = $3, updated_at = $3 + WHERE id = $1 + `, id, string(ProjectStatusArchived), now) + if err != nil { + return fmt.Errorf("archive project: %w", err) + } + return nil +} + +// UpdateProjectStatus updates the lifecycle status of a project +func (s *Store) UpdateProjectStatus(ctx context.Context, id uuid.UUID, status ProjectStatus) error { + _, err := s.pool.Exec(ctx, ` + UPDATE iace_projects SET status = $2, updated_at = NOW() + WHERE id = $1 + `, id, string(status)) + if err != nil { + return fmt.Errorf("update project status: %w", err) + } + return nil +} + +// UpdateProjectCompleteness updates the completeness score and risk summary +func (s *Store) UpdateProjectCompleteness(ctx context.Context, id uuid.UUID, score float64, riskSummary map[string]int) error { + riskSummaryJSON, err := json.Marshal(riskSummary) + if err != nil { + return fmt.Errorf("marshal risk summary: %w", err) + } + + _, err = s.pool.Exec(ctx, ` + UPDATE iace_projects SET + completeness_score = $2, risk_summary = $3, updated_at = NOW() + WHERE id = $1 + `, id, score, riskSummaryJSON) + if err != nil { + return fmt.Errorf("update project completeness: %w", err) + } + return nil +} + +// ============================================================================ +// Component CRUD Operations +// ============================================================================ + +// CreateComponent creates a new component within a project +func (s *Store) CreateComponent(ctx context.Context, req CreateComponentRequest) (*Component, error) { + comp := &Component{ + ID: uuid.New(), + ProjectID: req.ProjectID, + ParentID: req.ParentID, + Name: req.Name, + ComponentType: req.ComponentType, + Version: req.Version, + Description: req.Description, + IsSafetyRelevant: req.IsSafetyRelevant, + IsNetworked: req.IsNetworked, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_components ( + id, project_id, parent_id, name, component_type, + version, description, is_safety_relevant, is_networked, + metadata, sort_order, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, $9, + $10, $11, $12, $13 + ) + `, + comp.ID, comp.ProjectID, comp.ParentID, comp.Name, string(comp.ComponentType), + comp.Version, comp.Description, comp.IsSafetyRelevant, comp.IsNetworked, + comp.Metadata, comp.SortOrder, comp.CreatedAt, comp.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("create component: %w", err) + } + + return comp, nil +} + +// GetComponent retrieves a component by ID +func (s *Store) GetComponent(ctx context.Context, id uuid.UUID) (*Component, error) { + var c Component + var compType string + var metadata []byte + + err := s.pool.QueryRow(ctx, ` + SELECT + id, project_id, parent_id, name, component_type, + version, description, is_safety_relevant, is_networked, + metadata, sort_order, created_at, updated_at + FROM iace_components WHERE id = $1 + `, id).Scan( + &c.ID, &c.ProjectID, &c.ParentID, &c.Name, &compType, + &c.Version, &c.Description, &c.IsSafetyRelevant, &c.IsNetworked, + &metadata, &c.SortOrder, &c.CreatedAt, &c.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get component: %w", err) + } + + c.ComponentType = ComponentType(compType) + json.Unmarshal(metadata, &c.Metadata) + + return &c, nil +} + +// ListComponents lists all components for a project +func (s *Store) ListComponents(ctx context.Context, projectID uuid.UUID) ([]Component, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, parent_id, name, component_type, + version, description, is_safety_relevant, is_networked, + metadata, sort_order, created_at, updated_at + FROM iace_components WHERE project_id = $1 + ORDER BY sort_order ASC, created_at ASC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list components: %w", err) + } + defer rows.Close() + + var components []Component + for rows.Next() { + var c Component + var compType string + var metadata []byte + + err := rows.Scan( + &c.ID, &c.ProjectID, &c.ParentID, &c.Name, &compType, + &c.Version, &c.Description, &c.IsSafetyRelevant, &c.IsNetworked, + &metadata, &c.SortOrder, &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list components scan: %w", err) + } + + c.ComponentType = ComponentType(compType) + json.Unmarshal(metadata, &c.Metadata) + + components = append(components, c) + } + + return components, nil +} + +// UpdateComponent updates a component with a dynamic set of fields +func (s *Store) UpdateComponent(ctx context.Context, id uuid.UUID, updates map[string]interface{}) (*Component, error) { + if len(updates) == 0 { + return s.GetComponent(ctx, id) + } + + query := "UPDATE iace_components SET updated_at = NOW()" + args := []interface{}{id} + argIdx := 2 + + for key, val := range updates { + switch key { + case "name", "version", "description": + query += fmt.Sprintf(", %s = $%d", key, argIdx) + args = append(args, val) + argIdx++ + case "component_type": + query += fmt.Sprintf(", component_type = $%d", argIdx) + args = append(args, val) + argIdx++ + case "is_safety_relevant": + query += fmt.Sprintf(", is_safety_relevant = $%d", argIdx) + args = append(args, val) + argIdx++ + case "is_networked": + query += fmt.Sprintf(", is_networked = $%d", argIdx) + args = append(args, val) + argIdx++ + case "sort_order": + query += fmt.Sprintf(", sort_order = $%d", argIdx) + args = append(args, val) + argIdx++ + case "metadata": + metaJSON, _ := json.Marshal(val) + query += fmt.Sprintf(", metadata = $%d", argIdx) + args = append(args, metaJSON) + argIdx++ + case "parent_id": + query += fmt.Sprintf(", parent_id = $%d", argIdx) + args = append(args, val) + argIdx++ + } + } + + query += " WHERE id = $1" + + _, err := s.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("update component: %w", err) + } + + return s.GetComponent(ctx, id) +} + +// DeleteComponent deletes a component by ID +func (s *Store) DeleteComponent(ctx context.Context, id uuid.UUID) error { + _, err := s.pool.Exec(ctx, "DELETE FROM iace_components WHERE id = $1", id) + if err != nil { + return fmt.Errorf("delete component: %w", err) + } + return nil +} + +// ============================================================================ +// Classification Operations +// ============================================================================ + +// UpsertClassification inserts or updates a regulatory classification for a project +func (s *Store) UpsertClassification(ctx context.Context, projectID uuid.UUID, regulation RegulationType, result string, riskLevel string, confidence float64, reasoning string, ragSources, requirements json.RawMessage) (*RegulatoryClassification, error) { + id := uuid.New() + now := time.Now().UTC() + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_classifications ( + id, project_id, regulation, classification_result, + risk_level, confidence, reasoning, + rag_sources, requirements, + created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, + $8, $9, + $10, $11 + ) + ON CONFLICT (project_id, regulation) + DO UPDATE SET + classification_result = EXCLUDED.classification_result, + risk_level = EXCLUDED.risk_level, + confidence = EXCLUDED.confidence, + reasoning = EXCLUDED.reasoning, + rag_sources = EXCLUDED.rag_sources, + requirements = EXCLUDED.requirements, + updated_at = EXCLUDED.updated_at + `, + id, projectID, string(regulation), result, + riskLevel, confidence, reasoning, + ragSources, requirements, + now, now, + ) + if err != nil { + return nil, fmt.Errorf("upsert classification: %w", err) + } + + // Retrieve the upserted row (may have kept the original ID on conflict) + return s.getClassificationByProjectAndRegulation(ctx, projectID, regulation) +} + +// getClassificationByProjectAndRegulation is a helper to fetch a single classification +func (s *Store) getClassificationByProjectAndRegulation(ctx context.Context, projectID uuid.UUID, regulation RegulationType) (*RegulatoryClassification, error) { + var c RegulatoryClassification + var reg, rl string + var ragSources, requirements []byte + + err := s.pool.QueryRow(ctx, ` + SELECT + id, project_id, regulation, classification_result, + risk_level, confidence, reasoning, + rag_sources, requirements, + created_at, updated_at + FROM iace_classifications + WHERE project_id = $1 AND regulation = $2 + `, projectID, string(regulation)).Scan( + &c.ID, &c.ProjectID, ®, &c.ClassificationResult, + &rl, &c.Confidence, &c.Reasoning, + &ragSources, &requirements, + &c.CreatedAt, &c.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get classification: %w", err) + } + + c.Regulation = RegulationType(reg) + c.RiskLevel = RiskLevel(rl) + json.Unmarshal(ragSources, &c.RAGSources) + json.Unmarshal(requirements, &c.Requirements) + + return &c, nil +} + +// GetClassifications retrieves all classifications for a project +func (s *Store) GetClassifications(ctx context.Context, projectID uuid.UUID) ([]RegulatoryClassification, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, regulation, classification_result, + risk_level, confidence, reasoning, + rag_sources, requirements, + created_at, updated_at + FROM iace_classifications + WHERE project_id = $1 + ORDER BY regulation ASC + `, projectID) + if err != nil { + return nil, fmt.Errorf("get classifications: %w", err) + } + defer rows.Close() + + var classifications []RegulatoryClassification + for rows.Next() { + var c RegulatoryClassification + var reg, rl string + var ragSources, requirements []byte + + err := rows.Scan( + &c.ID, &c.ProjectID, ®, &c.ClassificationResult, + &rl, &c.Confidence, &c.Reasoning, + &ragSources, &requirements, + &c.CreatedAt, &c.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("get classifications scan: %w", err) + } + + c.Regulation = RegulationType(reg) + c.RiskLevel = RiskLevel(rl) + json.Unmarshal(ragSources, &c.RAGSources) + json.Unmarshal(requirements, &c.Requirements) + + classifications = append(classifications, c) + } + + return classifications, nil +} + +// ============================================================================ +// Hazard CRUD Operations +// ============================================================================ + +// CreateHazard creates a new hazard within a project +func (s *Store) CreateHazard(ctx context.Context, req CreateHazardRequest) (*Hazard, error) { + h := &Hazard{ + ID: uuid.New(), + ProjectID: req.ProjectID, + ComponentID: req.ComponentID, + LibraryHazardID: req.LibraryHazardID, + Name: req.Name, + Description: req.Description, + Scenario: req.Scenario, + Category: req.Category, + Status: HazardStatusIdentified, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_hazards ( + id, project_id, component_id, library_hazard_id, + name, description, scenario, category, status, + created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, $8, $9, + $10, $11 + ) + `, + h.ID, h.ProjectID, h.ComponentID, h.LibraryHazardID, + h.Name, h.Description, h.Scenario, h.Category, string(h.Status), + h.CreatedAt, h.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("create hazard: %w", err) + } + + return h, nil +} + +// GetHazard retrieves a hazard by ID +func (s *Store) GetHazard(ctx context.Context, id uuid.UUID) (*Hazard, error) { + var h Hazard + var status string + + err := s.pool.QueryRow(ctx, ` + SELECT + id, project_id, component_id, library_hazard_id, + name, description, scenario, category, status, + created_at, updated_at + FROM iace_hazards WHERE id = $1 + `, id).Scan( + &h.ID, &h.ProjectID, &h.ComponentID, &h.LibraryHazardID, + &h.Name, &h.Description, &h.Scenario, &h.Category, &status, + &h.CreatedAt, &h.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get hazard: %w", err) + } + + h.Status = HazardStatus(status) + return &h, nil +} + +// ListHazards lists all hazards for a project +func (s *Store) ListHazards(ctx context.Context, projectID uuid.UUID) ([]Hazard, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, component_id, library_hazard_id, + name, description, scenario, category, status, + created_at, updated_at + FROM iace_hazards WHERE project_id = $1 + ORDER BY created_at ASC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list hazards: %w", err) + } + defer rows.Close() + + var hazards []Hazard + for rows.Next() { + var h Hazard + var status string + + err := rows.Scan( + &h.ID, &h.ProjectID, &h.ComponentID, &h.LibraryHazardID, + &h.Name, &h.Description, &h.Scenario, &h.Category, &status, + &h.CreatedAt, &h.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list hazards scan: %w", err) + } + + h.Status = HazardStatus(status) + hazards = append(hazards, h) + } + + return hazards, nil +} + +// UpdateHazard updates a hazard with a dynamic set of fields +func (s *Store) UpdateHazard(ctx context.Context, id uuid.UUID, updates map[string]interface{}) (*Hazard, error) { + if len(updates) == 0 { + return s.GetHazard(ctx, id) + } + + query := "UPDATE iace_hazards SET updated_at = NOW()" + args := []interface{}{id} + argIdx := 2 + + for key, val := range updates { + switch key { + case "name", "description", "scenario", "category": + query += fmt.Sprintf(", %s = $%d", key, argIdx) + args = append(args, val) + argIdx++ + case "status": + query += fmt.Sprintf(", status = $%d", argIdx) + args = append(args, val) + argIdx++ + case "component_id": + query += fmt.Sprintf(", component_id = $%d", argIdx) + args = append(args, val) + argIdx++ + } + } + + query += " WHERE id = $1" + + _, err := s.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("update hazard: %w", err) + } + + return s.GetHazard(ctx, id) +} + +// ============================================================================ +// Risk Assessment Operations +// ============================================================================ + +// CreateRiskAssessment creates a new risk assessment for a hazard +func (s *Store) CreateRiskAssessment(ctx context.Context, assessment *RiskAssessment) error { + if assessment.ID == uuid.Nil { + assessment.ID = uuid.New() + } + if assessment.CreatedAt.IsZero() { + assessment.CreatedAt = time.Now().UTC() + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_risk_assessments ( + id, hazard_id, version, assessment_type, + severity, exposure, probability, + inherent_risk, control_maturity, control_coverage, + test_evidence_strength, c_eff, residual_risk, + risk_level, is_acceptable, acceptance_justification, + assessed_by, created_at + ) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, + $8, $9, $10, + $11, $12, $13, + $14, $15, $16, + $17, $18 + ) + `, + assessment.ID, assessment.HazardID, assessment.Version, string(assessment.AssessmentType), + assessment.Severity, assessment.Exposure, assessment.Probability, + assessment.InherentRisk, assessment.ControlMaturity, assessment.ControlCoverage, + assessment.TestEvidenceStrength, assessment.CEff, assessment.ResidualRisk, + string(assessment.RiskLevel), assessment.IsAcceptable, assessment.AcceptanceJustification, + assessment.AssessedBy, assessment.CreatedAt, + ) + if err != nil { + return fmt.Errorf("create risk assessment: %w", err) + } + + return nil +} + +// GetLatestAssessment retrieves the most recent risk assessment for a hazard +func (s *Store) GetLatestAssessment(ctx context.Context, hazardID uuid.UUID) (*RiskAssessment, error) { + var a RiskAssessment + var assessmentType, riskLevel string + + err := s.pool.QueryRow(ctx, ` + SELECT + id, hazard_id, version, assessment_type, + severity, exposure, probability, + inherent_risk, control_maturity, control_coverage, + test_evidence_strength, c_eff, residual_risk, + risk_level, is_acceptable, acceptance_justification, + assessed_by, created_at + FROM iace_risk_assessments + WHERE hazard_id = $1 + ORDER BY version DESC, created_at DESC + LIMIT 1 + `, hazardID).Scan( + &a.ID, &a.HazardID, &a.Version, &assessmentType, + &a.Severity, &a.Exposure, &a.Probability, + &a.InherentRisk, &a.ControlMaturity, &a.ControlCoverage, + &a.TestEvidenceStrength, &a.CEff, &a.ResidualRisk, + &riskLevel, &a.IsAcceptable, &a.AcceptanceJustification, + &a.AssessedBy, &a.CreatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get latest assessment: %w", err) + } + + a.AssessmentType = AssessmentType(assessmentType) + a.RiskLevel = RiskLevel(riskLevel) + + return &a, nil +} + +// ListAssessments lists all risk assessments for a hazard, newest first +func (s *Store) ListAssessments(ctx context.Context, hazardID uuid.UUID) ([]RiskAssessment, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, hazard_id, version, assessment_type, + severity, exposure, probability, + inherent_risk, control_maturity, control_coverage, + test_evidence_strength, c_eff, residual_risk, + risk_level, is_acceptable, acceptance_justification, + assessed_by, created_at + FROM iace_risk_assessments + WHERE hazard_id = $1 + ORDER BY version DESC, created_at DESC + `, hazardID) + if err != nil { + return nil, fmt.Errorf("list assessments: %w", err) + } + defer rows.Close() + + var assessments []RiskAssessment + for rows.Next() { + var a RiskAssessment + var assessmentType, riskLevel string + + err := rows.Scan( + &a.ID, &a.HazardID, &a.Version, &assessmentType, + &a.Severity, &a.Exposure, &a.Probability, + &a.InherentRisk, &a.ControlMaturity, &a.ControlCoverage, + &a.TestEvidenceStrength, &a.CEff, &a.ResidualRisk, + &riskLevel, &a.IsAcceptable, &a.AcceptanceJustification, + &a.AssessedBy, &a.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list assessments scan: %w", err) + } + + a.AssessmentType = AssessmentType(assessmentType) + a.RiskLevel = RiskLevel(riskLevel) + + assessments = append(assessments, a) + } + + return assessments, nil +} + +// ============================================================================ +// Mitigation CRUD Operations +// ============================================================================ + +// CreateMitigation creates a new mitigation measure for a hazard +func (s *Store) CreateMitigation(ctx context.Context, req CreateMitigationRequest) (*Mitigation, error) { + m := &Mitigation{ + ID: uuid.New(), + HazardID: req.HazardID, + ReductionType: req.ReductionType, + Name: req.Name, + Description: req.Description, + Status: MitigationStatusPlanned, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_mitigations ( + id, hazard_id, reduction_type, name, description, + status, verification_method, verification_result, + verified_at, verified_by, + created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, + $9, $10, + $11, $12 + ) + `, + m.ID, m.HazardID, string(m.ReductionType), m.Name, m.Description, + string(m.Status), "", "", + nil, uuid.Nil, + m.CreatedAt, m.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("create mitigation: %w", err) + } + + return m, nil +} + +// UpdateMitigation updates a mitigation with a dynamic set of fields +func (s *Store) UpdateMitigation(ctx context.Context, id uuid.UUID, updates map[string]interface{}) (*Mitigation, error) { + if len(updates) == 0 { + return s.getMitigation(ctx, id) + } + + query := "UPDATE iace_mitigations SET updated_at = NOW()" + args := []interface{}{id} + argIdx := 2 + + for key, val := range updates { + switch key { + case "name", "description", "verification_result": + query += fmt.Sprintf(", %s = $%d", key, argIdx) + args = append(args, val) + argIdx++ + case "status": + query += fmt.Sprintf(", status = $%d", argIdx) + args = append(args, val) + argIdx++ + case "reduction_type": + query += fmt.Sprintf(", reduction_type = $%d", argIdx) + args = append(args, val) + argIdx++ + case "verification_method": + query += fmt.Sprintf(", verification_method = $%d", argIdx) + args = append(args, val) + argIdx++ + } + } + + query += " WHERE id = $1" + + _, err := s.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("update mitigation: %w", err) + } + + return s.getMitigation(ctx, id) +} + +// VerifyMitigation marks a mitigation as verified +func (s *Store) VerifyMitigation(ctx context.Context, id uuid.UUID, verificationResult string, verifiedBy string) error { + now := time.Now().UTC() + verifiedByUUID, err := uuid.Parse(verifiedBy) + if err != nil { + return fmt.Errorf("invalid verified_by UUID: %w", err) + } + + _, err = s.pool.Exec(ctx, ` + UPDATE iace_mitigations SET + status = $2, + verification_result = $3, + verified_at = $4, + verified_by = $5, + updated_at = $4 + WHERE id = $1 + `, id, string(MitigationStatusVerified), verificationResult, now, verifiedByUUID) + if err != nil { + return fmt.Errorf("verify mitigation: %w", err) + } + + return nil +} + +// ListMitigations lists all mitigations for a hazard +func (s *Store) ListMitigations(ctx context.Context, hazardID uuid.UUID) ([]Mitigation, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, hazard_id, reduction_type, name, description, + status, verification_method, verification_result, + verified_at, verified_by, + created_at, updated_at + FROM iace_mitigations WHERE hazard_id = $1 + ORDER BY created_at ASC + `, hazardID) + if err != nil { + return nil, fmt.Errorf("list mitigations: %w", err) + } + defer rows.Close() + + var mitigations []Mitigation + for rows.Next() { + var m Mitigation + var reductionType, status, verificationMethod string + + err := rows.Scan( + &m.ID, &m.HazardID, &reductionType, &m.Name, &m.Description, + &status, &verificationMethod, &m.VerificationResult, + &m.VerifiedAt, &m.VerifiedBy, + &m.CreatedAt, &m.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list mitigations scan: %w", err) + } + + m.ReductionType = ReductionType(reductionType) + m.Status = MitigationStatus(status) + m.VerificationMethod = VerificationMethod(verificationMethod) + + mitigations = append(mitigations, m) + } + + return mitigations, nil +} + +// getMitigation is a helper to fetch a single mitigation by ID +func (s *Store) getMitigation(ctx context.Context, id uuid.UUID) (*Mitigation, error) { + var m Mitigation + var reductionType, status, verificationMethod string + + err := s.pool.QueryRow(ctx, ` + SELECT + id, hazard_id, reduction_type, name, description, + status, verification_method, verification_result, + verified_at, verified_by, + created_at, updated_at + FROM iace_mitigations WHERE id = $1 + `, id).Scan( + &m.ID, &m.HazardID, &reductionType, &m.Name, &m.Description, + &status, &verificationMethod, &m.VerificationResult, + &m.VerifiedAt, &m.VerifiedBy, + &m.CreatedAt, &m.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get mitigation: %w", err) + } + + m.ReductionType = ReductionType(reductionType) + m.Status = MitigationStatus(status) + m.VerificationMethod = VerificationMethod(verificationMethod) + + return &m, nil +} + +// ============================================================================ +// Evidence Operations +// ============================================================================ + +// CreateEvidence creates a new evidence record +func (s *Store) CreateEvidence(ctx context.Context, evidence *Evidence) error { + if evidence.ID == uuid.Nil { + evidence.ID = uuid.New() + } + if evidence.CreatedAt.IsZero() { + evidence.CreatedAt = time.Now().UTC() + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_evidence ( + id, project_id, mitigation_id, verification_plan_id, + file_name, file_path, file_hash, file_size, mime_type, + description, uploaded_by, created_at + ) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, $8, $9, + $10, $11, $12 + ) + `, + evidence.ID, evidence.ProjectID, evidence.MitigationID, evidence.VerificationPlanID, + evidence.FileName, evidence.FilePath, evidence.FileHash, evidence.FileSize, evidence.MimeType, + evidence.Description, evidence.UploadedBy, evidence.CreatedAt, + ) + if err != nil { + return fmt.Errorf("create evidence: %w", err) + } + + return nil +} + +// ListEvidence lists all evidence for a project +func (s *Store) ListEvidence(ctx context.Context, projectID uuid.UUID) ([]Evidence, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, mitigation_id, verification_plan_id, + file_name, file_path, file_hash, file_size, mime_type, + description, uploaded_by, created_at + FROM iace_evidence WHERE project_id = $1 + ORDER BY created_at DESC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list evidence: %w", err) + } + defer rows.Close() + + var evidence []Evidence + for rows.Next() { + var e Evidence + + err := rows.Scan( + &e.ID, &e.ProjectID, &e.MitigationID, &e.VerificationPlanID, + &e.FileName, &e.FilePath, &e.FileHash, &e.FileSize, &e.MimeType, + &e.Description, &e.UploadedBy, &e.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list evidence scan: %w", err) + } + + evidence = append(evidence, e) + } + + return evidence, nil +} + +// ============================================================================ +// Verification Plan Operations +// ============================================================================ + +// CreateVerificationPlan creates a new verification plan +func (s *Store) CreateVerificationPlan(ctx context.Context, req CreateVerificationPlanRequest) (*VerificationPlan, error) { + vp := &VerificationPlan{ + ID: uuid.New(), + ProjectID: req.ProjectID, + HazardID: req.HazardID, + MitigationID: req.MitigationID, + Title: req.Title, + Description: req.Description, + AcceptanceCriteria: req.AcceptanceCriteria, + Method: req.Method, + Status: "planned", + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_verification_plans ( + id, project_id, hazard_id, mitigation_id, + title, description, acceptance_criteria, method, + status, result, completed_at, completed_by, + created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, $8, + $9, $10, $11, $12, + $13, $14 + ) + `, + vp.ID, vp.ProjectID, vp.HazardID, vp.MitigationID, + vp.Title, vp.Description, vp.AcceptanceCriteria, string(vp.Method), + vp.Status, "", nil, uuid.Nil, + vp.CreatedAt, vp.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("create verification plan: %w", err) + } + + return vp, nil +} + +// UpdateVerificationPlan updates a verification plan with a dynamic set of fields +func (s *Store) UpdateVerificationPlan(ctx context.Context, id uuid.UUID, updates map[string]interface{}) (*VerificationPlan, error) { + if len(updates) == 0 { + return s.getVerificationPlan(ctx, id) + } + + query := "UPDATE iace_verification_plans SET updated_at = NOW()" + args := []interface{}{id} + argIdx := 2 + + for key, val := range updates { + switch key { + case "title", "description", "acceptance_criteria", "result", "status": + query += fmt.Sprintf(", %s = $%d", key, argIdx) + args = append(args, val) + argIdx++ + case "method": + query += fmt.Sprintf(", method = $%d", argIdx) + args = append(args, val) + argIdx++ + } + } + + query += " WHERE id = $1" + + _, err := s.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("update verification plan: %w", err) + } + + return s.getVerificationPlan(ctx, id) +} + +// CompleteVerification marks a verification plan as completed +func (s *Store) CompleteVerification(ctx context.Context, id uuid.UUID, result string, completedBy string) error { + now := time.Now().UTC() + completedByUUID, err := uuid.Parse(completedBy) + if err != nil { + return fmt.Errorf("invalid completed_by UUID: %w", err) + } + + _, err = s.pool.Exec(ctx, ` + UPDATE iace_verification_plans SET + status = 'completed', + result = $2, + completed_at = $3, + completed_by = $4, + updated_at = $3 + WHERE id = $1 + `, id, result, now, completedByUUID) + if err != nil { + return fmt.Errorf("complete verification: %w", err) + } + + return nil +} + +// ListVerificationPlans lists all verification plans for a project +func (s *Store) ListVerificationPlans(ctx context.Context, projectID uuid.UUID) ([]VerificationPlan, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, hazard_id, mitigation_id, + title, description, acceptance_criteria, method, + status, result, completed_at, completed_by, + created_at, updated_at + FROM iace_verification_plans WHERE project_id = $1 + ORDER BY created_at ASC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list verification plans: %w", err) + } + defer rows.Close() + + var plans []VerificationPlan + for rows.Next() { + var vp VerificationPlan + var method string + + err := rows.Scan( + &vp.ID, &vp.ProjectID, &vp.HazardID, &vp.MitigationID, + &vp.Title, &vp.Description, &vp.AcceptanceCriteria, &method, + &vp.Status, &vp.Result, &vp.CompletedAt, &vp.CompletedBy, + &vp.CreatedAt, &vp.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list verification plans scan: %w", err) + } + + vp.Method = VerificationMethod(method) + plans = append(plans, vp) + } + + return plans, nil +} + +// getVerificationPlan is a helper to fetch a single verification plan by ID +func (s *Store) getVerificationPlan(ctx context.Context, id uuid.UUID) (*VerificationPlan, error) { + var vp VerificationPlan + var method string + + err := s.pool.QueryRow(ctx, ` + SELECT + id, project_id, hazard_id, mitigation_id, + title, description, acceptance_criteria, method, + status, result, completed_at, completed_by, + created_at, updated_at + FROM iace_verification_plans WHERE id = $1 + `, id).Scan( + &vp.ID, &vp.ProjectID, &vp.HazardID, &vp.MitigationID, + &vp.Title, &vp.Description, &vp.AcceptanceCriteria, &method, + &vp.Status, &vp.Result, &vp.CompletedAt, &vp.CompletedBy, + &vp.CreatedAt, &vp.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get verification plan: %w", err) + } + + vp.Method = VerificationMethod(method) + return &vp, nil +} + +// ============================================================================ +// Tech File Section Operations +// ============================================================================ + +// CreateTechFileSection creates a new section in the technical file +func (s *Store) CreateTechFileSection(ctx context.Context, projectID uuid.UUID, sectionType, title, content string) (*TechFileSection, error) { + tf := &TechFileSection{ + ID: uuid.New(), + ProjectID: projectID, + SectionType: sectionType, + Title: title, + Content: content, + Version: 1, + Status: TechFileSectionStatusDraft, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_tech_file_sections ( + id, project_id, section_type, title, content, + version, status, approved_by, approved_at, metadata, + created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, $9, $10, + $11, $12 + ) + `, + tf.ID, tf.ProjectID, tf.SectionType, tf.Title, tf.Content, + tf.Version, string(tf.Status), uuid.Nil, nil, nil, + tf.CreatedAt, tf.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("create tech file section: %w", err) + } + + return tf, nil +} + +// UpdateTechFileSection updates the content of a tech file section and bumps version +func (s *Store) UpdateTechFileSection(ctx context.Context, id uuid.UUID, content string) error { + _, err := s.pool.Exec(ctx, ` + UPDATE iace_tech_file_sections SET + content = $2, + version = version + 1, + status = $3, + updated_at = NOW() + WHERE id = $1 + `, id, content, string(TechFileSectionStatusDraft)) + if err != nil { + return fmt.Errorf("update tech file section: %w", err) + } + return nil +} + +// ApproveTechFileSection marks a tech file section as approved +func (s *Store) ApproveTechFileSection(ctx context.Context, id uuid.UUID, approvedBy string) error { + now := time.Now().UTC() + approvedByUUID, err := uuid.Parse(approvedBy) + if err != nil { + return fmt.Errorf("invalid approved_by UUID: %w", err) + } + + _, err = s.pool.Exec(ctx, ` + UPDATE iace_tech_file_sections SET + status = $2, + approved_by = $3, + approved_at = $4, + updated_at = $4 + WHERE id = $1 + `, id, string(TechFileSectionStatusApproved), approvedByUUID, now) + if err != nil { + return fmt.Errorf("approve tech file section: %w", err) + } + + return nil +} + +// ListTechFileSections lists all tech file sections for a project +func (s *Store) ListTechFileSections(ctx context.Context, projectID uuid.UUID) ([]TechFileSection, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, section_type, title, content, + version, status, approved_by, approved_at, metadata, + created_at, updated_at + FROM iace_tech_file_sections WHERE project_id = $1 + ORDER BY section_type ASC, created_at ASC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list tech file sections: %w", err) + } + defer rows.Close() + + var sections []TechFileSection + for rows.Next() { + var tf TechFileSection + var status string + var metadata []byte + + err := rows.Scan( + &tf.ID, &tf.ProjectID, &tf.SectionType, &tf.Title, &tf.Content, + &tf.Version, &status, &tf.ApprovedBy, &tf.ApprovedAt, &metadata, + &tf.CreatedAt, &tf.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list tech file sections scan: %w", err) + } + + tf.Status = TechFileSectionStatus(status) + json.Unmarshal(metadata, &tf.Metadata) + + sections = append(sections, tf) + } + + return sections, nil +} + +// ============================================================================ +// Monitoring Event Operations +// ============================================================================ + +// CreateMonitoringEvent creates a new post-market monitoring event +func (s *Store) CreateMonitoringEvent(ctx context.Context, req CreateMonitoringEventRequest) (*MonitoringEvent, error) { + me := &MonitoringEvent{ + ID: uuid.New(), + ProjectID: req.ProjectID, + EventType: req.EventType, + Title: req.Title, + Description: req.Description, + Severity: req.Severity, + Status: "open", + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO iace_monitoring_events ( + id, project_id, event_type, title, description, + severity, impact_assessment, status, + resolved_at, resolved_by, metadata, + created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, + $9, $10, $11, + $12, $13 + ) + `, + me.ID, me.ProjectID, string(me.EventType), me.Title, me.Description, + me.Severity, "", me.Status, + nil, uuid.Nil, nil, + me.CreatedAt, me.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("create monitoring event: %w", err) + } + + return me, nil +} + +// ListMonitoringEvents lists all monitoring events for a project +func (s *Store) ListMonitoringEvents(ctx context.Context, projectID uuid.UUID) ([]MonitoringEvent, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, event_type, title, description, + severity, impact_assessment, status, + resolved_at, resolved_by, metadata, + created_at, updated_at + FROM iace_monitoring_events WHERE project_id = $1 + ORDER BY created_at DESC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list monitoring events: %w", err) + } + defer rows.Close() + + var events []MonitoringEvent + for rows.Next() { + var me MonitoringEvent + var eventType string + var metadata []byte + + err := rows.Scan( + &me.ID, &me.ProjectID, &eventType, &me.Title, &me.Description, + &me.Severity, &me.ImpactAssessment, &me.Status, + &me.ResolvedAt, &me.ResolvedBy, &metadata, + &me.CreatedAt, &me.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list monitoring events scan: %w", err) + } + + me.EventType = MonitoringEventType(eventType) + json.Unmarshal(metadata, &me.Metadata) + + events = append(events, me) + } + + return events, nil +} + +// UpdateMonitoringEvent updates a monitoring event with a dynamic set of fields +func (s *Store) UpdateMonitoringEvent(ctx context.Context, id uuid.UUID, updates map[string]interface{}) (*MonitoringEvent, error) { + if len(updates) == 0 { + return s.getMonitoringEvent(ctx, id) + } + + query := "UPDATE iace_monitoring_events SET updated_at = NOW()" + args := []interface{}{id} + argIdx := 2 + + for key, val := range updates { + switch key { + case "title", "description", "severity", "impact_assessment", "status": + query += fmt.Sprintf(", %s = $%d", key, argIdx) + args = append(args, val) + argIdx++ + case "event_type": + query += fmt.Sprintf(", event_type = $%d", argIdx) + args = append(args, val) + argIdx++ + case "resolved_at": + query += fmt.Sprintf(", resolved_at = $%d", argIdx) + args = append(args, val) + argIdx++ + case "resolved_by": + query += fmt.Sprintf(", resolved_by = $%d", argIdx) + args = append(args, val) + argIdx++ + case "metadata": + metaJSON, _ := json.Marshal(val) + query += fmt.Sprintf(", metadata = $%d", argIdx) + args = append(args, metaJSON) + argIdx++ + } + } + + query += " WHERE id = $1" + + _, err := s.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("update monitoring event: %w", err) + } + + return s.getMonitoringEvent(ctx, id) +} + +// getMonitoringEvent is a helper to fetch a single monitoring event by ID +func (s *Store) getMonitoringEvent(ctx context.Context, id uuid.UUID) (*MonitoringEvent, error) { + var me MonitoringEvent + var eventType string + var metadata []byte + + err := s.pool.QueryRow(ctx, ` + SELECT + id, project_id, event_type, title, description, + severity, impact_assessment, status, + resolved_at, resolved_by, metadata, + created_at, updated_at + FROM iace_monitoring_events WHERE id = $1 + `, id).Scan( + &me.ID, &me.ProjectID, &eventType, &me.Title, &me.Description, + &me.Severity, &me.ImpactAssessment, &me.Status, + &me.ResolvedAt, &me.ResolvedBy, &metadata, + &me.CreatedAt, &me.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get monitoring event: %w", err) + } + + me.EventType = MonitoringEventType(eventType) + json.Unmarshal(metadata, &me.Metadata) + + return &me, nil +} + +// ============================================================================ +// Audit Trail Operations +// ============================================================================ + +// AddAuditEntry adds an immutable audit trail entry +func (s *Store) AddAuditEntry(ctx context.Context, projectID uuid.UUID, entityType string, entityID uuid.UUID, action AuditAction, userID string, oldValues, newValues json.RawMessage) error { + id := uuid.New() + now := time.Now().UTC() + + userUUID, err := uuid.Parse(userID) + if err != nil { + return fmt.Errorf("invalid user_id UUID: %w", err) + } + + // Compute a simple hash for integrity: sha256(entityType + entityID + action + timestamp) + hashInput := fmt.Sprintf("%s:%s:%s:%s:%s", projectID, entityType, entityID, string(action), now.Format(time.RFC3339Nano)) + // Use a simple deterministic hash representation + hash := fmt.Sprintf("%x", hashInput) + + _, err = s.pool.Exec(ctx, ` + INSERT INTO iace_audit_trail ( + id, project_id, entity_type, entity_id, + action, user_id, old_values, new_values, + hash, created_at + ) VALUES ( + $1, $2, $3, $4, + $5, $6, $7, $8, + $9, $10 + ) + `, + id, projectID, entityType, entityID, + string(action), userUUID, oldValues, newValues, + hash, now, + ) + if err != nil { + return fmt.Errorf("add audit entry: %w", err) + } + + return nil +} + +// ListAuditTrail lists all audit trail entries for a project, newest first +func (s *Store) ListAuditTrail(ctx context.Context, projectID uuid.UUID) ([]AuditTrailEntry, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, project_id, entity_type, entity_id, + action, user_id, old_values, new_values, + hash, created_at + FROM iace_audit_trail WHERE project_id = $1 + ORDER BY created_at DESC + `, projectID) + if err != nil { + return nil, fmt.Errorf("list audit trail: %w", err) + } + defer rows.Close() + + var entries []AuditTrailEntry + for rows.Next() { + var e AuditTrailEntry + var action string + + err := rows.Scan( + &e.ID, &e.ProjectID, &e.EntityType, &e.EntityID, + &action, &e.UserID, &e.OldValues, &e.NewValues, + &e.Hash, &e.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list audit trail scan: %w", err) + } + + e.Action = AuditAction(action) + entries = append(entries, e) + } + + return entries, nil +} + +// ============================================================================ +// Hazard Library Operations +// ============================================================================ + +// ListHazardLibrary lists hazard library entries, optionally filtered by category and component type +func (s *Store) ListHazardLibrary(ctx context.Context, category string, componentType string) ([]HazardLibraryEntry, error) { + query := ` + SELECT + id, category, name, description, + default_severity, default_probability, + applicable_component_types, regulation_references, + suggested_mitigations, is_builtin, tenant_id, + created_at + FROM iace_hazard_library WHERE 1=1` + + args := []interface{}{} + argIdx := 1 + + if category != "" { + query += fmt.Sprintf(" AND category = $%d", argIdx) + args = append(args, category) + argIdx++ + } + if componentType != "" { + query += fmt.Sprintf(" AND applicable_component_types @> $%d::jsonb", argIdx) + componentTypeJSON, _ := json.Marshal([]string{componentType}) + args = append(args, string(componentTypeJSON)) + argIdx++ + } + + query += " ORDER BY category ASC, name ASC" + + rows, err := s.pool.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list hazard library: %w", err) + } + defer rows.Close() + + var entries []HazardLibraryEntry + for rows.Next() { + var e HazardLibraryEntry + var applicableComponentTypes, regulationReferences, suggestedMitigations []byte + + err := rows.Scan( + &e.ID, &e.Category, &e.Name, &e.Description, + &e.DefaultSeverity, &e.DefaultProbability, + &applicableComponentTypes, ®ulationReferences, + &suggestedMitigations, &e.IsBuiltin, &e.TenantID, + &e.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("list hazard library scan: %w", err) + } + + json.Unmarshal(applicableComponentTypes, &e.ApplicableComponentTypes) + json.Unmarshal(regulationReferences, &e.RegulationReferences) + json.Unmarshal(suggestedMitigations, &e.SuggestedMitigations) + + if e.ApplicableComponentTypes == nil { + e.ApplicableComponentTypes = []string{} + } + if e.RegulationReferences == nil { + e.RegulationReferences = []string{} + } + + entries = append(entries, e) + } + + return entries, nil +} + +// GetHazardLibraryEntry retrieves a single hazard library entry by ID +func (s *Store) GetHazardLibraryEntry(ctx context.Context, id uuid.UUID) (*HazardLibraryEntry, error) { + var e HazardLibraryEntry + var applicableComponentTypes, regulationReferences, suggestedMitigations []byte + + err := s.pool.QueryRow(ctx, ` + SELECT + id, category, name, description, + default_severity, default_probability, + applicable_component_types, regulation_references, + suggested_mitigations, is_builtin, tenant_id, + created_at + FROM iace_hazard_library WHERE id = $1 + `, id).Scan( + &e.ID, &e.Category, &e.Name, &e.Description, + &e.DefaultSeverity, &e.DefaultProbability, + &applicableComponentTypes, ®ulationReferences, + &suggestedMitigations, &e.IsBuiltin, &e.TenantID, + &e.CreatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get hazard library entry: %w", err) + } + + json.Unmarshal(applicableComponentTypes, &e.ApplicableComponentTypes) + json.Unmarshal(regulationReferences, &e.RegulationReferences) + json.Unmarshal(suggestedMitigations, &e.SuggestedMitigations) + + if e.ApplicableComponentTypes == nil { + e.ApplicableComponentTypes = []string{} + } + if e.RegulationReferences == nil { + e.RegulationReferences = []string{} + } + + return &e, nil +} + +// ============================================================================ +// Risk Summary (Aggregated View) +// ============================================================================ + +// GetRiskSummary computes an aggregated risk overview for a project +func (s *Store) GetRiskSummary(ctx context.Context, projectID uuid.UUID) (*RiskSummaryResponse, error) { + // Get all hazards for the project + hazards, err := s.ListHazards(ctx, projectID) + if err != nil { + return nil, fmt.Errorf("get risk summary - list hazards: %w", err) + } + + summary := &RiskSummaryResponse{ + TotalHazards: len(hazards), + AllAcceptable: true, + } + + if len(hazards) == 0 { + summary.OverallRiskLevel = RiskLevelNegligible + return summary, nil + } + + highestRisk := RiskLevelNegligible + + for _, h := range hazards { + latest, err := s.GetLatestAssessment(ctx, h.ID) + if err != nil { + return nil, fmt.Errorf("get risk summary - get assessment for hazard %s: %w", h.ID, err) + } + if latest == nil { + // Hazard without assessment counts as unassessed; consider it not acceptable + summary.AllAcceptable = false + continue + } + + switch latest.RiskLevel { + case RiskLevelCritical: + summary.Critical++ + case RiskLevelHigh: + summary.High++ + case RiskLevelMedium: + summary.Medium++ + case RiskLevelLow: + summary.Low++ + case RiskLevelNegligible: + summary.Negligible++ + } + + if !latest.IsAcceptable { + summary.AllAcceptable = false + } + + // Track highest risk level + if riskLevelSeverity(latest.RiskLevel) > riskLevelSeverity(highestRisk) { + highestRisk = latest.RiskLevel + } + } + + summary.OverallRiskLevel = highestRisk + + return summary, nil +} + +// riskLevelSeverity returns a numeric severity for risk level comparison +func riskLevelSeverity(rl RiskLevel) int { + switch rl { + case RiskLevelCritical: + return 5 + case RiskLevelHigh: + return 4 + case RiskLevelMedium: + return 3 + case RiskLevelLow: + return 2 + case RiskLevelNegligible: + return 1 + default: + return 0 + } +}