feat(pipeline): Anthropic Batch API, source/regulation filter, cost optimization
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 35s
CI/CD / test-python-backend-compliance (push) Successful in 34s
CI/CD / test-python-document-crawler (push) Successful in 22s
CI/CD / test-python-dsms-gateway (push) Successful in 19s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped

- Add Anthropic API support to decomposition Pass 0a/0b (prompt caching, content batching)
- Add Anthropic Batch API (50% cost reduction, async 24h processing)
- Add source_filter (ILIKE on source_citation) for regulation-based filtering
- Add category_filter to Pass 0a for selective decomposition
- Add regulation_filter to control_generator for RAG scan phase filtering
  (prefix match on regulation_code — enables CE + Code Review focus)
- New API endpoints: batch-submit-0a, batch-submit-0b, batch-status, batch-process
- 83 new tests (all passing)

Cost reduction: $2,525 → ~$600-700 with all optimizations combined.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-17 13:22:01 +01:00
parent 825e070ed9
commit d22c47c9eb
6 changed files with 1525 additions and 163 deletions

View File

@@ -947,3 +947,120 @@ class TestBatchProcessingLoop:
assert len(result) == 1
assert result[0].release_state == "too_close"
assert result[0].generation_metadata["similarity_status"] == "FAIL"
# =============================================================================
# Regulation Filter Tests
# =============================================================================
class TestRegulationFilter:
"""Tests for regulation_filter in GeneratorConfig."""
def test_config_accepts_regulation_filter(self):
config = GeneratorConfig(regulation_filter=["owasp_", "nist_", "eu_2023_1230"])
assert config.regulation_filter == ["owasp_", "nist_", "eu_2023_1230"]
def test_config_default_none(self):
config = GeneratorConfig()
assert config.regulation_filter is None
@pytest.mark.asyncio
async def test_scan_rag_filters_by_regulation(self):
"""Verify _scan_rag skips chunks not matching regulation_filter."""
mock_db = MagicMock()
mock_db.execute.return_value.fetchall.return_value = []
mock_db.execute.return_value = MagicMock()
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
# Mock Qdrant scroll response with mixed regulation_codes
qdrant_points = {
"result": {
"points": [
{"id": "1", "payload": {
"chunk_text": "OWASP ASVS requirement for input validation " * 5,
"regulation_code": "owasp_asvs",
"regulation_name": "OWASP ASVS",
}},
{"id": "2", "payload": {
"chunk_text": "AML anti-money laundering requirement for banks " * 5,
"regulation_code": "amlr",
"regulation_name": "AML-Verordnung",
}},
{"id": "3", "payload": {
"chunk_text": "NIST secure software development framework req " * 5,
"regulation_code": "nist_sp_800_218",
"regulation_name": "NIST SSDF",
}},
],
"next_page_offset": None,
}
}
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = qdrant_points
mock_client.post.return_value = mock_resp
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
# With filter: only owasp_ and nist_ prefixes
config = GeneratorConfig(
collections=["bp_compliance_ce"],
regulation_filter=["owasp_", "nist_"],
)
results = await pipeline._scan_rag(config)
# Should only get 2 chunks (owasp + nist), not amlr
assert len(results) == 2
codes = {r.regulation_code for r in results}
assert "owasp_asvs" in codes
assert "nist_sp_800_218" in codes
assert "amlr" not in codes
@pytest.mark.asyncio
async def test_scan_rag_no_filter_returns_all(self):
"""Verify _scan_rag returns all chunks when no regulation_filter."""
mock_db = MagicMock()
mock_db.execute.return_value.fetchall.return_value = []
mock_db.execute.return_value = MagicMock()
mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([]))
qdrant_points = {
"result": {
"points": [
{"id": "1", "payload": {
"chunk_text": "OWASP requirement for secure authentication " * 5,
"regulation_code": "owasp_asvs",
}},
{"id": "2", "payload": {
"chunk_text": "AML compliance requirement for financial inst " * 5,
"regulation_code": "amlr",
}},
],
"next_page_offset": None,
}
}
with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = qdrant_points
mock_client.post.return_value = mock_resp
pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock())
config = GeneratorConfig(
collections=["bp_compliance_ce"],
regulation_filter=None,
)
results = await pipeline._scan_rag(config)
assert len(results) == 2

View File

@@ -37,8 +37,11 @@ from compliance.services.decomposition_pass import (
_compute_extraction_confidence,
_normalize_severity,
_template_fallback,
_fallback_obligation,
_build_pass0a_prompt,
_build_pass0b_prompt,
_build_pass0a_batch_prompt,
_build_pass0b_batch_prompt,
_PASS0A_SYSTEM_PROMPT,
_PASS0B_SYSTEM_PROMPT,
DecompositionPass,
@@ -814,3 +817,342 @@ class TestMigration061:
assert "decomposition_method" in content
assert "candidate_id" in content
assert "quality_flags" in content
# ---------------------------------------------------------------------------
# BATCH PROMPT TESTS
# ---------------------------------------------------------------------------
class TestBatchPromptBuilders:
"""Tests for batch prompt builders."""
def test_pass0a_batch_prompt_contains_all_controls(self):
controls = [
{
"control_id": "AUTH-001",
"title": "MFA Control",
"objective": "Implement MFA",
"requirements": "- TOTP required",
"test_procedure": "- Test login",
"source_ref": "DSGVO Art. 32",
},
{
"control_id": "AUTH-002",
"title": "Password Policy",
"objective": "Enforce strong passwords",
"requirements": "- Min 12 chars",
"test_procedure": "- Test weak password",
"source_ref": "BSI IT-Grundschutz",
},
]
prompt = _build_pass0a_batch_prompt(controls)
assert "AUTH-001" in prompt
assert "AUTH-002" in prompt
assert "MFA Control" in prompt
assert "Password Policy" in prompt
assert "CONTROL 1" in prompt
assert "CONTROL 2" in prompt
assert "2 Controls" in prompt
def test_pass0a_batch_prompt_single_control(self):
controls = [
{
"control_id": "AUTH-001",
"title": "MFA",
"objective": "MFA",
"requirements": "",
"test_procedure": "",
"source_ref": "",
},
]
prompt = _build_pass0a_batch_prompt(controls)
assert "AUTH-001" in prompt
assert "1 Controls" in prompt
def test_pass0b_batch_prompt_contains_all_obligations(self):
obligations = [
{
"candidate_id": "OC-AUTH-001-01",
"obligation_text": "MFA implementieren",
"action": "implementieren",
"object": "MFA",
"parent_title": "Auth Controls",
"parent_category": "authentication",
"source_ref": "DSGVO Art. 32",
},
{
"candidate_id": "OC-AUTH-001-02",
"obligation_text": "MFA testen",
"action": "testen",
"object": "MFA",
"parent_title": "Auth Controls",
"parent_category": "authentication",
"source_ref": "DSGVO Art. 32",
},
]
prompt = _build_pass0b_batch_prompt(obligations)
assert "OC-AUTH-001-01" in prompt
assert "OC-AUTH-001-02" in prompt
assert "PFLICHT 1" in prompt
assert "PFLICHT 2" in prompt
assert "2 Pflichten" in prompt
class TestFallbackObligation:
"""Tests for _fallback_obligation helper."""
def test_uses_objective_when_available(self):
ctrl = {"title": "MFA", "objective": "Implement MFA for all users"}
result = _fallback_obligation(ctrl)
assert result["obligation_text"] == "Implement MFA for all users"
assert result["action"] == "sicherstellen"
def test_uses_title_when_no_objective(self):
ctrl = {"title": "MFA Control", "objective": ""}
result = _fallback_obligation(ctrl)
assert result["obligation_text"] == "MFA Control"
# ---------------------------------------------------------------------------
# ANTHROPIC BATCHING INTEGRATION TESTS
# ---------------------------------------------------------------------------
class TestDecompositionPassAnthropicBatch:
"""Tests for batched Anthropic API calls in Pass 0a/0b."""
@pytest.mark.asyncio
async def test_pass0a_anthropic_batched(self):
"""Test Pass 0a with Anthropic API and batch_size=2."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "MFA Control", "Implement MFA",
"", "", "", "security"),
("uuid-2", "CTRL-002", "Encryption", "Encrypt data at rest",
"", "", "", "security"),
]
mock_db.execute.return_value = mock_rows
# Anthropic returns JSON object keyed by control_id
batched_response = json.dumps({
"CTRL-001": [
{"obligation_text": "MFA muss implementiert werden",
"action": "implementieren", "object": "MFA",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
],
"CTRL-002": [
{"obligation_text": "Daten müssen verschlüsselt werden",
"action": "verschlüsseln", "object": "Daten",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
],
})
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = batched_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, batch_size=2, use_anthropic=True,
)
assert stats["controls_processed"] == 2
assert stats["obligations_extracted"] == 2
assert stats["llm_calls"] == 1 # Only 1 API call for 2 controls
assert stats["provider"] == "anthropic"
@pytest.mark.asyncio
async def test_pass0a_anthropic_single(self):
"""Test Pass 0a with Anthropic API, batch_size=1 (no batching)."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "MFA Control", "Implement MFA",
"", "", "", "security"),
]
mock_db.execute.return_value = mock_rows
response = json.dumps([
{"obligation_text": "MFA muss implementiert werden",
"action": "implementieren", "object": "MFA",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
])
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, batch_size=1, use_anthropic=True,
)
assert stats["controls_processed"] == 1
assert stats["llm_calls"] == 1
assert stats["provider"] == "anthropic"
@pytest.mark.asyncio
async def test_pass0b_anthropic_batched(self):
"""Test Pass 0b with Anthropic API and batch_size=2."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("oc-uuid-1", "OC-CTRL-001-01", "parent-uuid-1",
"MFA implementieren", "implementieren", "MFA",
False, False, "Auth", "security",
'{"source": "DSGVO", "article": "Art. 32"}',
"high", "CTRL-001"),
("oc-uuid-2", "OC-CTRL-001-02", "parent-uuid-1",
"MFA testen", "testen", "MFA",
True, False, "Auth", "security",
'{"source": "DSGVO", "article": "Art. 32"}',
"high", "CTRL-001"),
]
mock_seq = MagicMock()
mock_seq.fetchone.return_value = (0,)
call_count = [0]
def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_rows # SELECT candidates
# _next_atomic_seq calls (every 3rd after first: 2, 5, 8, ...)
if call_count[0] in (2, 5):
return mock_seq
return MagicMock() # INSERT/UPDATE
mock_db.execute.side_effect = side_effect
batched_response = json.dumps({
"OC-CTRL-001-01": {
"title": "MFA implementieren",
"objective": "MFA fuer alle Konten.",
"requirements": ["TOTP einrichten"],
"test_procedure": ["Login testen"],
"evidence": ["Konfigurationsnachweis"],
"severity": "high",
"category": "security",
},
"OC-CTRL-001-02": {
"title": "MFA-Wirksamkeit testen",
"objective": "Regelmaessige MFA-Tests.",
"requirements": ["Testplan erstellen"],
"test_procedure": ["Testdurchfuehrung"],
"evidence": ["Testprotokoll"],
"severity": "high",
"category": "security",
},
})
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = batched_response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0b(
limit=10, batch_size=2, use_anthropic=True,
)
assert stats["controls_created"] == 2
assert stats["llm_calls"] == 1
assert stats["provider"] == "anthropic"
# ---------------------------------------------------------------------------
# SOURCE FILTER TESTS
# ---------------------------------------------------------------------------
class TestSourceFilter:
"""Tests for source_filter parameter in Pass 0a."""
@pytest.mark.asyncio
async def test_pass0a_source_filter_builds_ilike_query(self):
"""Verify source_filter adds ILIKE clauses to query."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = [
("uuid-1", "CTRL-001", "Machine Safety", "Ensure safety",
"", "", '{"source": "Maschinenverordnung (EU) 2023/1230"}', "security"),
]
mock_db.execute.return_value = mock_rows
response = json.dumps([
{"obligation_text": "Sicherheit gewaehrleisten",
"action": "gewaehrleisten", "object": "Sicherheit",
"normative_strength": "must",
"is_test_obligation": False, "is_reporting_obligation": False},
])
with patch(
"compliance.services.decomposition_pass._llm_anthropic",
new_callable=AsyncMock,
) as mock_llm:
mock_llm.return_value = response
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, batch_size=1, use_anthropic=True,
source_filter="Maschinenverordnung,Cyber Resilience Act",
)
assert stats["controls_processed"] == 1
# Verify the SQL query contained ILIKE clauses
call_args = mock_db.execute.call_args_list[0]
query_str = str(call_args[0][0])
assert "ILIKE" in query_str
@pytest.mark.asyncio
async def test_pass0a_source_filter_none_no_clause(self):
"""Verify no ILIKE clause when source_filter is None."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = []
mock_db.execute.return_value = mock_rows
decomp = DecompositionPass(db=mock_db)
stats = await decomp.run_pass0a(
limit=10, use_anthropic=True, source_filter=None,
)
call_args = mock_db.execute.call_args_list[0]
query_str = str(call_args[0][0])
assert "ILIKE" not in query_str
@pytest.mark.asyncio
async def test_pass0a_combined_category_and_source_filter(self):
"""Verify both category_filter and source_filter can be used together."""
mock_db = MagicMock()
mock_rows = MagicMock()
mock_rows.fetchall.return_value = []
mock_db.execute.return_value = mock_rows
decomp = DecompositionPass(db=mock_db)
await decomp.run_pass0a(
limit=10, use_anthropic=True,
category_filter="security,operations",
source_filter="Maschinenverordnung",
)
call_args = mock_db.execute.call_args_list[0]
query_str = str(call_args[0][0])
assert "IN :cats" in query_str
assert "ILIKE" in query_str