Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 36s
CI/CD / test-python-backend-compliance (push) Successful in 36s
CI/CD / test-python-document-crawler (push) Successful in 27s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Phase 4: source_type (law/guideline/standard/restricted) on source_citation - NIST/OWASP/ENISA correctly shown as "Standard" instead of "Gesetzliche Grundlage" - Dynamic frontend labels based on source_type - Backfill endpoint POST /v1/canonical/generate/backfill-source-type Phase v3: Scoped Control Applicability - 3 new fields: applicable_industries, applicable_company_size, scope_conditions - LLM prompt extended with 39 industries, 5 company sizes, 10 scope signals - All 5 generation paths (Rule 1/2/3, batch structure, batch reform) updated - _build_control_from_json: parsing + validation (string→list, size validation) - _store_control: writes 3 new JSONB columns - API: response models, create/update requests, SELECT queries extended - Migration 063: 3 new JSONB columns with GIN indexes - 110 generator tests + 28 route tests = 138 total, all passing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1671 lines
69 KiB
Python
1671 lines
69 KiB
Python
"""Tests for Control Generator Pipeline."""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from compliance.services.control_generator import (
|
|
_classify_regulation,
|
|
_detect_domain,
|
|
_detect_recital,
|
|
_parse_llm_json,
|
|
_parse_llm_json_array,
|
|
GeneratorConfig,
|
|
GeneratedControl,
|
|
ControlGeneratorPipeline,
|
|
REGULATION_LICENSE_MAP,
|
|
PIPELINE_VERSION,
|
|
)
|
|
from compliance.services.anchor_finder import AnchorFinder, OpenAnchor
|
|
from compliance.services.rag_client import RAGSearchResult
|
|
|
|
|
|
# =============================================================================
|
|
# License Mapping Tests
|
|
# =============================================================================
|
|
|
|
class TestLicenseMapping:
|
|
"""Tests for regulation_code → license rule classification."""
|
|
|
|
def test_rule1_eu_law(self):
|
|
info = _classify_regulation("eu_2016_679")
|
|
assert info["rule"] == 1
|
|
assert info["name"] == "DSGVO"
|
|
assert info["source_type"] == "law"
|
|
|
|
def test_rule1_nist(self):
|
|
info = _classify_regulation("nist_sp_800_53")
|
|
assert info["rule"] == 1
|
|
assert "NIST" in info["name"]
|
|
assert info["source_type"] == "standard"
|
|
|
|
def test_rule1_german_law(self):
|
|
info = _classify_regulation("bdsg")
|
|
assert info["rule"] == 1
|
|
assert info["name"] == "BDSG"
|
|
assert info["source_type"] == "law"
|
|
|
|
def test_rule2_owasp(self):
|
|
info = _classify_regulation("owasp_asvs")
|
|
assert info["rule"] == 2
|
|
assert "OWASP" in info["name"]
|
|
assert "attribution" in info
|
|
assert info["source_type"] == "standard"
|
|
|
|
def test_rule2_enisa_prefix(self):
|
|
info = _classify_regulation("enisa_iot_security")
|
|
assert info["rule"] == 2
|
|
assert "ENISA" in info["name"]
|
|
assert info["source_type"] == "standard"
|
|
|
|
def test_rule3_bsi_prefix(self):
|
|
info = _classify_regulation("bsi_tr03161")
|
|
assert info["rule"] == 3
|
|
assert info["name"] == "INTERNAL_ONLY"
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_rule3_iso_prefix(self):
|
|
info = _classify_regulation("iso_27001")
|
|
assert info["rule"] == 3
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_rule3_etsi_prefix(self):
|
|
info = _classify_regulation("etsi_en_303_645")
|
|
assert info["rule"] == 3
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_unknown_defaults_to_rule3(self):
|
|
info = _classify_regulation("some_unknown_source")
|
|
assert info["rule"] == 3
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_case_insensitive(self):
|
|
info = _classify_regulation("EU_2016_679")
|
|
assert info["rule"] == 1
|
|
assert info["source_type"] == "law"
|
|
|
|
def test_all_mapped_regulations_have_valid_rules(self):
|
|
for code, info in REGULATION_LICENSE_MAP.items():
|
|
assert info["rule"] in (1, 2, 3), f"{code} has invalid rule {info['rule']}"
|
|
|
|
def test_all_mapped_regulations_have_source_type(self):
|
|
valid_types = {"law", "guideline", "standard", "restricted"}
|
|
for code, info in REGULATION_LICENSE_MAP.items():
|
|
assert "source_type" in info, f"{code} missing source_type"
|
|
assert info["source_type"] in valid_types, f"{code} has invalid source_type {info['source_type']}"
|
|
|
|
def test_rule3_never_exposes_names(self):
|
|
for prefix in ["bsi_test", "iso_test", "etsi_test"]:
|
|
info = _classify_regulation(prefix)
|
|
assert info["name"] == "INTERNAL_ONLY", f"{prefix} exposes name: {info['name']}"
|
|
|
|
|
|
# =============================================================================
|
|
# Domain Detection Tests
|
|
# =============================================================================
|
|
|
|
class TestDomainDetection:
|
|
|
|
def test_auth_domain(self):
|
|
assert _detect_domain("Multi-factor authentication and password policy") == "AUTH"
|
|
|
|
def test_crypto_domain(self):
|
|
assert _detect_domain("TLS 1.3 encryption and certificate management") == "CRYP"
|
|
|
|
def test_network_domain(self):
|
|
assert _detect_domain("Firewall rules and network segmentation") == "NET"
|
|
|
|
def test_data_domain(self):
|
|
assert _detect_domain("DSGVO personenbezogene Daten Datenschutz") == "DATA"
|
|
|
|
def test_default_domain(self):
|
|
assert _detect_domain("random unrelated text xyz") == "SEC"
|
|
|
|
|
|
# =============================================================================
|
|
# JSON Parsing Tests
|
|
# =============================================================================
|
|
|
|
class TestJsonParsing:
|
|
|
|
def test_parse_plain_json(self):
|
|
result = _parse_llm_json('{"title": "Test", "objective": "Test obj"}')
|
|
assert result["title"] == "Test"
|
|
|
|
def test_parse_markdown_fenced_json(self):
|
|
raw = '```json\n{"title": "Test"}\n```'
|
|
result = _parse_llm_json(raw)
|
|
assert result["title"] == "Test"
|
|
|
|
def test_parse_json_with_preamble(self):
|
|
raw = 'Here is the result:\n{"title": "Test"}'
|
|
result = _parse_llm_json(raw)
|
|
assert result["title"] == "Test"
|
|
|
|
def test_parse_invalid_json(self):
|
|
result = _parse_llm_json("not json at all")
|
|
assert result == {}
|
|
|
|
|
|
# =============================================================================
|
|
# GeneratedControl Rule Tests
|
|
# =============================================================================
|
|
|
|
class TestGeneratedControlRules:
|
|
"""Tests that enforce the 3-rule licensing constraints."""
|
|
|
|
def test_rule1_has_original_text(self):
|
|
ctrl = GeneratedControl(license_rule=1)
|
|
ctrl.source_original_text = "Original EU law text"
|
|
ctrl.source_citation = {"source": "DSGVO Art. 35", "license": "EU_LAW"}
|
|
ctrl.customer_visible = True
|
|
|
|
assert ctrl.source_original_text is not None
|
|
assert ctrl.source_citation is not None
|
|
assert ctrl.customer_visible is True
|
|
|
|
def test_rule2_has_citation(self):
|
|
ctrl = GeneratedControl(license_rule=2)
|
|
ctrl.source_citation = {"source": "OWASP ASVS V2.1", "license": "CC-BY-SA-4.0"}
|
|
ctrl.customer_visible = True
|
|
|
|
assert ctrl.source_citation is not None
|
|
assert "CC-BY-SA" in ctrl.source_citation["license"]
|
|
|
|
def test_rule3_no_original_no_citation(self):
|
|
ctrl = GeneratedControl(license_rule=3)
|
|
ctrl.source_original_text = None
|
|
ctrl.source_citation = None
|
|
ctrl.customer_visible = False
|
|
ctrl.generation_metadata = {"processing_path": "llm_reform", "license_rule": 3}
|
|
|
|
assert ctrl.source_original_text is None
|
|
assert ctrl.source_citation is None
|
|
assert ctrl.customer_visible is False
|
|
# generation_metadata must NOT contain source names
|
|
metadata_str = json.dumps(ctrl.generation_metadata)
|
|
assert "bsi" not in metadata_str.lower()
|
|
assert "iso" not in metadata_str.lower()
|
|
assert "TR-03161" not in metadata_str
|
|
|
|
|
|
# =============================================================================
|
|
# Anchor Finder Tests
|
|
# =============================================================================
|
|
|
|
class TestAnchorFinder:
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rag_anchor_search_filters_restricted(self):
|
|
"""Only Rule 1+2 sources are returned as anchors."""
|
|
mock_rag = AsyncMock()
|
|
mock_rag.search.return_value = [
|
|
RAGSearchResult(
|
|
text="OWASP requirement",
|
|
regulation_code="owasp_asvs",
|
|
regulation_name="OWASP ASVS",
|
|
regulation_short="OWASP",
|
|
category="requirement",
|
|
article="V2.1.1",
|
|
paragraph="",
|
|
source_url="https://owasp.org",
|
|
score=0.9,
|
|
),
|
|
RAGSearchResult(
|
|
text="BSI requirement",
|
|
regulation_code="bsi_tr03161",
|
|
regulation_name="BSI TR-03161",
|
|
regulation_short="BSI",
|
|
category="requirement",
|
|
article="O.Auth_1",
|
|
paragraph="",
|
|
source_url="",
|
|
score=0.85,
|
|
),
|
|
]
|
|
|
|
finder = AnchorFinder(rag_client=mock_rag)
|
|
control = GeneratedControl(title="Test Auth Control", tags=["auth"])
|
|
|
|
anchors = await finder.find_anchors(control, skip_web=True)
|
|
|
|
# Only OWASP should be returned (Rule 2), BSI should be filtered out (Rule 3)
|
|
assert len(anchors) == 1
|
|
assert anchors[0].framework == "OWASP ASVS"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_web_search_identifies_frameworks(self):
|
|
finder = AnchorFinder()
|
|
|
|
assert finder._identify_framework_from_url("https://owasp.org/asvs") == "OWASP"
|
|
assert finder._identify_framework_from_url("https://csrc.nist.gov/sp800-53") == "NIST"
|
|
assert finder._identify_framework_from_url("https://www.enisa.europa.eu/pub") == "ENISA"
|
|
assert finder._identify_framework_from_url("https://random-site.com") is None
|
|
|
|
|
|
# =============================================================================
|
|
# Pipeline Integration Tests (Mocked)
|
|
# =============================================================================
|
|
|
|
class TestPipelineMocked:
|
|
"""Tests for the pipeline with mocked DB and external services."""
|
|
|
|
def _make_chunk(self, regulation_code: str = "owasp_asvs", article: str = "V2.1.1"):
|
|
return RAGSearchResult(
|
|
text="Applications must implement multi-factor authentication.",
|
|
regulation_code=regulation_code,
|
|
regulation_name="OWASP ASVS",
|
|
regulation_short="OWASP",
|
|
category="requirement",
|
|
article=article,
|
|
paragraph="",
|
|
source_url="https://owasp.org",
|
|
score=0.9,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rule1_processing_path(self):
|
|
"""Rule 1 chunks produce controls with original text."""
|
|
chunk = self._make_chunk(regulation_code="eu_2016_679", article="Art. 35")
|
|
chunk.text = "Die Datenschutz-Folgenabschaetzung ist durchzufuehren."
|
|
chunk.regulation_name = "DSGVO"
|
|
|
|
mock_db = MagicMock()
|
|
mock_db.execute.return_value.fetchone.return_value = None
|
|
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
license_info = pipeline._classify_license(chunk)
|
|
|
|
assert license_info["rule"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rule3_processing_blocks_source_info(self):
|
|
"""Rule 3 must never store original text or source names."""
|
|
mock_db = MagicMock()
|
|
mock_rag = AsyncMock()
|
|
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=mock_rag)
|
|
|
|
# Simulate LLM response
|
|
llm_response = json.dumps({
|
|
"title": "Secure Password Storage",
|
|
"objective": "Passwords must be hashed with modern algorithms.",
|
|
"rationale": "Prevents credential theft.",
|
|
"requirements": ["Use bcrypt or argon2"],
|
|
"test_procedure": ["Verify hash algorithm"],
|
|
"evidence": ["Config review"],
|
|
"severity": "high",
|
|
"tags": ["auth", "password"],
|
|
})
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", return_value=llm_response):
|
|
chunk = self._make_chunk(regulation_code="bsi_tr03161", article="O.Auth_1")
|
|
config = GeneratorConfig(max_controls=1)
|
|
control = await pipeline._llm_reformulate(chunk, config)
|
|
|
|
assert control.license_rule == 3
|
|
assert control.source_original_text is None
|
|
assert control.source_citation is None
|
|
assert control.customer_visible is False
|
|
# Verify no BSI references in metadata
|
|
metadata_str = json.dumps(control.generation_metadata)
|
|
assert "bsi" not in metadata_str.lower()
|
|
assert "BSI" not in metadata_str
|
|
assert "TR-03161" not in metadata_str
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chunk_hash_deduplication(self):
|
|
"""Same chunk text produces same hash — no double processing."""
|
|
import hashlib
|
|
text = "Test requirement text"
|
|
h1 = hashlib.sha256(text.encode()).hexdigest()
|
|
h2 = hashlib.sha256(text.encode()).hexdigest()
|
|
assert h1 == h2
|
|
|
|
def test_config_defaults(self):
|
|
config = GeneratorConfig()
|
|
assert config.max_controls == 0
|
|
assert config.batch_size == 5
|
|
assert config.skip_processed is True
|
|
assert config.dry_run is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structure_free_use_produces_citation(self):
|
|
"""Rule 1 structuring includes source citation."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
|
|
llm_response = json.dumps({
|
|
"title": "DSFA Pflicht",
|
|
"objective": "DSFA bei hohem Risiko durchfuehren.",
|
|
"rationale": "Gesetzliche Pflicht nach DSGVO.",
|
|
"requirements": ["DSFA durchfuehren"],
|
|
"test_procedure": ["DSFA Bericht pruefen"],
|
|
"evidence": ["DSFA Dokumentation"],
|
|
"severity": "high",
|
|
"tags": ["dsfa", "dsgvo"],
|
|
})
|
|
|
|
chunk = self._make_chunk(regulation_code="eu_2016_679", article="Art. 35")
|
|
chunk.text = "Art. 35 DSGVO: Datenschutz-Folgenabschaetzung"
|
|
chunk.regulation_name = "DSGVO"
|
|
license_info = _classify_regulation("eu_2016_679")
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", return_value=llm_response):
|
|
control = await pipeline._structure_free_use(chunk, license_info)
|
|
|
|
assert control.license_rule == 1
|
|
assert control.source_original_text is not None
|
|
assert control.source_citation is not None
|
|
assert "DSGVO" in control.source_citation["source"]
|
|
assert control.customer_visible is True
|
|
|
|
|
|
# =============================================================================
|
|
# JSON Array Parsing Tests (_parse_llm_json_array)
|
|
# =============================================================================
|
|
|
|
from compliance.services.control_generator import _parse_llm_json_array
|
|
|
|
|
|
class TestParseJsonArray:
|
|
"""Tests for _parse_llm_json_array — batch LLM response parsing."""
|
|
|
|
def test_clean_json_array(self):
|
|
"""A well-formed JSON array should be returned directly."""
|
|
raw = json.dumps([
|
|
{"title": "Control A", "objective": "Obj A"},
|
|
{"title": "Control B", "objective": "Obj B"},
|
|
])
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 2
|
|
assert result[0]["title"] == "Control A"
|
|
assert result[1]["title"] == "Control B"
|
|
|
|
def test_json_array_in_markdown_code_block(self):
|
|
"""JSON array wrapped in ```json ... ``` markdown fence."""
|
|
inner = json.dumps([
|
|
{"title": "Fenced A", "chunk_index": 1},
|
|
{"title": "Fenced B", "chunk_index": 2},
|
|
])
|
|
raw = f"Here is the result:\n```json\n{inner}\n```\nDone."
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 2
|
|
assert result[0]["title"] == "Fenced A"
|
|
assert result[1]["title"] == "Fenced B"
|
|
|
|
def test_markdown_code_block_without_json_tag(self):
|
|
"""Markdown fence without explicit 'json' language tag."""
|
|
inner = json.dumps([{"title": "NoTag", "objective": "test"}])
|
|
raw = f"```\n{inner}\n```"
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "NoTag"
|
|
|
|
def test_wrapper_object_controls_key(self):
|
|
"""LLM wraps array in {"controls": [...]} — should unwrap."""
|
|
raw = json.dumps({
|
|
"controls": [
|
|
{"title": "Wrapped A", "objective": "O1"},
|
|
{"title": "Wrapped B", "objective": "O2"},
|
|
]
|
|
})
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 2
|
|
assert result[0]["title"] == "Wrapped A"
|
|
|
|
def test_wrapper_object_results_key(self):
|
|
"""LLM wraps array in {"results": [...]} — should unwrap."""
|
|
raw = json.dumps({
|
|
"results": [
|
|
{"title": "R1"},
|
|
{"title": "R2"},
|
|
{"title": "R3"},
|
|
]
|
|
})
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 3
|
|
|
|
def test_wrapper_object_items_key(self):
|
|
"""LLM wraps array in {"items": [...]} — should unwrap."""
|
|
raw = json.dumps({
|
|
"items": [{"title": "Item1"}]
|
|
})
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "Item1"
|
|
|
|
def test_wrapper_object_data_key(self):
|
|
"""LLM wraps array in {"data": [...]} — should unwrap."""
|
|
raw = json.dumps({
|
|
"data": [{"title": "D1"}, {"title": "D2"}]
|
|
})
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 2
|
|
|
|
def test_single_dict_returned_as_list(self):
|
|
"""A single JSON object (no array) is wrapped in a list."""
|
|
raw = json.dumps({"title": "SingleControl", "objective": "Obj"})
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "SingleControl"
|
|
|
|
def test_individual_json_objects_fallback(self):
|
|
"""Multiple separate JSON objects (not in array) are collected."""
|
|
raw = (
|
|
'Here are the controls:\n'
|
|
'{"title": "Ctrl1", "objective": "A"}\n'
|
|
'{"title": "Ctrl2", "objective": "B"}\n'
|
|
)
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 2
|
|
titles = {r["title"] for r in result}
|
|
assert "Ctrl1" in titles
|
|
assert "Ctrl2" in titles
|
|
|
|
def test_individual_objects_require_title(self):
|
|
"""Fallback individual-object parsing only includes objects with 'title'."""
|
|
raw = (
|
|
'{"title": "HasTitle", "objective": "Yes"}\n'
|
|
'{"no_title_field": "skip_me"}\n'
|
|
)
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "HasTitle"
|
|
|
|
def test_empty_string_returns_empty_list(self):
|
|
"""Empty input returns empty list."""
|
|
result = _parse_llm_json_array("")
|
|
assert result == []
|
|
|
|
def test_invalid_input_returns_empty_list(self):
|
|
"""Completely invalid input returns empty list."""
|
|
result = _parse_llm_json_array("this is not json at all, no braces anywhere")
|
|
assert result == []
|
|
|
|
def test_garbage_with_no_json_returns_empty(self):
|
|
"""Random non-JSON text should return empty list."""
|
|
result = _parse_llm_json_array("Hier ist meine Antwort: leider kann ich das nicht.")
|
|
assert result == []
|
|
|
|
def test_bracket_block_extraction(self):
|
|
"""Array embedded in preamble text is extracted via bracket matching."""
|
|
raw = 'Some preamble text...\n[{"title": "Extracted", "objective": "X"}]\nSome trailing text.'
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 1
|
|
assert result[0]["title"] == "Extracted"
|
|
|
|
def test_nested_objects_in_array(self):
|
|
"""Array elements with nested objects (like scope) are parsed correctly."""
|
|
raw = json.dumps([
|
|
{
|
|
"title": "Nested",
|
|
"objective": "Test",
|
|
"scope": {"regions": ["EU", "DE"]},
|
|
"requirements": ["Req1"],
|
|
}
|
|
])
|
|
result = _parse_llm_json_array(raw)
|
|
assert len(result) == 1
|
|
assert result[0]["scope"]["regions"] == ["EU", "DE"]
|
|
|
|
|
|
# =============================================================================
|
|
# Batch Size Configuration Tests
|
|
# =============================================================================
|
|
|
|
class TestBatchSizeConfig:
|
|
"""Tests for batch_size configuration on GeneratorConfig."""
|
|
|
|
def test_default_batch_size(self):
|
|
config = GeneratorConfig()
|
|
assert config.batch_size == 5
|
|
|
|
def test_custom_batch_size(self):
|
|
config = GeneratorConfig(batch_size=10)
|
|
assert config.batch_size == 10
|
|
|
|
def test_batch_size_of_one(self):
|
|
config = GeneratorConfig(batch_size=1)
|
|
assert config.batch_size == 1
|
|
|
|
def test_batch_size_used_in_pipeline_constant(self):
|
|
"""Verify that pipeline uses config.batch_size (not a hardcoded value)."""
|
|
config = GeneratorConfig(batch_size=3)
|
|
# BATCH_SIZE = config.batch_size or 5 — with batch_size=3 it should be 3
|
|
batch_size = config.batch_size or 5
|
|
assert batch_size == 3
|
|
|
|
def test_batch_size_zero_falls_back_to_five(self):
|
|
"""batch_size=0 triggers `or 5` fallback in the pipeline loop."""
|
|
config = GeneratorConfig(batch_size=0)
|
|
# Mimics the pipeline logic: BATCH_SIZE = config.batch_size or 5
|
|
batch_size = config.batch_size or 5
|
|
assert batch_size == 5
|
|
|
|
|
|
# =============================================================================
|
|
# Batch Processing Loop Tests (Mocked)
|
|
# =============================================================================
|
|
|
|
class TestBatchProcessingLoop:
|
|
"""Tests for _process_batch, _structure_batch, _reformulate_batch with mocked LLM."""
|
|
|
|
def _make_chunk(self, regulation_code="owasp_asvs", article="V2.1.1", text="Test requirement"):
|
|
return RAGSearchResult(
|
|
text=text,
|
|
regulation_code=regulation_code,
|
|
regulation_name="OWASP ASVS" if "owasp" in regulation_code else "Test Source",
|
|
regulation_short="OWASP" if "owasp" in regulation_code else "TEST",
|
|
category="requirement",
|
|
article=article,
|
|
paragraph="",
|
|
source_url="https://example.com",
|
|
score=0.9,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_batch_splits_by_license_rule(self):
|
|
"""_process_batch routes Rule 1+2 to _structure_batch and Rule 3 to _reformulate_batch."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
pipeline._existing_controls = []
|
|
|
|
chunk_r1 = self._make_chunk("eu_2016_679", "Art. 35", "DSGVO text")
|
|
chunk_r3 = self._make_chunk("bsi_tr03161", "O.Auth_1", "BSI text")
|
|
|
|
batch_items = [
|
|
(chunk_r1, {"rule": 1, "name": "DSGVO", "license": "EU_LAW"}),
|
|
(chunk_r3, {"rule": 3, "name": "INTERNAL_ONLY"}),
|
|
]
|
|
|
|
# Mock _structure_batch and _reformulate_batch
|
|
structure_ctrl = GeneratedControl(title="Structured", license_rule=1, release_state="draft")
|
|
reform_ctrl = GeneratedControl(title="Reformed", license_rule=3, release_state="draft")
|
|
|
|
mock_finder_instance = AsyncMock()
|
|
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
|
|
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
|
|
|
|
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock, return_value=[structure_ctrl]) as mock_struct, \
|
|
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=[reform_ctrl]) as mock_reform, \
|
|
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
|
|
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
|
|
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock) as mock_sim:
|
|
mock_sim.return_value = MagicMock(status="PASS", token_overlap=0.1, ngram_jaccard=0.1, lcs_ratio=0.1)
|
|
config = GeneratorConfig(batch_size=5)
|
|
result = await pipeline._process_batch(batch_items, config, "test-job-123")
|
|
|
|
mock_struct.assert_called_once()
|
|
mock_reform.assert_called_once()
|
|
# structure_batch received only the Rule 1 chunk
|
|
assert len(mock_struct.call_args[0][0]) == 1
|
|
# reformulate_batch received only the Rule 3 chunk
|
|
assert len(mock_reform.call_args[0][0]) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structure_batch_calls_llm_and_parses(self):
|
|
"""_structure_batch sends prompt to LLM and parses array response."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
|
|
chunks = [
|
|
self._make_chunk("eu_2016_679", "Art. 5", "Datensparsamkeit und Zweckbindung"),
|
|
self._make_chunk("eu_2016_679", "Art. 35", "DSFA bei hohem Risiko"),
|
|
]
|
|
license_infos = [
|
|
{"rule": 1, "name": "DSGVO", "license": "EU_LAW"},
|
|
{"rule": 1, "name": "DSGVO", "license": "EU_LAW"},
|
|
]
|
|
|
|
llm_response = json.dumps([
|
|
{
|
|
"chunk_index": 1,
|
|
"title": "Datensparsamkeit",
|
|
"objective": "Nur notwendige Daten erheben.",
|
|
"rationale": "DSGVO Grundprinzip.",
|
|
"requirements": ["Datenminimierung"],
|
|
"test_procedure": ["Datenbestand pruefen"],
|
|
"evidence": ["Verarbeitungsverzeichnis"],
|
|
"severity": "high",
|
|
"tags": ["dsgvo", "datenschutz"],
|
|
},
|
|
{
|
|
"chunk_index": 2,
|
|
"title": "DSFA Pflicht",
|
|
"objective": "DSFA bei hohem Risiko durchfuehren.",
|
|
"rationale": "Gesetzliche Pflicht.",
|
|
"requirements": ["DSFA erstellen"],
|
|
"test_procedure": ["DSFA Bericht pruefen"],
|
|
"evidence": ["DSFA Dokumentation"],
|
|
"severity": "high",
|
|
"tags": ["dsfa"],
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
|
|
controls = await pipeline._structure_batch(chunks, license_infos)
|
|
|
|
assert len(controls) == 2
|
|
assert controls[0] is not None
|
|
assert controls[0].title == "Datensparsamkeit"
|
|
assert controls[0].license_rule == 1
|
|
assert controls[0].source_original_text is not None
|
|
assert controls[0].customer_visible is True
|
|
assert controls[0].generation_metadata["processing_path"] == "structured_batch"
|
|
assert controls[0].generation_metadata["batch_size"] == 2
|
|
|
|
assert controls[1] is not None
|
|
assert controls[1].title == "DSFA Pflicht"
|
|
assert controls[1].license_rule == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reformulate_batch_calls_llm_and_strips_source(self):
|
|
"""_reformulate_batch produces Rule 3 controls without source info."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
|
|
chunks = [
|
|
self._make_chunk("bsi_tr03161", "O.Auth_1", "Multi-factor authentication for apps"),
|
|
]
|
|
config = GeneratorConfig(batch_size=5)
|
|
|
|
llm_response = json.dumps([
|
|
{
|
|
"chunk_index": 1,
|
|
"title": "Starke Authentifizierung",
|
|
"objective": "Mehrstufige Anmeldung implementieren.",
|
|
"rationale": "Schutz vor unbefugtem Zugriff.",
|
|
"requirements": ["MFA einrichten"],
|
|
"test_procedure": ["MFA Funktionstest"],
|
|
"evidence": ["MFA Konfiguration"],
|
|
"severity": "critical",
|
|
"tags": ["auth", "mfa"],
|
|
}
|
|
])
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
|
|
controls = await pipeline._reformulate_batch(chunks, config)
|
|
|
|
assert len(controls) == 1
|
|
ctrl = controls[0]
|
|
assert ctrl is not None
|
|
assert ctrl.title == "Starke Authentifizierung"
|
|
assert ctrl.license_rule == 3
|
|
assert ctrl.source_original_text is None
|
|
assert ctrl.source_citation is None
|
|
assert ctrl.customer_visible is False
|
|
assert ctrl.generation_metadata["processing_path"] == "llm_reform_batch"
|
|
# Must not contain BSI references
|
|
metadata_str = json.dumps(ctrl.generation_metadata)
|
|
assert "bsi" not in metadata_str.lower()
|
|
assert "TR-03161" not in metadata_str
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structure_batch_maps_by_chunk_index(self):
|
|
"""Controls are mapped back to the correct chunk via chunk_index."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
|
|
chunks = [
|
|
self._make_chunk("eu_2016_679", "Art. 5", "First chunk"),
|
|
self._make_chunk("eu_2016_679", "Art. 6", "Second chunk"),
|
|
self._make_chunk("eu_2016_679", "Art. 7", "Third chunk"),
|
|
]
|
|
license_infos = [{"rule": 1, "name": "DSGVO", "license": "EU_LAW"}] * 3
|
|
|
|
# LLM returns them in reversed order
|
|
llm_response = json.dumps([
|
|
{
|
|
"chunk_index": 3,
|
|
"title": "Third Control",
|
|
"objective": "Obj3",
|
|
"rationale": "Rat3",
|
|
"requirements": [],
|
|
"test_procedure": [],
|
|
"evidence": [],
|
|
"severity": "low",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"chunk_index": 1,
|
|
"title": "First Control",
|
|
"objective": "Obj1",
|
|
"rationale": "Rat1",
|
|
"requirements": [],
|
|
"test_procedure": [],
|
|
"evidence": [],
|
|
"severity": "high",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"chunk_index": 2,
|
|
"title": "Second Control",
|
|
"objective": "Obj2",
|
|
"rationale": "Rat2",
|
|
"requirements": [],
|
|
"test_procedure": [],
|
|
"evidence": [],
|
|
"severity": "medium",
|
|
"tags": [],
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
|
|
controls = await pipeline._structure_batch(chunks, license_infos)
|
|
|
|
assert len(controls) == 3
|
|
# Verify correct mapping despite shuffled chunk_index
|
|
assert controls[0].title == "First Control"
|
|
assert controls[1].title == "Second Control"
|
|
assert controls[2].title == "Third Control"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structure_batch_falls_back_to_position(self):
|
|
"""If chunk_index is missing, controls are assigned by position."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
|
|
chunks = [
|
|
self._make_chunk("eu_2016_679", "Art. 5", "Chunk A"),
|
|
self._make_chunk("eu_2016_679", "Art. 6", "Chunk B"),
|
|
]
|
|
license_infos = [{"rule": 1, "name": "DSGVO", "license": "EU_LAW"}] * 2
|
|
|
|
# No chunk_index in response
|
|
llm_response = json.dumps([
|
|
{
|
|
"title": "PositionA",
|
|
"objective": "ObjA",
|
|
"rationale": "RatA",
|
|
"requirements": [],
|
|
"test_procedure": [],
|
|
"evidence": [],
|
|
"severity": "medium",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"title": "PositionB",
|
|
"objective": "ObjB",
|
|
"rationale": "RatB",
|
|
"requirements": [],
|
|
"test_procedure": [],
|
|
"evidence": [],
|
|
"severity": "medium",
|
|
"tags": [],
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
|
|
controls = await pipeline._structure_batch(chunks, license_infos)
|
|
|
|
assert len(controls) == 2
|
|
assert controls[0].title == "PositionA"
|
|
assert controls[1].title == "PositionB"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_batch_only_structure_no_reform(self):
|
|
"""Batch with only Rule 1+2 items skips _reformulate_batch."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
pipeline._existing_controls = []
|
|
|
|
chunk = self._make_chunk("eu_2016_679", "Art. 5", "DSGVO text")
|
|
batch_items = [
|
|
(chunk, {"rule": 1, "name": "DSGVO", "license": "EU_LAW"}),
|
|
]
|
|
|
|
ctrl = GeneratedControl(title="OnlyStructure", license_rule=1, release_state="draft")
|
|
|
|
mock_finder_instance = AsyncMock()
|
|
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
|
|
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
|
|
|
|
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock, return_value=[ctrl]) as mock_struct, \
|
|
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock) as mock_reform, \
|
|
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
|
|
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls):
|
|
config = GeneratorConfig()
|
|
result, qa_count = await pipeline._process_batch(batch_items, config, "job-1")
|
|
|
|
mock_struct.assert_called_once()
|
|
mock_reform.assert_not_called()
|
|
assert len(result) == 1
|
|
assert result[0].title == "OnlyStructure"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_batch_only_reform_no_structure(self):
|
|
"""Batch with only Rule 3 items skips _structure_batch."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
pipeline._existing_controls = []
|
|
|
|
chunk = self._make_chunk("bsi_tr03161", "O.Auth_1", "BSI text")
|
|
batch_items = [
|
|
(chunk, {"rule": 3, "name": "INTERNAL_ONLY"}),
|
|
]
|
|
|
|
ctrl = GeneratedControl(title="OnlyReform", license_rule=3, release_state="draft")
|
|
|
|
mock_finder_instance = AsyncMock()
|
|
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
|
|
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
|
|
|
|
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock) as mock_struct, \
|
|
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=[ctrl]) as mock_reform, \
|
|
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
|
|
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
|
|
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock) as mock_sim:
|
|
mock_sim.return_value = MagicMock(status="PASS", token_overlap=0.1, ngram_jaccard=0.1, lcs_ratio=0.1)
|
|
config = GeneratorConfig()
|
|
result, qa_count = await pipeline._process_batch(batch_items, config, "job-2")
|
|
|
|
mock_struct.assert_not_called()
|
|
mock_reform.assert_called_once()
|
|
assert len(result) == 1
|
|
assert result[0].title == "OnlyReform"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_batch_mixed_rules(self):
|
|
"""Batch with mixed Rule 1 and Rule 3 items calls both sub-methods."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
pipeline._existing_controls = []
|
|
|
|
chunk_r1 = self._make_chunk("eu_2016_679", "Art. 5", "DSGVO")
|
|
chunk_r2 = self._make_chunk("owasp_asvs", "V2.1", "OWASP")
|
|
chunk_r3a = self._make_chunk("bsi_tr03161", "O.Auth_1", "BSI A")
|
|
chunk_r3b = self._make_chunk("iso_27001", "A.9.1", "ISO B")
|
|
|
|
batch_items = [
|
|
(chunk_r1, {"rule": 1, "name": "DSGVO", "license": "EU_LAW"}),
|
|
(chunk_r3a, {"rule": 3, "name": "INTERNAL_ONLY"}),
|
|
(chunk_r2, {"rule": 2, "name": "OWASP ASVS", "license": "CC-BY-SA-4.0", "attribution": "OWASP Foundation"}),
|
|
(chunk_r3b, {"rule": 3, "name": "INTERNAL_ONLY"}),
|
|
]
|
|
|
|
struct_ctrls = [
|
|
GeneratedControl(title="DSGVO Ctrl", license_rule=1, release_state="draft"),
|
|
GeneratedControl(title="OWASP Ctrl", license_rule=2, release_state="draft"),
|
|
]
|
|
reform_ctrls = [
|
|
GeneratedControl(title="BSI Ctrl", license_rule=3, release_state="draft"),
|
|
GeneratedControl(title="ISO Ctrl", license_rule=3, release_state="draft"),
|
|
]
|
|
|
|
mock_finder_instance = AsyncMock()
|
|
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
|
|
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
|
|
|
|
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock, return_value=struct_ctrls) as mock_struct, \
|
|
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=reform_ctrls) as mock_reform, \
|
|
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
|
|
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
|
|
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock) as mock_sim:
|
|
mock_sim.return_value = MagicMock(status="PASS", token_overlap=0.05, ngram_jaccard=0.05, lcs_ratio=0.05)
|
|
config = GeneratorConfig()
|
|
result, qa_count = await pipeline._process_batch(batch_items, config, "job-mixed")
|
|
|
|
# Both methods called
|
|
mock_struct.assert_called_once()
|
|
mock_reform.assert_called_once()
|
|
# structure_batch gets 2 items (Rule 1 + Rule 2)
|
|
assert len(mock_struct.call_args[0][0]) == 2
|
|
# reformulate_batch gets 2 items (Rule 3 + Rule 3)
|
|
assert len(mock_reform.call_args[0][0]) == 2
|
|
# Result has 4 controls total
|
|
assert len(result) == 4
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_batch_empty_batch(self):
|
|
"""Empty batch returns empty list."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
pipeline._existing_controls = []
|
|
|
|
config = GeneratorConfig()
|
|
result, qa_count = await pipeline._process_batch([], config, "job-empty")
|
|
assert result == []
|
|
assert qa_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reformulate_batch_too_close_flagged(self):
|
|
"""Rule 3 controls that are too similar to source get flagged."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db)
|
|
pipeline._existing_controls = []
|
|
|
|
chunk = self._make_chunk("bsi_tr03161", "O.Auth_1", "Authentication must use MFA")
|
|
batch_items = [
|
|
(chunk, {"rule": 3, "name": "INTERNAL_ONLY"}),
|
|
]
|
|
|
|
ctrl = GeneratedControl(
|
|
title="Auth MFA",
|
|
objective="Authentication must use MFA",
|
|
rationale="Security",
|
|
license_rule=3,
|
|
release_state="draft",
|
|
generation_metadata={},
|
|
)
|
|
|
|
# Simulate similarity FAIL (too close to source)
|
|
fail_report = MagicMock(status="FAIL", token_overlap=0.85, ngram_jaccard=0.9, lcs_ratio=0.88)
|
|
|
|
mock_finder_instance = AsyncMock()
|
|
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
|
|
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
|
|
|
|
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock), \
|
|
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=[ctrl]), \
|
|
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
|
|
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
|
|
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock, return_value=fail_report):
|
|
config = GeneratorConfig()
|
|
result, qa_count = await pipeline._process_batch(batch_items, config, "job-tooclose")
|
|
|
|
assert len(result) == 1
|
|
assert result[0].release_state == "too_close"
|
|
assert result[0].generation_metadata["similarity_status"] == "FAIL"
|
|
|
|
|
|
# =============================================================================
|
|
# Regulation Filter Tests
|
|
# =============================================================================
|
|
|
|
class TestRegulationFilter:
|
|
"""Tests for regulation_filter in GeneratorConfig."""
|
|
|
|
def test_config_accepts_regulation_filter(self):
|
|
config = GeneratorConfig(regulation_filter=["owasp_", "nist_", "eu_2023_1230"])
|
|
assert config.regulation_filter == ["owasp_", "nist_", "eu_2023_1230"]
|
|
|
|
def test_config_default_none(self):
|
|
config = GeneratorConfig()
|
|
assert config.regulation_filter is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_rag_filters_by_regulation(self):
|
|
"""Verify _scan_rag skips chunks not matching regulation_filter."""
|
|
mock_db = MagicMock()
|
|
mock_db.execute.return_value.fetchall.return_value = []
|
|
mock_db.execute.return_value = MagicMock()
|
|
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
|
|
|
|
# Mock Qdrant scroll response with mixed regulation_codes
|
|
qdrant_points = {
|
|
"result": {
|
|
"points": [
|
|
{"id": "1", "payload": {
|
|
"chunk_text": "OWASP ASVS requirement for input validation " * 5,
|
|
"regulation_code": "owasp_asvs",
|
|
"regulation_name": "OWASP ASVS",
|
|
}},
|
|
{"id": "2", "payload": {
|
|
"chunk_text": "AML anti-money laundering requirement for banks " * 5,
|
|
"regulation_code": "amlr",
|
|
"regulation_name": "AML-Verordnung",
|
|
}},
|
|
{"id": "3", "payload": {
|
|
"chunk_text": "NIST secure software development framework req " * 5,
|
|
"regulation_code": "nist_sp_800_218",
|
|
"regulation_name": "NIST SSDF",
|
|
}},
|
|
],
|
|
"next_page_offset": None,
|
|
}
|
|
}
|
|
|
|
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
|
|
mock_client = AsyncMock()
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.json.return_value = qdrant_points
|
|
mock_client.post.return_value = mock_resp
|
|
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
|
|
# With filter: only owasp_ and nist_ prefixes
|
|
config = GeneratorConfig(
|
|
collections=["bp_compliance_ce"],
|
|
regulation_filter=["owasp_", "nist_"],
|
|
)
|
|
results = await pipeline._scan_rag(config)
|
|
|
|
# Should only get 2 chunks (owasp + nist), not amlr
|
|
assert len(results) == 2
|
|
codes = {r.regulation_code for r in results}
|
|
assert "owasp_asvs" in codes
|
|
assert "nist_sp_800_218" in codes
|
|
assert "amlr" not in codes
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_rag_filters_out_empty_regulation_code(self):
|
|
"""Chunks without regulation_code must be skipped when filter is active."""
|
|
mock_db = MagicMock()
|
|
mock_db.execute.return_value = MagicMock()
|
|
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
|
|
|
|
qdrant_points = {
|
|
"result": {
|
|
"points": [
|
|
{"id": "1", "payload": {
|
|
"chunk_text": "OWASP ASVS requirement for input validation " * 5,
|
|
"regulation_code": "owasp_asvs",
|
|
}},
|
|
{"id": "2", "payload": {
|
|
"chunk_text": "Some template without regulation code at all " * 5,
|
|
# No regulation_id, regulation_code, source_id, or source_code
|
|
}},
|
|
{"id": "3", "payload": {
|
|
"chunk_text": "Another chunk with empty regulation code value " * 5,
|
|
"regulation_code": "",
|
|
}},
|
|
],
|
|
"next_page_offset": None,
|
|
}
|
|
}
|
|
|
|
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
|
|
mock_client = AsyncMock()
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.json.return_value = qdrant_points
|
|
mock_client.post.return_value = mock_resp
|
|
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
config = GeneratorConfig(
|
|
collections=["bp_compliance_ce"],
|
|
regulation_filter=["owasp_"],
|
|
)
|
|
results = await pipeline._scan_rag(config)
|
|
|
|
# Only the owasp chunk should pass — empty reg_code chunks are filtered out
|
|
assert len(results) == 1
|
|
assert results[0].regulation_code == "owasp_asvs"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_rag_no_filter_returns_all(self):
|
|
"""Verify _scan_rag returns all chunks when no regulation_filter."""
|
|
mock_db = MagicMock()
|
|
mock_db.execute.return_value.fetchall.return_value = []
|
|
mock_db.execute.return_value = MagicMock()
|
|
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
|
|
|
|
qdrant_points = {
|
|
"result": {
|
|
"points": [
|
|
{"id": "1", "payload": {
|
|
"chunk_text": "OWASP requirement for secure authentication " * 5,
|
|
"regulation_code": "owasp_asvs",
|
|
}},
|
|
{"id": "2", "payload": {
|
|
"chunk_text": "AML compliance requirement for financial inst " * 5,
|
|
"regulation_code": "amlr",
|
|
}},
|
|
],
|
|
"next_page_offset": None,
|
|
}
|
|
}
|
|
|
|
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
|
|
mock_client = AsyncMock()
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.json.return_value = qdrant_points
|
|
mock_client.post.return_value = mock_resp
|
|
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
config = GeneratorConfig(
|
|
collections=["bp_compliance_ce"],
|
|
regulation_filter=None,
|
|
)
|
|
results = await pipeline._scan_rag(config)
|
|
|
|
assert len(results) == 2
|
|
|
|
|
|
# =============================================================================
|
|
# Pipeline Version Tests
|
|
# =============================================================================
|
|
|
|
class TestPipelineVersion:
|
|
"""Tests for pipeline_version propagation in DB writes and null handling."""
|
|
|
|
def test_pipeline_version_constant_is_3(self):
|
|
assert PIPELINE_VERSION == 3
|
|
|
|
def test_store_control_includes_pipeline_version(self):
|
|
"""_store_control must pass pipeline_version=PIPELINE_VERSION to the INSERT."""
|
|
mock_db = MagicMock()
|
|
# Framework lookup returns a UUID
|
|
fw_row = MagicMock()
|
|
fw_row.__getitem__ = lambda self, idx: "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
mock_db.execute.return_value.fetchone.return_value = fw_row
|
|
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
|
|
control = GeneratedControl(
|
|
control_id="SEC-TEST-001",
|
|
title="Test Control",
|
|
objective="Test objective",
|
|
)
|
|
pipeline._store_control(control, job_id="00000000-0000-0000-0000-000000000001")
|
|
|
|
# The second call to db.execute is the INSERT
|
|
calls = mock_db.execute.call_args_list
|
|
assert len(calls) >= 2, f"Expected at least 2 db.execute calls, got {len(calls)}"
|
|
insert_call = calls[1]
|
|
params = insert_call[0][1] # positional arg 1 = params dict
|
|
assert "pipeline_version" in params
|
|
assert params["pipeline_version"] == PIPELINE_VERSION
|
|
|
|
def test_mark_chunk_processed_includes_pipeline_version(self):
|
|
"""_mark_chunk_processed must pass pipeline_version=PIPELINE_VERSION to the INSERT."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
|
|
chunk = MagicMock()
|
|
chunk.text = "Some chunk text for hashing"
|
|
chunk.collection = "bp_compliance_ce"
|
|
chunk.regulation_code = "eu_2016_679"
|
|
|
|
license_info = {"license": "CC0-1.0", "rule": 1}
|
|
|
|
pipeline._mark_chunk_processed(
|
|
chunk=chunk,
|
|
license_info=license_info,
|
|
processing_path="structured_batch",
|
|
control_ids=["SEC-TEST-001"],
|
|
job_id="00000000-0000-0000-0000-000000000001",
|
|
)
|
|
|
|
calls = mock_db.execute.call_args_list
|
|
assert len(calls) >= 1
|
|
insert_call = calls[0]
|
|
params = insert_call[0][1]
|
|
assert "pipeline_version" in params
|
|
assert params["pipeline_version"] == PIPELINE_VERSION
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_structure_batch_handles_null_results(self):
|
|
"""When _parse_llm_json_array returns [dict, None, dict], the null entries produce None."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
|
|
# Three chunks
|
|
chunks = []
|
|
license_infos = []
|
|
for i in range(3):
|
|
c = MagicMock()
|
|
c.text = f"Chunk text number {i} with enough content for processing"
|
|
c.regulation_name = "DSGVO"
|
|
c.regulation_code = "eu_2016_679"
|
|
c.article = f"Art. {i + 1}"
|
|
c.paragraph = ""
|
|
c.source_url = ""
|
|
c.collection = "bp_compliance_ce"
|
|
chunks.append(c)
|
|
license_infos.append({"rule": 1, "name": "DSGVO", "license": "CC0-1.0"})
|
|
|
|
# LLM returns a JSON array: valid, null, valid
|
|
llm_response = json.dumps([
|
|
{
|
|
"chunk_index": 1,
|
|
"title": "Datenschutz-Kontrolle 1",
|
|
"objective": "Schutz personenbezogener Daten",
|
|
"rationale": "DSGVO-Konformitaet",
|
|
"requirements": ["Req 1"],
|
|
"test_procedure": ["Test 1"],
|
|
"evidence": ["Nachweis 1"],
|
|
"severity": "high",
|
|
"tags": ["dsgvo"],
|
|
"domain": "DATA",
|
|
"category": "datenschutz",
|
|
"target_audience": ["unternehmen"],
|
|
"source_article": "Art. 1",
|
|
"source_paragraph": "",
|
|
},
|
|
None,
|
|
{
|
|
"chunk_index": 3,
|
|
"title": "Datenschutz-Kontrolle 3",
|
|
"objective": "Transparenzpflicht",
|
|
"rationale": "Information der Betroffenen",
|
|
"requirements": ["Req 3"],
|
|
"test_procedure": ["Test 3"],
|
|
"evidence": ["Nachweis 3"],
|
|
"severity": "medium",
|
|
"tags": ["transparenz"],
|
|
"domain": "DATA",
|
|
"category": "datenschutz",
|
|
"target_audience": ["unternehmen"],
|
|
"source_article": "Art. 3",
|
|
"source_paragraph": "",
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = llm_response
|
|
controls = await pipeline._structure_batch(chunks, license_infos)
|
|
|
|
assert len(controls) == 3
|
|
assert controls[0] is not None
|
|
assert controls[1] is None # Null entry from LLM
|
|
assert controls[2] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reformulate_batch_handles_null_results(self):
|
|
"""When _parse_llm_json_array returns [dict, None, dict], the null entries produce None."""
|
|
mock_db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
|
|
|
|
chunks = []
|
|
for i in range(3):
|
|
c = MagicMock()
|
|
c.text = f"Restricted chunk text number {i} with BSI content"
|
|
c.regulation_name = "BSI TR-03161"
|
|
c.regulation_code = "bsi_tr03161"
|
|
c.article = f"Section {i + 1}"
|
|
c.paragraph = ""
|
|
c.source_url = ""
|
|
c.collection = "bp_compliance_ce"
|
|
chunks.append(c)
|
|
|
|
config = GeneratorConfig(domain="SEC")
|
|
|
|
llm_response = json.dumps([
|
|
{
|
|
"chunk_index": 1,
|
|
"title": "Sicherheitskontrolle 1",
|
|
"objective": "Authentifizierung absichern",
|
|
"rationale": "Best Practice",
|
|
"requirements": ["Req 1"],
|
|
"test_procedure": ["Test 1"],
|
|
"evidence": ["Nachweis 1"],
|
|
"severity": "high",
|
|
"tags": ["sicherheit"],
|
|
"domain": "SEC",
|
|
"category": "it-sicherheit",
|
|
"target_audience": ["it-abteilung"],
|
|
},
|
|
None,
|
|
{
|
|
"chunk_index": 3,
|
|
"title": "Sicherheitskontrolle 3",
|
|
"objective": "Netzwerk segmentieren",
|
|
"rationale": "Angriffsoberflaeche reduzieren",
|
|
"requirements": ["Req 3"],
|
|
"test_procedure": ["Test 3"],
|
|
"evidence": ["Nachweis 3"],
|
|
"severity": "medium",
|
|
"tags": ["netzwerk"],
|
|
"domain": "NET",
|
|
"category": "netzwerksicherheit",
|
|
"target_audience": ["it-abteilung"],
|
|
},
|
|
])
|
|
|
|
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock) as mock_llm:
|
|
mock_llm.return_value = llm_response
|
|
controls = await pipeline._reformulate_batch(chunks, config)
|
|
|
|
assert len(controls) == 3
|
|
assert controls[0] is not None
|
|
assert controls[1] is None # Null entry from LLM
|
|
assert controls[2] is not None
|
|
|
|
|
|
# =============================================================================
|
|
# Recital (Erwägungsgrund) Detection Tests
|
|
# =============================================================================
|
|
|
|
class TestRecitalDetection:
|
|
"""Tests for _detect_recital — identifying Erwägungsgrund text in source."""
|
|
|
|
def test_recital_number_detected(self):
|
|
"""Text with (126)\\n pattern is flagged as recital suspect."""
|
|
text = "Daher ist es wichtig...\n(126)\nDie Konformitätsbewertung sollte..."
|
|
result = _detect_recital(text)
|
|
assert result is not None
|
|
assert result["recital_suspect"] is True
|
|
assert "126" in result["recital_numbers"]
|
|
|
|
def test_multiple_recital_numbers(self):
|
|
"""Multiple recital markers are all captured."""
|
|
text = "(124)\nErster Punkt.\n(125)\nZweiter Punkt.\n(126)\nDritter Punkt."
|
|
result = _detect_recital(text)
|
|
assert result is not None
|
|
assert "124" in result["recital_numbers"]
|
|
assert "125" in result["recital_numbers"]
|
|
assert "126" in result["recital_numbers"]
|
|
|
|
def test_article_text_not_flagged(self):
|
|
"""Normal article text without recital markers returns None."""
|
|
text = ("Der Anbieter eines Hochrisiko-KI-Systems muss sicherstellen, "
|
|
"dass die technische Dokumentation erstellt wird.")
|
|
result = _detect_recital(text)
|
|
assert result is None
|
|
|
|
def test_empty_text_returns_none(self):
|
|
result = _detect_recital("")
|
|
assert result is None
|
|
|
|
def test_none_text_returns_none(self):
|
|
result = _detect_recital(None)
|
|
assert result is None
|
|
|
|
def test_recital_phrases_detected(self):
|
|
"""Text with multiple recital-typical phrases is flagged."""
|
|
text = ("In Erwägung nachstehender Gründe wurde beschlossen, "
|
|
"daher sollte der Anbieter folgende Maßnahmen ergreifen. "
|
|
"Es ist daher notwendig, die Konformität sicherzustellen.")
|
|
result = _detect_recital(text)
|
|
assert result is not None
|
|
assert result["detection_method"] == "phrases"
|
|
|
|
def test_single_phrase_not_enough(self):
|
|
"""A single recital phrase alone is not sufficient for detection."""
|
|
text = "Daher sollte das System regelmäßig geprüft werden."
|
|
result = _detect_recital(text)
|
|
assert result is None
|
|
|
|
def test_combined_regex_and_phrases(self):
|
|
"""Both recital numbers and phrases → detection_method is regex+phrases."""
|
|
text = "(42)\nIn Erwägung nachstehender Gründe wurde entschieden..."
|
|
result = _detect_recital(text)
|
|
assert result is not None
|
|
assert result["detection_method"] == "regex+phrases"
|
|
assert "42" in result["recital_numbers"]
|
|
|
|
def test_parenthesized_number_without_newline_ignored(self):
|
|
"""Numbers in parentheses without trailing newline are not recital markers.
|
|
e.g. 'gemäß Absatz (3) des Artikels' should not be flagged."""
|
|
text = "Gemäß Absatz (3) des Artikels 52 muss der Anbieter sicherstellen..."
|
|
result = _detect_recital(text)
|
|
assert result is None
|
|
|
|
def test_real_world_recital_text(self):
|
|
"""Real-world example: AI Act Erwägungsgrund (126) about conformity assessment."""
|
|
text = (
|
|
"(126)\n"
|
|
"Um den Verwaltungsaufwand zu verringern und die Konformitätsbewertung "
|
|
"zu vereinfachen, sollten bestimmte Hochrisiko-KI-Systeme, die von "
|
|
"Anbietern zertifiziert oder für die eine Konformitätserklärung "
|
|
"ausgestellt wurde, automatisch als konform mit den Anforderungen "
|
|
"dieser Verordnung gelten, sofern sie den harmonisierten Normen oder "
|
|
"gemeinsamen Spezifikationen entsprechen.\n"
|
|
"(127)\n"
|
|
"Es ist daher angezeigt, dass der Anbieter das entsprechende "
|
|
"Konformitätsbewertungsverfahren anwendet."
|
|
)
|
|
result = _detect_recital(text)
|
|
assert result is not None
|
|
assert "126" in result["recital_numbers"]
|
|
assert "127" in result["recital_numbers"]
|
|
|
|
|
|
# =============================================================================
|
|
# Source Type Classification Tests
|
|
# =============================================================================
|
|
|
|
class TestSourceTypeClassification:
|
|
"""Tests that source_type correctly distinguishes law vs guideline vs standard vs restricted."""
|
|
|
|
def test_eu_regulations_are_law(self):
|
|
"""All EU regulations (Verordnungen/Richtlinien) must be classified as 'law'."""
|
|
eu_laws = ["eu_2016_679", "eu_2024_1689", "eu_2022_2555", "eu_2024_2847",
|
|
"eucsa", "dataact", "dora", "eaa"]
|
|
for code in eu_laws:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "law", f"{code} should be law, got {info['source_type']}"
|
|
|
|
def test_german_laws_are_law(self):
|
|
"""German national laws must be classified as 'law'."""
|
|
de_laws = ["bdsg", "ttdsg", "tkg", "bgb_komplett", "hgb", "gewo"]
|
|
for code in de_laws:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "law", f"{code} should be law, got {info['source_type']}"
|
|
|
|
def test_austrian_laws_are_law(self):
|
|
"""Austrian laws must be classified as 'law'."""
|
|
at_laws = ["at_dsg", "at_abgb", "at_ecg", "at_tkg"]
|
|
for code in at_laws:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "law", f"{code} should be law, got {info['source_type']}"
|
|
|
|
def test_nist_is_standard_not_law(self):
|
|
"""NIST frameworks are US standards, NOT EU law — must be 'standard'."""
|
|
nist_codes = ["nist_sp_800_53", "nist_csf_2_0", "nist_ai_rmf", "nistir_8259a"]
|
|
for code in nist_codes:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "standard", f"{code} should be standard, got {info['source_type']}"
|
|
|
|
def test_cisa_is_standard(self):
|
|
info = _classify_regulation("cisa_secure_by_design")
|
|
assert info["source_type"] == "standard"
|
|
|
|
def test_owasp_is_standard(self):
|
|
"""OWASP frameworks are voluntary standards, not law."""
|
|
owasp_codes = ["owasp_asvs", "owasp_top10", "owasp_samm"]
|
|
for code in owasp_codes:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "standard", f"{code} should be standard, got {info['source_type']}"
|
|
|
|
def test_enisa_prefix_is_standard(self):
|
|
info = _classify_regulation("enisa_threat_landscape")
|
|
assert info["source_type"] == "standard"
|
|
|
|
def test_oecd_is_standard(self):
|
|
info = _classify_regulation("oecd_ai_principles")
|
|
assert info["source_type"] == "standard"
|
|
|
|
def test_edpb_is_guideline(self):
|
|
"""EDPB guidelines are authoritative but non-binding soft law."""
|
|
edpb_codes = ["edpb_01_2020", "edpb_dpbd_04_2019", "edpb_legitimate_interest"]
|
|
for code in edpb_codes:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "guideline", f"{code} should be guideline, got {info['source_type']}"
|
|
|
|
def test_wp29_is_guideline(self):
|
|
"""WP29 (pre-EDPB) guidelines are soft law."""
|
|
for code in ["wp244_profiling", "wp260_transparency"]:
|
|
info = _classify_regulation(code)
|
|
assert info["source_type"] == "guideline", f"{code} should be guideline, got {info['source_type']}"
|
|
|
|
def test_blue_guide_is_guideline(self):
|
|
info = _classify_regulation("eu_blue_guide_2022")
|
|
assert info["source_type"] == "guideline"
|
|
|
|
def test_bsi_is_restricted(self):
|
|
info = _classify_regulation("bsi_grundschutz")
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_iso_is_restricted(self):
|
|
info = _classify_regulation("iso_27001")
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_etsi_is_restricted(self):
|
|
info = _classify_regulation("etsi_en_303_645")
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_unknown_is_restricted(self):
|
|
info = _classify_regulation("totally_unknown")
|
|
assert info["source_type"] == "restricted"
|
|
|
|
def test_source_type_and_license_rule_are_independent(self):
|
|
"""source_type classifies legal authority; license_rule classifies copyright.
|
|
NIST is Rule 1 (public domain, free use) but source_type='standard' (not a law)."""
|
|
nist = _classify_regulation("nist_sp_800_53")
|
|
assert nist["rule"] == 1 # free use (copyright)
|
|
assert nist["source_type"] == "standard" # NOT law (legal authority)
|
|
|
|
edpb = _classify_regulation("edpb_01_2020")
|
|
assert edpb["rule"] == 1 # free use (public authority)
|
|
assert edpb["source_type"] == "guideline" # NOT law (soft law)
|
|
|
|
|
|
# =============================================================================
|
|
# Scoped Control Applicability Tests (v3 Pipeline)
|
|
# =============================================================================
|
|
|
|
class TestApplicabilityFields:
|
|
"""Tests for applicable_industries, applicable_company_size, scope_conditions parsing."""
|
|
|
|
def _make_pipeline(self):
|
|
"""Create a pipeline with mocked DB."""
|
|
db = MagicMock()
|
|
pipeline = ControlGeneratorPipeline(db=db, rag_client=MagicMock())
|
|
return pipeline
|
|
|
|
def test_all_industries_parsed(self):
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Test",
|
|
"objective": "Test objective",
|
|
"applicable_industries": ["all"],
|
|
"applicable_company_size": ["all"],
|
|
"scope_conditions": None,
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert control.applicable_industries == ["all"]
|
|
assert control.applicable_company_size == ["all"]
|
|
assert control.scope_conditions is None
|
|
|
|
def test_specific_industries_parsed(self):
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "TKG Control",
|
|
"objective": "Telekommunikation",
|
|
"applicable_industries": ["Telekommunikation", "Energie"],
|
|
"applicable_company_size": ["medium", "large", "enterprise"],
|
|
"scope_conditions": None,
|
|
}
|
|
control = pipeline._build_control_from_json(data, "INC")
|
|
assert control.applicable_industries == ["Telekommunikation", "Energie"]
|
|
assert control.applicable_company_size == ["medium", "large", "enterprise"]
|
|
|
|
def test_scope_conditions_parsed(self):
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "AI Act Control",
|
|
"objective": "KI-Risikomanagement",
|
|
"applicable_industries": ["all"],
|
|
"applicable_company_size": ["all"],
|
|
"scope_conditions": {
|
|
"requires_any": ["uses_ai"],
|
|
"description": "Nur bei KI-Einsatz relevant",
|
|
},
|
|
}
|
|
control = pipeline._build_control_from_json(data, "AI")
|
|
assert control.scope_conditions is not None
|
|
assert control.scope_conditions["requires_any"] == ["uses_ai"]
|
|
assert "KI" in control.scope_conditions["description"]
|
|
|
|
def test_missing_applicability_fields_are_none(self):
|
|
"""Old-style LLM response without applicability fields."""
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Legacy Control",
|
|
"objective": "Test",
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert control.applicable_industries is None
|
|
assert control.applicable_company_size is None
|
|
assert control.scope_conditions is None
|
|
|
|
def test_string_industry_converted_to_list(self):
|
|
"""LLM sometimes returns a string instead of list."""
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Test",
|
|
"objective": "Test",
|
|
"applicable_industries": "Telekommunikation",
|
|
"applicable_company_size": "all",
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert control.applicable_industries == ["Telekommunikation"]
|
|
assert control.applicable_company_size == ["all"]
|
|
|
|
def test_invalid_company_size_filtered(self):
|
|
"""Invalid size values should be filtered out."""
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Test",
|
|
"objective": "Test",
|
|
"applicable_company_size": ["medium", "huge", "large"],
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert control.applicable_company_size == ["medium", "large"]
|
|
|
|
def test_all_invalid_sizes_results_in_none(self):
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Test",
|
|
"objective": "Test",
|
|
"applicable_company_size": ["huge", "tiny"],
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert control.applicable_company_size is None
|
|
|
|
def test_scope_conditions_non_dict_ignored(self):
|
|
"""If LLM returns a string for scope_conditions, ignore it."""
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Test",
|
|
"objective": "Test",
|
|
"scope_conditions": "uses_ai",
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert control.scope_conditions is None
|
|
|
|
def test_multiple_scope_signals(self):
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "EHDS Control",
|
|
"objective": "Gesundheitsdaten",
|
|
"applicable_industries": ["Gesundheitswesen", "Pharma"],
|
|
"applicable_company_size": ["all"],
|
|
"scope_conditions": {
|
|
"requires_any": ["processes_health_data", "uses_ai"],
|
|
"description": "Gesundheitsdaten mit KI-Verarbeitung",
|
|
},
|
|
}
|
|
control = pipeline._build_control_from_json(data, "HLT")
|
|
assert len(control.scope_conditions["requires_any"]) == 2
|
|
assert "processes_health_data" in control.scope_conditions["requires_any"]
|
|
|
|
def test_pipeline_version_is_3(self):
|
|
"""v3 pipeline includes applicability fields."""
|
|
assert PIPELINE_VERSION == 3
|
|
|
|
def test_generated_control_dataclass_has_fields(self):
|
|
"""Verify the dataclass has the new fields with correct defaults."""
|
|
ctrl = GeneratedControl()
|
|
assert ctrl.applicable_industries is None
|
|
assert ctrl.applicable_company_size is None
|
|
assert ctrl.scope_conditions is None
|
|
|
|
def test_applicability_in_generation_metadata_not_leaked(self):
|
|
"""Applicability fields should be top-level, not in generation_metadata."""
|
|
pipeline = self._make_pipeline()
|
|
data = {
|
|
"title": "Test",
|
|
"objective": "Test",
|
|
"applicable_industries": ["all"],
|
|
"applicable_company_size": ["all"],
|
|
"scope_conditions": None,
|
|
}
|
|
control = pipeline._build_control_from_json(data, "SEC")
|
|
assert "applicable_industries" not in control.generation_metadata
|
|
assert "applicable_company_size" not in control.generation_metadata
|