package iace import ( "encoding/json" "testing" ) func TestClassifyAIAct(t *testing.T) { c := NewClassifier() tests := []struct { name string project *Project components []Component wantResult string wantRiskLevel string wantReqsEmpty bool wantConfidence float64 }{ { name: "no AI components returns not_applicable", project: &Project{MachineName: "TestMachine"}, components: []Component{ {Name: "PLC", ComponentType: ComponentTypeSoftware}, {Name: "Ethernet", ComponentType: ComponentTypeNetwork}, }, wantResult: "not_applicable", wantRiskLevel: "none", wantReqsEmpty: true, wantConfidence: 0.95, }, { name: "no components at all returns not_applicable", project: &Project{MachineName: "EmptyMachine"}, components: []Component{}, wantResult: "not_applicable", wantRiskLevel: "none", wantReqsEmpty: true, wantConfidence: 0.95, }, { name: "AI model not safety relevant returns limited_risk", project: &Project{MachineName: "VisionMachine"}, components: []Component{ {Name: "QualityChecker", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: false}, }, wantResult: "limited_risk", wantRiskLevel: "medium", wantReqsEmpty: false, wantConfidence: 0.85, }, { name: "safety-relevant AI model returns high_risk", project: &Project{MachineName: "SafetyMachine"}, components: []Component{ {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true}, }, wantResult: "high_risk", wantRiskLevel: "high", wantReqsEmpty: false, wantConfidence: 0.9, }, { name: "mixed components with safety-relevant AI returns high_risk", project: &Project{MachineName: "ComplexMachine"}, components: []Component{ {Name: "PLC", ComponentType: ComponentTypeSoftware}, {Name: "BasicAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: false}, {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true}, {Name: "Cam", ComponentType: ComponentTypeSensor}, }, wantResult: "high_risk", wantRiskLevel: "high", wantReqsEmpty: false, wantConfidence: 0.9, }, { name: "non-AI safety-relevant component does not trigger AI act", project: &Project{MachineName: "SafetySoftwareMachine"}, components: []Component{ {Name: "SafetyPLC", ComponentType: ComponentTypeSoftware, IsSafetyRelevant: true}, }, wantResult: "not_applicable", wantRiskLevel: "none", wantReqsEmpty: true, wantConfidence: 0.95, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := c.ClassifyAIAct(tt.project, tt.components) if result.Regulation != RegulationAIAct { t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationAIAct) } if result.ClassificationResult != tt.wantResult { t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) } if result.RiskLevel != tt.wantRiskLevel { t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, tt.wantRiskLevel) } if (result.Requirements == nil || len(result.Requirements) == 0) != tt.wantReqsEmpty { t.Errorf("Requirements empty = %v, want %v", result.Requirements == nil || len(result.Requirements) == 0, tt.wantReqsEmpty) } if result.Confidence != tt.wantConfidence { t.Errorf("Confidence = %f, want %f", result.Confidence, tt.wantConfidence) } if result.Reasoning == "" { t.Error("Reasoning should not be empty") } }) } } func TestClassifyMachineryRegulation(t *testing.T) { c := NewClassifier() tests := []struct { name string project *Project components []Component wantResult string wantRiskLevel string wantReqsLen int }{ { name: "no CE target and no safety SW returns standard", project: &Project{MachineName: "BasicMachine", CEMarkingTarget: ""}, components: []Component{{Name: "App", ComponentType: ComponentTypeSoftware}}, wantResult: "standard", wantRiskLevel: "low", wantReqsLen: 3, }, { name: "CE target set returns applicable", project: &Project{MachineName: "CEMachine", CEMarkingTarget: "2023/1230"}, components: []Component{{Name: "App", ComponentType: ComponentTypeSoftware}}, wantResult: "applicable", wantRiskLevel: "medium", wantReqsLen: 5, }, { name: "safety-relevant software overrides CE target to annex_iii", project: &Project{MachineName: "SafetyMachine", CEMarkingTarget: "2023/1230"}, components: []Component{{Name: "SafetyPLC", ComponentType: ComponentTypeSoftware, IsSafetyRelevant: true}}, wantResult: "annex_iii", wantRiskLevel: "high", wantReqsLen: 7, }, { name: "safety-relevant firmware returns annex_iii", project: &Project{MachineName: "FirmwareMachine", CEMarkingTarget: ""}, components: []Component{{Name: "SafetyFW", ComponentType: ComponentTypeFirmware, IsSafetyRelevant: true}}, wantResult: "annex_iii", wantRiskLevel: "high", wantReqsLen: 7, }, { name: "safety-relevant non-SW component does not trigger annex_iii", project: &Project{MachineName: "SensorMachine", CEMarkingTarget: ""}, components: []Component{ {Name: "SafetySensor", ComponentType: ComponentTypeSensor, IsSafetyRelevant: true}, }, wantResult: "standard", wantRiskLevel: "low", wantReqsLen: 3, }, { name: "AI model safety-relevant does not trigger annex_iii (not software/firmware type)", project: &Project{MachineName: "AIModelMachine", CEMarkingTarget: ""}, components: []Component{ {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true}, }, wantResult: "standard", wantRiskLevel: "low", wantReqsLen: 3, }, { name: "empty components with no CE target returns standard", project: &Project{MachineName: "EmptyMachine"}, components: []Component{}, wantResult: "standard", wantRiskLevel: "low", wantReqsLen: 3, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := c.ClassifyMachineryRegulation(tt.project, tt.components) if result.Regulation != RegulationMachineryRegulation { t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationMachineryRegulation) } if result.ClassificationResult != tt.wantResult { t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) } if result.RiskLevel != tt.wantRiskLevel { t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, tt.wantRiskLevel) } if len(result.Requirements) != tt.wantReqsLen { t.Errorf("Requirements length = %d, want %d", len(result.Requirements), tt.wantReqsLen) } if result.Reasoning == "" { t.Error("Reasoning should not be empty") } }) } } func TestClassifyCRA(t *testing.T) { c := NewClassifier() tests := []struct { name string project *Project components []Component wantResult string wantRiskLevel string wantReqsNil bool }{ { name: "no networked components returns not_applicable", project: &Project{MachineName: "OfflineMachine"}, components: []Component{{Name: "PLC", ComponentType: ComponentTypeSoftware, IsNetworked: false}}, wantResult: "not_applicable", wantRiskLevel: "none", wantReqsNil: true, }, { name: "empty components returns not_applicable", project: &Project{MachineName: "EmptyMachine"}, components: []Component{}, wantResult: "not_applicable", wantRiskLevel: "none", wantReqsNil: true, }, { name: "networked generic software returns default", project: &Project{MachineName: "GenericNetworkedMachine"}, components: []Component{ {Name: "App", ComponentType: ComponentTypeSoftware, IsNetworked: true}, }, wantResult: "default", wantRiskLevel: "low", wantReqsNil: false, }, { name: "networked controller returns class_i", project: &Project{MachineName: "ControllerMachine"}, components: []Component{ {Name: "MainPLC", ComponentType: ComponentTypeController, IsNetworked: true}, }, wantResult: "class_i", wantRiskLevel: "medium", wantReqsNil: false, }, { name: "networked network component returns class_i", project: &Project{MachineName: "NetworkMachine"}, components: []Component{ {Name: "Switch", ComponentType: ComponentTypeNetwork, IsNetworked: true}, }, wantResult: "class_i", wantRiskLevel: "medium", wantReqsNil: false, }, { name: "networked sensor returns class_i", project: &Project{MachineName: "SensorMachine"}, components: []Component{ {Name: "IoTSensor", ComponentType: ComponentTypeSensor, IsNetworked: true}, }, wantResult: "class_i", wantRiskLevel: "medium", wantReqsNil: false, }, { name: "safety-relevant networked component returns class_ii", project: &Project{MachineName: "SafetyNetworkedMachine"}, components: []Component{ {Name: "SafetyNet", ComponentType: ComponentTypeSoftware, IsNetworked: true, IsSafetyRelevant: true}, }, wantResult: "class_ii", wantRiskLevel: "high", wantReqsNil: false, }, { name: "safety-relevant overrides critical type", project: &Project{MachineName: "MixedMachine"}, components: []Component{ {Name: "PLC", ComponentType: ComponentTypeController, IsNetworked: true, IsSafetyRelevant: true}, }, wantResult: "class_ii", wantRiskLevel: "high", wantReqsNil: false, }, { name: "non-networked critical type is not_applicable", project: &Project{MachineName: "OfflineControllerMachine"}, components: []Component{ {Name: "PLC", ComponentType: ComponentTypeController, IsNetworked: false}, }, wantResult: "not_applicable", wantRiskLevel: "none", wantReqsNil: true, }, { name: "HMI networked but not critical type returns default", project: &Project{MachineName: "HMIMachine"}, components: []Component{ {Name: "Panel", ComponentType: ComponentTypeHMI, IsNetworked: true}, }, wantResult: "default", wantRiskLevel: "low", wantReqsNil: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := c.ClassifyCRA(tt.project, tt.components) if result.Regulation != RegulationCRA { t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationCRA) } if result.ClassificationResult != tt.wantResult { t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) } if result.RiskLevel != tt.wantRiskLevel { t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, tt.wantRiskLevel) } if (result.Requirements == nil) != tt.wantReqsNil { t.Errorf("Requirements nil = %v, want %v", result.Requirements == nil, tt.wantReqsNil) } if result.Reasoning == "" { t.Error("Reasoning should not be empty") } }) } } func TestClassifyNIS2(t *testing.T) { c := NewClassifier() tests := []struct { name string metadata json.RawMessage wantResult string }{ { name: "nil metadata returns not_applicable", metadata: nil, wantResult: "not_applicable", }, { name: "empty JSON object returns not_applicable", metadata: json.RawMessage(`{}`), wantResult: "not_applicable", }, { name: "invalid JSON returns not_applicable", metadata: json.RawMessage(`not-json`), wantResult: "not_applicable", }, { name: "kritis_supplier true returns indirect_obligation", metadata: json.RawMessage(`{"kritis_supplier": true}`), wantResult: "indirect_obligation", }, { name: "kritis_supplier false returns not_applicable", metadata: json.RawMessage(`{"kritis_supplier": false}`), wantResult: "not_applicable", }, { name: "critical_sector_clients non-empty array returns indirect_obligation", metadata: json.RawMessage(`{"critical_sector_clients": ["energy"]}`), wantResult: "indirect_obligation", }, { name: "critical_sector_clients empty array returns not_applicable", metadata: json.RawMessage(`{"critical_sector_clients": []}`), wantResult: "not_applicable", }, { name: "critical_sector_clients bool true returns indirect_obligation", metadata: json.RawMessage(`{"critical_sector_clients": true}`), wantResult: "indirect_obligation", }, { name: "critical_sector_clients bool false returns not_applicable", metadata: json.RawMessage(`{"critical_sector_clients": false}`), wantResult: "not_applicable", }, { name: "target_sectors with critical sector returns indirect_obligation", metadata: json.RawMessage(`{"target_sectors": ["health"]}`), wantResult: "indirect_obligation", }, { name: "target_sectors energy returns indirect_obligation", metadata: json.RawMessage(`{"target_sectors": ["energy"]}`), wantResult: "indirect_obligation", }, { name: "target_sectors transport returns indirect_obligation", metadata: json.RawMessage(`{"target_sectors": ["transport"]}`), wantResult: "indirect_obligation", }, { name: "target_sectors banking returns indirect_obligation", metadata: json.RawMessage(`{"target_sectors": ["banking"]}`), wantResult: "indirect_obligation", }, { name: "target_sectors water returns indirect_obligation", metadata: json.RawMessage(`{"target_sectors": ["water"]}`), wantResult: "indirect_obligation", }, { name: "target_sectors digital_infra returns indirect_obligation", metadata: json.RawMessage(`{"target_sectors": ["digital_infra"]}`), wantResult: "indirect_obligation", }, { name: "target_sectors non-critical sector returns not_applicable", metadata: json.RawMessage(`{"target_sectors": ["retail"]}`), wantResult: "not_applicable", }, { name: "target_sectors empty array returns not_applicable", metadata: json.RawMessage(`{"target_sectors": []}`), wantResult: "not_applicable", }, { name: "target_sectors case insensitive match", metadata: json.RawMessage(`{"target_sectors": ["Health"]}`), wantResult: "indirect_obligation", }, { name: "kritis_supplier takes precedence over target_sectors", metadata: json.RawMessage(`{"kritis_supplier": true, "target_sectors": ["retail"]}`), wantResult: "indirect_obligation", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { project := &Project{ MachineName: "TestMachine", Metadata: tt.metadata, } result := c.ClassifyNIS2(project, nil) if result.Regulation != RegulationNIS2 { t.Errorf("Regulation = %q, want %q", result.Regulation, RegulationNIS2) } if result.ClassificationResult != tt.wantResult { t.Errorf("ClassificationResult = %q, want %q", result.ClassificationResult, tt.wantResult) } if result.Reasoning == "" { t.Error("Reasoning should not be empty") } if tt.wantResult == "indirect_obligation" { if result.RiskLevel != "medium" { t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, "medium") } if result.Requirements == nil || len(result.Requirements) == 0 { t.Error("Requirements should not be empty for indirect_obligation") } } else { if result.RiskLevel != "none" { t.Errorf("RiskLevel = %q, want %q", result.RiskLevel, "none") } if result.Requirements != nil { t.Errorf("Requirements should be nil for not_applicable, got %v", result.Requirements) } } }) } } func TestClassifyAll(t *testing.T) { c := NewClassifier() tests := []struct { name string project *Project components []Component }{ { name: "returns exactly 4 results for empty project", project: &Project{MachineName: "TestMachine"}, components: []Component{}, }, { name: "returns exactly 4 results for complex project", project: &Project{MachineName: "ComplexMachine", CEMarkingTarget: "2023/1230", Metadata: json.RawMessage(`{"kritis_supplier": true}`)}, components: []Component{ {Name: "SafetyAI", ComponentType: ComponentTypeAIModel, IsSafetyRelevant: true, IsNetworked: true}, {Name: "PLC", ComponentType: ComponentTypeController, IsNetworked: true}, {Name: "SafetyFW", ComponentType: ComponentTypeFirmware, IsSafetyRelevant: true}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { results := c.ClassifyAll(tt.project, tt.components) if len(results) != 4 { t.Fatalf("ClassifyAll returned %d results, want 4", len(results)) } expectedRegulations := map[RegulationType]bool{ RegulationAIAct: false, RegulationMachineryRegulation: false, RegulationCRA: false, RegulationNIS2: false, } for _, r := range results { if _, ok := expectedRegulations[r.Regulation]; !ok { t.Errorf("unexpected regulation %q in results", r.Regulation) } expectedRegulations[r.Regulation] = true } for reg, found := range expectedRegulations { if !found { t.Errorf("missing regulation %q in results", reg) } } // Verify order: AI Act, Machinery, CRA, NIS2 if results[0].Regulation != RegulationAIAct { t.Errorf("results[0].Regulation = %q, want %q", results[0].Regulation, RegulationAIAct) } if results[1].Regulation != RegulationMachineryRegulation { t.Errorf("results[1].Regulation = %q, want %q", results[1].Regulation, RegulationMachineryRegulation) } if results[2].Regulation != RegulationCRA { t.Errorf("results[2].Regulation = %q, want %q", results[2].Regulation, RegulationCRA) } if results[3].Regulation != RegulationNIS2 { t.Errorf("results[3].Regulation = %q, want %q", results[3].Regulation, RegulationNIS2) } }) } }