Files
breakpilot-compliance/backend-compliance/tests/test_control_generator.py
Benjamin Admin d22c47c9eb
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 35s
CI/CD / test-python-backend-compliance (push) Successful in 34s
CI/CD / test-python-document-crawler (push) Successful in 22s
CI/CD / test-python-dsms-gateway (push) Successful in 19s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
feat(pipeline): Anthropic Batch API, source/regulation filter, cost optimization
- Add Anthropic API support to decomposition Pass 0a/0b (prompt caching, content batching)
- Add Anthropic Batch API (50% cost reduction, async 24h processing)
- Add source_filter (ILIKE on source_citation) for regulation-based filtering
- Add category_filter to Pass 0a for selective decomposition
- Add regulation_filter to control_generator for RAG scan phase filtering
  (prefix match on regulation_code — enables CE + Code Review focus)
- New API endpoints: batch-submit-0a, batch-submit-0b, batch-status, batch-process
- 83 new tests (all passing)

Cost reduction: $2,525 → ~$600-700 with all optimizations combined.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 13:22:01 +01:00

1067 lines
43 KiB
Python

"""Tests for Control Generator Pipeline."""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from compliance.services.control_generator import (
_classify_regulation,
_detect_domain,
_parse_llm_json,
GeneratorConfig,
GeneratedControl,
ControlGeneratorPipeline,
REGULATION_LICENSE_MAP,
)
from compliance.services.anchor_finder import AnchorFinder, OpenAnchor
from compliance.services.rag_client import RAGSearchResult
# =============================================================================
# License Mapping Tests
# =============================================================================
class TestLicenseMapping:
"""Tests for regulation_code → license rule classification."""
def test_rule1_eu_law(self):
info = _classify_regulation("eu_2016_679")
assert info["rule"] == 1
assert info["name"] == "DSGVO"
def test_rule1_nist(self):
info = _classify_regulation("nist_sp_800_53")
assert info["rule"] == 1
assert "NIST" in info["name"]
def test_rule1_german_law(self):
info = _classify_regulation("bdsg")
assert info["rule"] == 1
assert info["name"] == "BDSG"
def test_rule2_owasp(self):
info = _classify_regulation("owasp_asvs")
assert info["rule"] == 2
assert "OWASP" in info["name"]
assert "attribution" in info
def test_rule2_enisa_prefix(self):
info = _classify_regulation("enisa_iot_security")
assert info["rule"] == 2
assert "ENISA" in info["name"]
def test_rule3_bsi_prefix(self):
info = _classify_regulation("bsi_tr03161")
assert info["rule"] == 3
assert info["name"] == "INTERNAL_ONLY"
def test_rule3_iso_prefix(self):
info = _classify_regulation("iso_27001")
assert info["rule"] == 3
def test_rule3_etsi_prefix(self):
info = _classify_regulation("etsi_en_303_645")
assert info["rule"] == 3
def test_unknown_defaults_to_rule3(self):
info = _classify_regulation("some_unknown_source")
assert info["rule"] == 3
def test_case_insensitive(self):
info = _classify_regulation("EU_2016_679")
assert info["rule"] == 1
def test_all_mapped_regulations_have_valid_rules(self):
for code, info in REGULATION_LICENSE_MAP.items():
assert info["rule"] in (1, 2, 3), f"{code} has invalid rule {info['rule']}"
def test_rule3_never_exposes_names(self):
for prefix in ["bsi_test", "iso_test", "etsi_test"]:
info = _classify_regulation(prefix)
assert info["name"] == "INTERNAL_ONLY", f"{prefix} exposes name: {info['name']}"
# =============================================================================
# Domain Detection Tests
# =============================================================================
class TestDomainDetection:
def test_auth_domain(self):
assert _detect_domain("Multi-factor authentication and password policy") == "AUTH"
def test_crypto_domain(self):
assert _detect_domain("TLS 1.3 encryption and certificate management") == "CRYPT"
def test_network_domain(self):
assert _detect_domain("Firewall rules and network segmentation") == "NET"
def test_data_domain(self):
assert _detect_domain("DSGVO personenbezogene Daten Datenschutz") == "DATA"
def test_default_domain(self):
assert _detect_domain("random unrelated text xyz") == "SEC"
# =============================================================================
# JSON Parsing Tests
# =============================================================================
class TestJsonParsing:
def test_parse_plain_json(self):
result = _parse_llm_json('{"title": "Test", "objective": "Test obj"}')
assert result["title"] == "Test"
def test_parse_markdown_fenced_json(self):
raw = '```json\n{"title": "Test"}\n```'
result = _parse_llm_json(raw)
assert result["title"] == "Test"
def test_parse_json_with_preamble(self):
raw = 'Here is the result:\n{"title": "Test"}'
result = _parse_llm_json(raw)
assert result["title"] == "Test"
def test_parse_invalid_json(self):
result = _parse_llm_json("not json at all")
assert result == {}
# =============================================================================
# GeneratedControl Rule Tests
# =============================================================================
class TestGeneratedControlRules:
"""Tests that enforce the 3-rule licensing constraints."""
def test_rule1_has_original_text(self):
ctrl = GeneratedControl(license_rule=1)
ctrl.source_original_text = "Original EU law text"
ctrl.source_citation = {"source": "DSGVO Art. 35", "license": "EU_LAW"}
ctrl.customer_visible = True
assert ctrl.source_original_text is not None
assert ctrl.source_citation is not None
assert ctrl.customer_visible is True
def test_rule2_has_citation(self):
ctrl = GeneratedControl(license_rule=2)
ctrl.source_citation = {"source": "OWASP ASVS V2.1", "license": "CC-BY-SA-4.0"}
ctrl.customer_visible = True
assert ctrl.source_citation is not None
assert "CC-BY-SA" in ctrl.source_citation["license"]
def test_rule3_no_original_no_citation(self):
ctrl = GeneratedControl(license_rule=3)
ctrl.source_original_text = None
ctrl.source_citation = None
ctrl.customer_visible = False
ctrl.generation_metadata = {"processing_path": "llm_reform", "license_rule": 3}
assert ctrl.source_original_text is None
assert ctrl.source_citation is None
assert ctrl.customer_visible is False
# generation_metadata must NOT contain source names
metadata_str = json.dumps(ctrl.generation_metadata)
assert "bsi" not in metadata_str.lower()
assert "iso" not in metadata_str.lower()
assert "TR-03161" not in metadata_str
# =============================================================================
# Anchor Finder Tests
# =============================================================================
class TestAnchorFinder:
@pytest.mark.asyncio
async def test_rag_anchor_search_filters_restricted(self):
"""Only Rule 1+2 sources are returned as anchors."""
mock_rag = AsyncMock()
mock_rag.search.return_value = [
RAGSearchResult(
text="OWASP requirement",
regulation_code="owasp_asvs",
regulation_name="OWASP ASVS",
regulation_short="OWASP",
category="requirement",
article="V2.1.1",
paragraph="",
source_url="https://owasp.org",
score=0.9,
),
RAGSearchResult(
text="BSI requirement",
regulation_code="bsi_tr03161",
regulation_name="BSI TR-03161",
regulation_short="BSI",
category="requirement",
article="O.Auth_1",
paragraph="",
source_url="",
score=0.85,
),
]
finder = AnchorFinder(rag_client=mock_rag)
control = GeneratedControl(title="Test Auth Control", tags=["auth"])
anchors = await finder.find_anchors(control, skip_web=True)
# Only OWASP should be returned (Rule 2), BSI should be filtered out (Rule 3)
assert len(anchors) == 1
assert anchors[0].framework == "OWASP ASVS"
@pytest.mark.asyncio
async def test_web_search_identifies_frameworks(self):
finder = AnchorFinder()
assert finder._identify_framework_from_url("https://owasp.org/asvs") == "OWASP"
assert finder._identify_framework_from_url("https://csrc.nist.gov/sp800-53") == "NIST"
assert finder._identify_framework_from_url("https://www.enisa.europa.eu/pub") == "ENISA"
assert finder._identify_framework_from_url("https://random-site.com") is None
# =============================================================================
# Pipeline Integration Tests (Mocked)
# =============================================================================
class TestPipelineMocked:
"""Tests for the pipeline with mocked DB and external services."""
def _make_chunk(self, regulation_code: str = "owasp_asvs", article: str = "V2.1.1"):
return RAGSearchResult(
text="Applications must implement multi-factor authentication.",
regulation_code=regulation_code,
regulation_name="OWASP ASVS",
regulation_short="OWASP",
category="requirement",
article=article,
paragraph="",
source_url="https://owasp.org",
score=0.9,
)
@pytest.mark.asyncio
async def test_rule1_processing_path(self):
"""Rule 1 chunks produce controls with original text."""
chunk = self._make_chunk(regulation_code="eu_2016_679", article="Art. 35")
chunk.text = "Die Datenschutz-Folgenabschaetzung ist durchzufuehren."
chunk.regulation_name = "DSGVO"
mock_db = MagicMock()
mock_db.execute.return_value.fetchone.return_value = None
pipeline = ControlGeneratorPipeline(db=mock_db)
license_info = pipeline._classify_license(chunk)
assert license_info["rule"] == 1
@pytest.mark.asyncio
async def test_rule3_processing_blocks_source_info(self):
"""Rule 3 must never store original text or source names."""
mock_db = MagicMock()
mock_rag = AsyncMock()
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=mock_rag)
# Simulate LLM response
llm_response = json.dumps({
"title": "Secure Password Storage",
"objective": "Passwords must be hashed with modern algorithms.",
"rationale": "Prevents credential theft.",
"requirements": ["Use bcrypt or argon2"],
"test_procedure": ["Verify hash algorithm"],
"evidence": ["Config review"],
"severity": "high",
"tags": ["auth", "password"],
})
with patch("compliance.services.control_generator._llm_chat", return_value=llm_response):
chunk = self._make_chunk(regulation_code="bsi_tr03161", article="O.Auth_1")
config = GeneratorConfig(max_controls=1)
control = await pipeline._llm_reformulate(chunk, config)
assert control.license_rule == 3
assert control.source_original_text is None
assert control.source_citation is None
assert control.customer_visible is False
# Verify no BSI references in metadata
metadata_str = json.dumps(control.generation_metadata)
assert "bsi" not in metadata_str.lower()
assert "BSI" not in metadata_str
assert "TR-03161" not in metadata_str
@pytest.mark.asyncio
async def test_chunk_hash_deduplication(self):
"""Same chunk text produces same hash — no double processing."""
import hashlib
text = "Test requirement text"
h1 = hashlib.sha256(text.encode()).hexdigest()
h2 = hashlib.sha256(text.encode()).hexdigest()
assert h1 == h2
def test_config_defaults(self):
config = GeneratorConfig()
assert config.max_controls == 0
assert config.batch_size == 5
assert config.skip_processed is True
assert config.dry_run is False
@pytest.mark.asyncio
async def test_structure_free_use_produces_citation(self):
"""Rule 1 structuring includes source citation."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
llm_response = json.dumps({
"title": "DSFA Pflicht",
"objective": "DSFA bei hohem Risiko durchfuehren.",
"rationale": "Gesetzliche Pflicht nach DSGVO.",
"requirements": ["DSFA durchfuehren"],
"test_procedure": ["DSFA Bericht pruefen"],
"evidence": ["DSFA Dokumentation"],
"severity": "high",
"tags": ["dsfa", "dsgvo"],
})
chunk = self._make_chunk(regulation_code="eu_2016_679", article="Art. 35")
chunk.text = "Art. 35 DSGVO: Datenschutz-Folgenabschaetzung"
chunk.regulation_name = "DSGVO"
license_info = _classify_regulation("eu_2016_679")
with patch("compliance.services.control_generator._llm_chat", return_value=llm_response):
control = await pipeline._structure_free_use(chunk, license_info)
assert control.license_rule == 1
assert control.source_original_text is not None
assert control.source_citation is not None
assert "DSGVO" in control.source_citation["source"]
assert control.customer_visible is True
# =============================================================================
# JSON Array Parsing Tests (_parse_llm_json_array)
# =============================================================================
from compliance.services.control_generator import _parse_llm_json_array
class TestParseJsonArray:
"""Tests for _parse_llm_json_array — batch LLM response parsing."""
def test_clean_json_array(self):
"""A well-formed JSON array should be returned directly."""
raw = json.dumps([
{"title": "Control A", "objective": "Obj A"},
{"title": "Control B", "objective": "Obj B"},
])
result = _parse_llm_json_array(raw)
assert len(result) == 2
assert result[0]["title"] == "Control A"
assert result[1]["title"] == "Control B"
def test_json_array_in_markdown_code_block(self):
"""JSON array wrapped in ```json ... ``` markdown fence."""
inner = json.dumps([
{"title": "Fenced A", "chunk_index": 1},
{"title": "Fenced B", "chunk_index": 2},
])
raw = f"Here is the result:\n```json\n{inner}\n```\nDone."
result = _parse_llm_json_array(raw)
assert len(result) == 2
assert result[0]["title"] == "Fenced A"
assert result[1]["title"] == "Fenced B"
def test_markdown_code_block_without_json_tag(self):
"""Markdown fence without explicit 'json' language tag."""
inner = json.dumps([{"title": "NoTag", "objective": "test"}])
raw = f"```\n{inner}\n```"
result = _parse_llm_json_array(raw)
assert len(result) == 1
assert result[0]["title"] == "NoTag"
def test_wrapper_object_controls_key(self):
"""LLM wraps array in {"controls": [...]} — should unwrap."""
raw = json.dumps({
"controls": [
{"title": "Wrapped A", "objective": "O1"},
{"title": "Wrapped B", "objective": "O2"},
]
})
result = _parse_llm_json_array(raw)
assert len(result) == 2
assert result[0]["title"] == "Wrapped A"
def test_wrapper_object_results_key(self):
"""LLM wraps array in {"results": [...]} — should unwrap."""
raw = json.dumps({
"results": [
{"title": "R1"},
{"title": "R2"},
{"title": "R3"},
]
})
result = _parse_llm_json_array(raw)
assert len(result) == 3
def test_wrapper_object_items_key(self):
"""LLM wraps array in {"items": [...]} — should unwrap."""
raw = json.dumps({
"items": [{"title": "Item1"}]
})
result = _parse_llm_json_array(raw)
assert len(result) == 1
assert result[0]["title"] == "Item1"
def test_wrapper_object_data_key(self):
"""LLM wraps array in {"data": [...]} — should unwrap."""
raw = json.dumps({
"data": [{"title": "D1"}, {"title": "D2"}]
})
result = _parse_llm_json_array(raw)
assert len(result) == 2
def test_single_dict_returned_as_list(self):
"""A single JSON object (no array) is wrapped in a list."""
raw = json.dumps({"title": "SingleControl", "objective": "Obj"})
result = _parse_llm_json_array(raw)
assert len(result) == 1
assert result[0]["title"] == "SingleControl"
def test_individual_json_objects_fallback(self):
"""Multiple separate JSON objects (not in array) are collected."""
raw = (
'Here are the controls:\n'
'{"title": "Ctrl1", "objective": "A"}\n'
'{"title": "Ctrl2", "objective": "B"}\n'
)
result = _parse_llm_json_array(raw)
assert len(result) == 2
titles = {r["title"] for r in result}
assert "Ctrl1" in titles
assert "Ctrl2" in titles
def test_individual_objects_require_title(self):
"""Fallback individual-object parsing only includes objects with 'title'."""
raw = (
'{"title": "HasTitle", "objective": "Yes"}\n'
'{"no_title_field": "skip_me"}\n'
)
result = _parse_llm_json_array(raw)
assert len(result) == 1
assert result[0]["title"] == "HasTitle"
def test_empty_string_returns_empty_list(self):
"""Empty input returns empty list."""
result = _parse_llm_json_array("")
assert result == []
def test_invalid_input_returns_empty_list(self):
"""Completely invalid input returns empty list."""
result = _parse_llm_json_array("this is not json at all, no braces anywhere")
assert result == []
def test_garbage_with_no_json_returns_empty(self):
"""Random non-JSON text should return empty list."""
result = _parse_llm_json_array("Hier ist meine Antwort: leider kann ich das nicht.")
assert result == []
def test_bracket_block_extraction(self):
"""Array embedded in preamble text is extracted via bracket matching."""
raw = 'Some preamble text...\n[{"title": "Extracted", "objective": "X"}]\nSome trailing text.'
result = _parse_llm_json_array(raw)
assert len(result) == 1
assert result[0]["title"] == "Extracted"
def test_nested_objects_in_array(self):
"""Array elements with nested objects (like scope) are parsed correctly."""
raw = json.dumps([
{
"title": "Nested",
"objective": "Test",
"scope": {"regions": ["EU", "DE"]},
"requirements": ["Req1"],
}
])
result = _parse_llm_json_array(raw)
assert len(result) == 1
assert result[0]["scope"]["regions"] == ["EU", "DE"]
# =============================================================================
# Batch Size Configuration Tests
# =============================================================================
class TestBatchSizeConfig:
"""Tests for batch_size configuration on GeneratorConfig."""
def test_default_batch_size(self):
config = GeneratorConfig()
assert config.batch_size == 5
def test_custom_batch_size(self):
config = GeneratorConfig(batch_size=10)
assert config.batch_size == 10
def test_batch_size_of_one(self):
config = GeneratorConfig(batch_size=1)
assert config.batch_size == 1
def test_batch_size_used_in_pipeline_constant(self):
"""Verify that pipeline uses config.batch_size (not a hardcoded value)."""
config = GeneratorConfig(batch_size=3)
# BATCH_SIZE = config.batch_size or 5 — with batch_size=3 it should be 3
batch_size = config.batch_size or 5
assert batch_size == 3
def test_batch_size_zero_falls_back_to_five(self):
"""batch_size=0 triggers `or 5` fallback in the pipeline loop."""
config = GeneratorConfig(batch_size=0)
# Mimics the pipeline logic: BATCH_SIZE = config.batch_size or 5
batch_size = config.batch_size or 5
assert batch_size == 5
# =============================================================================
# Batch Processing Loop Tests (Mocked)
# =============================================================================
class TestBatchProcessingLoop:
"""Tests for _process_batch, _structure_batch, _reformulate_batch with mocked LLM."""
def _make_chunk(self, regulation_code="owasp_asvs", article="V2.1.1", text="Test requirement"):
return RAGSearchResult(
text=text,
regulation_code=regulation_code,
regulation_name="OWASP ASVS" if "owasp" in regulation_code else "Test Source",
regulation_short="OWASP" if "owasp" in regulation_code else "TEST",
category="requirement",
article=article,
paragraph="",
source_url="https://example.com",
score=0.9,
)
@pytest.mark.asyncio
async def test_process_batch_splits_by_license_rule(self):
"""_process_batch routes Rule 1+2 to _structure_batch and Rule 3 to _reformulate_batch."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
pipeline._existing_controls = []
chunk_r1 = self._make_chunk("eu_2016_679", "Art. 35", "DSGVO text")
chunk_r3 = self._make_chunk("bsi_tr03161", "O.Auth_1", "BSI text")
batch_items = [
(chunk_r1, {"rule": 1, "name": "DSGVO", "license": "EU_LAW"}),
(chunk_r3, {"rule": 3, "name": "INTERNAL_ONLY"}),
]
# Mock _structure_batch and _reformulate_batch
structure_ctrl = GeneratedControl(title="Structured", license_rule=1, release_state="draft")
reform_ctrl = GeneratedControl(title="Reformed", license_rule=3, release_state="draft")
mock_finder_instance = AsyncMock()
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock, return_value=[structure_ctrl]) as mock_struct, \
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=[reform_ctrl]) as mock_reform, \
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock) as mock_sim:
mock_sim.return_value = MagicMock(status="PASS", token_overlap=0.1, ngram_jaccard=0.1, lcs_ratio=0.1)
config = GeneratorConfig(batch_size=5)
result = await pipeline._process_batch(batch_items, config, "test-job-123")
mock_struct.assert_called_once()
mock_reform.assert_called_once()
# structure_batch received only the Rule 1 chunk
assert len(mock_struct.call_args[0][0]) == 1
# reformulate_batch received only the Rule 3 chunk
assert len(mock_reform.call_args[0][0]) == 1
@pytest.mark.asyncio
async def test_structure_batch_calls_llm_and_parses(self):
"""_structure_batch sends prompt to LLM and parses array response."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
chunks = [
self._make_chunk("eu_2016_679", "Art. 5", "Datensparsamkeit und Zweckbindung"),
self._make_chunk("eu_2016_679", "Art. 35", "DSFA bei hohem Risiko"),
]
license_infos = [
{"rule": 1, "name": "DSGVO", "license": "EU_LAW"},
{"rule": 1, "name": "DSGVO", "license": "EU_LAW"},
]
llm_response = json.dumps([
{
"chunk_index": 1,
"title": "Datensparsamkeit",
"objective": "Nur notwendige Daten erheben.",
"rationale": "DSGVO Grundprinzip.",
"requirements": ["Datenminimierung"],
"test_procedure": ["Datenbestand pruefen"],
"evidence": ["Verarbeitungsverzeichnis"],
"severity": "high",
"tags": ["dsgvo", "datenschutz"],
},
{
"chunk_index": 2,
"title": "DSFA Pflicht",
"objective": "DSFA bei hohem Risiko durchfuehren.",
"rationale": "Gesetzliche Pflicht.",
"requirements": ["DSFA erstellen"],
"test_procedure": ["DSFA Bericht pruefen"],
"evidence": ["DSFA Dokumentation"],
"severity": "high",
"tags": ["dsfa"],
},
])
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
controls = await pipeline._structure_batch(chunks, license_infos)
assert len(controls) == 2
assert controls[0] is not None
assert controls[0].title == "Datensparsamkeit"
assert controls[0].license_rule == 1
assert controls[0].source_original_text is not None
assert controls[0].customer_visible is True
assert controls[0].generation_metadata["processing_path"] == "structured_batch"
assert controls[0].generation_metadata["batch_size"] == 2
assert controls[1] is not None
assert controls[1].title == "DSFA Pflicht"
assert controls[1].license_rule == 1
@pytest.mark.asyncio
async def test_reformulate_batch_calls_llm_and_strips_source(self):
"""_reformulate_batch produces Rule 3 controls without source info."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
chunks = [
self._make_chunk("bsi_tr03161", "O.Auth_1", "Multi-factor authentication for apps"),
]
config = GeneratorConfig(batch_size=5)
llm_response = json.dumps([
{
"chunk_index": 1,
"title": "Starke Authentifizierung",
"objective": "Mehrstufige Anmeldung implementieren.",
"rationale": "Schutz vor unbefugtem Zugriff.",
"requirements": ["MFA einrichten"],
"test_procedure": ["MFA Funktionstest"],
"evidence": ["MFA Konfiguration"],
"severity": "critical",
"tags": ["auth", "mfa"],
}
])
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
controls = await pipeline._reformulate_batch(chunks, config)
assert len(controls) == 1
ctrl = controls[0]
assert ctrl is not None
assert ctrl.title == "Starke Authentifizierung"
assert ctrl.license_rule == 3
assert ctrl.source_original_text is None
assert ctrl.source_citation is None
assert ctrl.customer_visible is False
assert ctrl.generation_metadata["processing_path"] == "llm_reform_batch"
# Must not contain BSI references
metadata_str = json.dumps(ctrl.generation_metadata)
assert "bsi" not in metadata_str.lower()
assert "TR-03161" not in metadata_str
@pytest.mark.asyncio
async def test_structure_batch_maps_by_chunk_index(self):
"""Controls are mapped back to the correct chunk via chunk_index."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
chunks = [
self._make_chunk("eu_2016_679", "Art. 5", "First chunk"),
self._make_chunk("eu_2016_679", "Art. 6", "Second chunk"),
self._make_chunk("eu_2016_679", "Art. 7", "Third chunk"),
]
license_infos = [{"rule": 1, "name": "DSGVO", "license": "EU_LAW"}] * 3
# LLM returns them in reversed order
llm_response = json.dumps([
{
"chunk_index": 3,
"title": "Third Control",
"objective": "Obj3",
"rationale": "Rat3",
"requirements": [],
"test_procedure": [],
"evidence": [],
"severity": "low",
"tags": [],
},
{
"chunk_index": 1,
"title": "First Control",
"objective": "Obj1",
"rationale": "Rat1",
"requirements": [],
"test_procedure": [],
"evidence": [],
"severity": "high",
"tags": [],
},
{
"chunk_index": 2,
"title": "Second Control",
"objective": "Obj2",
"rationale": "Rat2",
"requirements": [],
"test_procedure": [],
"evidence": [],
"severity": "medium",
"tags": [],
},
])
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
controls = await pipeline._structure_batch(chunks, license_infos)
assert len(controls) == 3
# Verify correct mapping despite shuffled chunk_index
assert controls[0].title == "First Control"
assert controls[1].title == "Second Control"
assert controls[2].title == "Third Control"
@pytest.mark.asyncio
async def test_structure_batch_falls_back_to_position(self):
"""If chunk_index is missing, controls are assigned by position."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
chunks = [
self._make_chunk("eu_2016_679", "Art. 5", "Chunk A"),
self._make_chunk("eu_2016_679", "Art. 6", "Chunk B"),
]
license_infos = [{"rule": 1, "name": "DSGVO", "license": "EU_LAW"}] * 2
# No chunk_index in response
llm_response = json.dumps([
{
"title": "PositionA",
"objective": "ObjA",
"rationale": "RatA",
"requirements": [],
"test_procedure": [],
"evidence": [],
"severity": "medium",
"tags": [],
},
{
"title": "PositionB",
"objective": "ObjB",
"rationale": "RatB",
"requirements": [],
"test_procedure": [],
"evidence": [],
"severity": "medium",
"tags": [],
},
])
with patch("compliance.services.control_generator._llm_chat", new_callable=AsyncMock, return_value=llm_response):
controls = await pipeline._structure_batch(chunks, license_infos)
assert len(controls) == 2
assert controls[0].title == "PositionA"
assert controls[1].title == "PositionB"
@pytest.mark.asyncio
async def test_process_batch_only_structure_no_reform(self):
"""Batch with only Rule 1+2 items skips _reformulate_batch."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
pipeline._existing_controls = []
chunk = self._make_chunk("eu_2016_679", "Art. 5", "DSGVO text")
batch_items = [
(chunk, {"rule": 1, "name": "DSGVO", "license": "EU_LAW"}),
]
ctrl = GeneratedControl(title="OnlyStructure", license_rule=1, release_state="draft")
mock_finder_instance = AsyncMock()
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock, return_value=[ctrl]) as mock_struct, \
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock) as mock_reform, \
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls):
config = GeneratorConfig()
result = await pipeline._process_batch(batch_items, config, "job-1")
mock_struct.assert_called_once()
mock_reform.assert_not_called()
assert len(result) == 1
assert result[0].title == "OnlyStructure"
@pytest.mark.asyncio
async def test_process_batch_only_reform_no_structure(self):
"""Batch with only Rule 3 items skips _structure_batch."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
pipeline._existing_controls = []
chunk = self._make_chunk("bsi_tr03161", "O.Auth_1", "BSI text")
batch_items = [
(chunk, {"rule": 3, "name": "INTERNAL_ONLY"}),
]
ctrl = GeneratedControl(title="OnlyReform", license_rule=3, release_state="draft")
mock_finder_instance = AsyncMock()
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock) as mock_struct, \
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=[ctrl]) as mock_reform, \
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock) as mock_sim:
mock_sim.return_value = MagicMock(status="PASS", token_overlap=0.1, ngram_jaccard=0.1, lcs_ratio=0.1)
config = GeneratorConfig()
result = await pipeline._process_batch(batch_items, config, "job-2")
mock_struct.assert_not_called()
mock_reform.assert_called_once()
assert len(result) == 1
assert result[0].title == "OnlyReform"
@pytest.mark.asyncio
async def test_process_batch_mixed_rules(self):
"""Batch with mixed Rule 1 and Rule 3 items calls both sub-methods."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
pipeline._existing_controls = []
chunk_r1 = self._make_chunk("eu_2016_679", "Art. 5", "DSGVO")
chunk_r2 = self._make_chunk("owasp_asvs", "V2.1", "OWASP")
chunk_r3a = self._make_chunk("bsi_tr03161", "O.Auth_1", "BSI A")
chunk_r3b = self._make_chunk("iso_27001", "A.9.1", "ISO B")
batch_items = [
(chunk_r1, {"rule": 1, "name": "DSGVO", "license": "EU_LAW"}),
(chunk_r3a, {"rule": 3, "name": "INTERNAL_ONLY"}),
(chunk_r2, {"rule": 2, "name": "OWASP ASVS", "license": "CC-BY-SA-4.0", "attribution": "OWASP Foundation"}),
(chunk_r3b, {"rule": 3, "name": "INTERNAL_ONLY"}),
]
struct_ctrls = [
GeneratedControl(title="DSGVO Ctrl", license_rule=1, release_state="draft"),
GeneratedControl(title="OWASP Ctrl", license_rule=2, release_state="draft"),
]
reform_ctrls = [
GeneratedControl(title="BSI Ctrl", license_rule=3, release_state="draft"),
GeneratedControl(title="ISO Ctrl", license_rule=3, release_state="draft"),
]
mock_finder_instance = AsyncMock()
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock, return_value=struct_ctrls) as mock_struct, \
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=reform_ctrls) as mock_reform, \
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock) as mock_sim:
mock_sim.return_value = MagicMock(status="PASS", token_overlap=0.05, ngram_jaccard=0.05, lcs_ratio=0.05)
config = GeneratorConfig()
result = await pipeline._process_batch(batch_items, config, "job-mixed")
# Both methods called
mock_struct.assert_called_once()
mock_reform.assert_called_once()
# structure_batch gets 2 items (Rule 1 + Rule 2)
assert len(mock_struct.call_args[0][0]) == 2
# reformulate_batch gets 2 items (Rule 3 + Rule 3)
assert len(mock_reform.call_args[0][0]) == 2
# Result has 4 controls total
assert len(result) == 4
@pytest.mark.asyncio
async def test_process_batch_empty_batch(self):
"""Empty batch returns empty list."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
pipeline._existing_controls = []
config = GeneratorConfig()
result = await pipeline._process_batch([], config, "job-empty")
assert result == []
@pytest.mark.asyncio
async def test_reformulate_batch_too_close_flagged(self):
"""Rule 3 controls that are too similar to source get flagged."""
mock_db = MagicMock()
pipeline = ControlGeneratorPipeline(db=mock_db)
pipeline._existing_controls = []
chunk = self._make_chunk("bsi_tr03161", "O.Auth_1", "Authentication must use MFA")
batch_items = [
(chunk, {"rule": 3, "name": "INTERNAL_ONLY"}),
]
ctrl = GeneratedControl(
title="Auth MFA",
objective="Authentication must use MFA",
rationale="Security",
license_rule=3,
release_state="draft",
generation_metadata={},
)
# Simulate similarity FAIL (too close to source)
fail_report = MagicMock(status="FAIL", token_overlap=0.85, ngram_jaccard=0.9, lcs_ratio=0.88)
mock_finder_instance = AsyncMock()
mock_finder_instance.find_anchors = AsyncMock(return_value=[])
mock_finder_cls = MagicMock(return_value=mock_finder_instance)
with patch.object(pipeline, "_structure_batch", new_callable=AsyncMock), \
patch.object(pipeline, "_reformulate_batch", new_callable=AsyncMock, return_value=[ctrl]), \
patch.object(pipeline, "_check_harmonization", new_callable=AsyncMock, return_value=[]), \
patch("compliance.services.anchor_finder.AnchorFinder", mock_finder_cls), \
patch("compliance.services.control_generator.check_similarity", new_callable=AsyncMock, return_value=fail_report):
config = GeneratorConfig()
result = await pipeline._process_batch(batch_items, config, "job-tooclose")
assert len(result) == 1
assert result[0].release_state == "too_close"
assert result[0].generation_metadata["similarity_status"] == "FAIL"
# =============================================================================
# Regulation Filter Tests
# =============================================================================
class TestRegulationFilter:
"""Tests for regulation_filter in GeneratorConfig."""
def test_config_accepts_regulation_filter(self):
config = GeneratorConfig(regulation_filter=["owasp_", "nist_", "eu_2023_1230"])
assert config.regulation_filter == ["owasp_", "nist_", "eu_2023_1230"]
def test_config_default_none(self):
config = GeneratorConfig()
assert config.regulation_filter is None
@pytest.mark.asyncio
async def test_scan_rag_filters_by_regulation(self):
"""Verify _scan_rag skips chunks not matching regulation_filter."""
mock_db = MagicMock()
mock_db.execute.return_value.fetchall.return_value = []
mock_db.execute.return_value = MagicMock()
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
# Mock Qdrant scroll response with mixed regulation_codes
qdrant_points = {
"result": {
"points": [
{"id": "1", "payload": {
"chunk_text": "OWASP ASVS requirement for input validation " * 5,
"regulation_code": "owasp_asvs",
"regulation_name": "OWASP ASVS",
}},
{"id": "2", "payload": {
"chunk_text": "AML anti-money laundering requirement for banks " * 5,
"regulation_code": "amlr",
"regulation_name": "AML-Verordnung",
}},
{"id": "3", "payload": {
"chunk_text": "NIST secure software development framework req " * 5,
"regulation_code": "nist_sp_800_218",
"regulation_name": "NIST SSDF",
}},
],
"next_page_offset": None,
}
}
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = qdrant_points
mock_client.post.return_value = mock_resp
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
# With filter: only owasp_ and nist_ prefixes
config = GeneratorConfig(
collections=["bp_compliance_ce"],
regulation_filter=["owasp_", "nist_"],
)
results = await pipeline._scan_rag(config)
# Should only get 2 chunks (owasp + nist), not amlr
assert len(results) == 2
codes = {r.regulation_code for r in results}
assert "owasp_asvs" in codes
assert "nist_sp_800_218" in codes
assert "amlr" not in codes
@pytest.mark.asyncio
async def test_scan_rag_no_filter_returns_all(self):
"""Verify _scan_rag returns all chunks when no regulation_filter."""
mock_db = MagicMock()
mock_db.execute.return_value.fetchall.return_value = []
mock_db.execute.return_value = MagicMock()
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
qdrant_points = {
"result": {
"points": [
{"id": "1", "payload": {
"chunk_text": "OWASP requirement for secure authentication " * 5,
"regulation_code": "owasp_asvs",
}},
{"id": "2", "payload": {
"chunk_text": "AML compliance requirement for financial inst " * 5,
"regulation_code": "amlr",
}},
],
"next_page_offset": None,
}
}
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = qdrant_points
mock_client.post.return_value = mock_resp
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
config = GeneratorConfig(
collections=["bp_compliance_ce"],
regulation_filter=None,
)
results = await pipeline._scan_rag(config)
assert len(results) == 2