package gap import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) // Store handles database operations for gap analysis. type Store struct { pool *pgxpool.Pool } // NewStore creates a new Store. func NewStore(pool *pgxpool.Pool) *Store { return &Store{pool: pool} } // ── Product Profile CRUD ──────────────────────────────────────────── // CreateProfile saves a product profile. func (s *Store) CreateProfile(p *ProductProfile) error { ctx := context.Background() p.ID = uuid.New() p.CreatedAt = time.Now() p.UpdatedAt = time.Now() techJSON, _ := json.Marshal(p.Technologies) dataJSON, _ := json.Marshal(p.DataProcessing) marketsJSON, _ := json.Marshal(p.Markets) certsJSON, _ := json.Marshal(p.ExistingCertifications) normsJSON, _ := json.Marshal(p.AppliedNorms) _, err := s.pool.Exec(ctx, ` INSERT INTO compliance.gap_projects (id, tenant_id, name, description, product_type, technologies, data_processing, markets, connected_to_internet, has_software_updates, uses_ai, processes_personal_data, is_critical_infra_supplier, existing_certifications, applied_norms, has_risk_assessment, has_technical_file, has_operating_manual, has_sbom, has_vuln_management, has_update_mechanism, has_incident_response, has_supply_chain_mgmt, ce_marking_since, product_age, iace_project_id, created_at, updated_at) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28)`, p.ID, p.TenantID, p.Name, p.Description, p.ProductType, techJSON, dataJSON, marketsJSON, p.ConnectedToInternet, p.HasSoftwareUpdates, p.UsesAI, p.ProcessesPersonalData, p.IsCriticalInfraSupplier, certsJSON, normsJSON, p.HasRiskAssessment, p.HasTechnicalFile, p.HasOperatingManual, p.HasSBOM, p.HasVulnManagement, p.HasUpdateMechanism, p.HasIncidentResponse, p.HasSupplyChainMgmt, p.CEMarkingSince, p.ProductAge, p.IACEProjectID, p.CreatedAt, p.UpdatedAt, ) return err } // GetProfile loads a product profile by ID. func (s *Store) GetProfile(id uuid.UUID) (*ProductProfile, error) { ctx := context.Background() p := &ProductProfile{} var techJSON, dataJSON, marketsJSON, certsJSON []byte err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, name, description, product_type, technologies, data_processing, markets, connected_to_internet, has_software_updates, uses_ai, processes_personal_data, is_critical_infra_supplier, existing_certifications, created_at, updated_at FROM compliance.gap_projects WHERE id = $1`, id, ).Scan( &p.ID, &p.TenantID, &p.Name, &p.Description, &p.ProductType, &techJSON, &dataJSON, &marketsJSON, &p.ConnectedToInternet, &p.HasSoftwareUpdates, &p.UsesAI, &p.ProcessesPersonalData, &p.IsCriticalInfraSupplier, &certsJSON, &p.CreatedAt, &p.UpdatedAt, ) if err != nil { return nil, err } json.Unmarshal(techJSON, &p.Technologies) json.Unmarshal(dataJSON, &p.DataProcessing) json.Unmarshal(marketsJSON, &p.Markets) json.Unmarshal(certsJSON, &p.ExistingCertifications) return p, nil } // ListProfiles lists profiles for a tenant. func (s *Store) ListProfiles(tenantID uuid.UUID) ([]ProductProfile, error) { ctx := context.Background() rows, err := s.pool.Query(ctx, ` SELECT id, name, description, product_type, created_at FROM compliance.gap_projects WHERE tenant_id = $1 ORDER BY created_at DESC`, tenantID) if err != nil { return nil, err } defer rows.Close() var profiles []ProductProfile for rows.Next() { var p ProductProfile if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ProductType, &p.CreatedAt); err != nil { return nil, err } profiles = append(profiles, p) } return profiles, nil } // ── Master Control Queries ────────────────────────────────────────── // FetchApplicableMCs queries Master Controls relevant for the given // scope signals and regulations. func (s *Store) FetchApplicableMCs(signals []string, regs []ApplicableRegulation) ([]MCGroup, error) { if len(regs) == 0 { return nil, nil } ctx := context.Background() sourceNames := regulationToSourceNames(regs) if len(sourceNames) == 0 { return nil, nil } // Build parameterized query placeholders := make([]string, len(sourceNames)) args := make([]interface{}, len(sourceNames)) for i, name := range sourceNames { placeholders[i] = fmt.Sprintf("$%d", i+1) args[i] = name } query := fmt.Sprintf(` SELECT DISTINCT mc.master_control_id, mc.canonical_name, mc.total_controls, pc.source_citation::jsonb->>'source' as regulation_source FROM compliance.master_controls mc JOIN compliance.master_control_members mcm ON mcm.master_control_uuid = mc.id JOIN compliance.canonical_controls cc ON cc.id = mcm.control_uuid LEFT JOIN compliance.canonical_controls pc ON pc.id = cc.parent_control_uuid WHERE pc.source_citation::jsonb->>'source' IN (%s) GROUP BY mc.master_control_id, mc.canonical_name, mc.total_controls, pc.source_citation::jsonb->>'source' ORDER BY mc.total_controls DESC LIMIT 500`, strings.Join(placeholders, ",")) rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("query MCs: %w", err) } defer rows.Close() var groups []MCGroup for rows.Next() { var g MCGroup var regSource *string if err := rows.Scan(&g.MasterControlID, &g.CanonicalName, &g.ControlCount, ®Source); err != nil { return nil, err } g.Title = formatTitle(g.CanonicalName) g.Severity = inferSeverity(g.CanonicalName) if regSource != nil { g.Regulation = sourceToRegID(*regSource) } groups = append(groups, g) } return groups, nil } // ── Helpers ───────────────────────────────────────────────────────── func regulationToSourceNames(regs []ApplicableRegulation) []string { mapping := map[RegulationID][]string{ RegCRA: {"Cyber Resilience Act (CRA)"}, RegAIAct: {"KI-Verordnung (EU) 2024/1689"}, RegNIS2: {"NIS2-Richtlinie (EU) 2022/2555"}, RegDSGVO: {"DSGVO (EU) 2016/679"}, RegDataAct: {"Data Act"}, RegMiCA: {"Markets in Crypto-Assets (MiCA)"}, RegPSD2: {"Zahlungsdiensterichtlinie 2"}, RegAML: {"Geldwaeschegesetz (GwG)", "AML-Verordnung"}, RegMDR: {"Medizinprodukteverordnung (EU) 2017/745 (MDR)"}, RegMachinery: {"Maschinenverordnung (EU) 2023/1230"}, RegTDDDG: {"TDDDG"}, RegLkSG: {"Lieferkettensorgfaltspflichtengesetz (LkSG)"}, } var names []string for _, reg := range regs { if sources, ok := mapping[reg.ID]; ok { names = append(names, sources...) } } return names } func sourceToRegID(source string) RegulationID { switch { case strings.Contains(source, "CRA") || strings.Contains(source, "Cyber Resilience"): return RegCRA case strings.Contains(source, "KI-Verordnung"): return RegAIAct case strings.Contains(source, "NIS2"): return RegNIS2 case strings.Contains(source, "DSGVO"): return RegDSGVO case strings.Contains(source, "Data Act"): return RegDataAct case strings.Contains(source, "MiCA") || strings.Contains(source, "Crypto"): return RegMiCA case strings.Contains(source, "Zahlungsdienst"): return RegPSD2 case strings.Contains(source, "Geldwäsche") || strings.Contains(source, "AML"): return RegAML case strings.Contains(source, "Medizinprodukt"): return RegMDR case strings.Contains(source, "Maschinenverordnung"): return RegMachinery case strings.Contains(source, "TDDDG"): return RegTDDDG default: return RegDSGVO } } // CheckIACECoverage checks if an IACE project has verified mitigations // covering the given MC topic. func (s *Store) CheckIACECoverage(projectID uuid.UUID, mcTopic string) string { ctx := context.Background() // Map MC topics to IACE hazard categories iaceCategory := mcTopicToIACECategory(mcTopic) if iaceCategory == "" { return "" } var verifiedCount, implementedCount int err := s.pool.QueryRow(ctx, ` SELECT COUNT(CASE WHEN m.status = 'verified' THEN 1 END), COUNT(CASE WHEN m.status = 'implemented' THEN 1 END) FROM iace_mitigations m JOIN iace_hazards h ON h.id = m.hazard_id WHERE h.project_id = $1 AND (h.category ILIKE $2 OR h.sub_category ILIKE $2)`, projectID, "%"+iaceCategory+"%", ).Scan(&verifiedCount, &implementedCount) if err != nil || (verifiedCount == 0 && implementedCount == 0) { return "" } if verifiedCount > 0 { return "verified" } return "implemented" } func mcTopicToIACECategory(topic string) string { mapping := map[string]string{ "encryption": "cyber", "access_control": "software", "network_security": "cyber", "vulnerability": "cyber", "product_safety": "mechanical", "physical_security": "electrical", "monitoring": "software", "incident": "organizational", "risk_management": "general", } for prefix, cat := range mapping { if strings.HasPrefix(topic, prefix) { return cat } } return "" } func formatTitle(name string) string { return strings.ReplaceAll( strings.ReplaceAll(name, "_", " "), " ", " ") } func inferSeverity(name string) string { high := []string{"encryption", "access_control", "incident", "vulnerability", "authentication", "key_management", "data_breach"} for _, h := range high { if strings.Contains(name, h) { return "HIGH" } } return "MEDIUM" }