feat(control-library): document-grouped batching, generation strategy tracking, sort by source
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 31s
CI/CD / test-python-backend-compliance (push) Successful in 31s
CI/CD / test-python-document-crawler (push) Successful in 21s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Successful in 2s

- Group chunks by regulation_code before batching for better LLM context
- Add generation_strategy column (ungrouped=v1, document_grouped=v2)
- Add v1/v2 badge to control cards in frontend
- Add sort-by-source option with visual group headers
- Add frontend page tests (18 tests)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-15 15:10:52 +01:00
parent 0d95c3bb44
commit c8fd9cc780
9 changed files with 1000 additions and 137 deletions

View File

@@ -80,6 +80,7 @@ class ControlResponse(BaseModel):
category: Optional[str] = None
target_audience: Optional[str] = None
generation_metadata: Optional[dict] = None
generation_strategy: Optional[str] = "ungrouped"
created_at: str
updated_at: str
@@ -161,7 +162,7 @@ _CONTROL_COLS = """id, framework_id, control_id, title, objective, rationale,
evidence_confidence, open_anchors, release_state, tags,
license_rule, source_original_text, source_citation,
customer_visible, verification_method, category,
target_audience, generation_metadata,
target_audience, generation_metadata, generation_strategy,
created_at, updated_at"""
@@ -297,8 +298,14 @@ async def list_controls(
verification_method: Optional[str] = Query(None),
category: Optional[str] = Query(None),
target_audience: Optional[str] = Query(None),
source: Optional[str] = Query(None, description="Filter by source_citation->source"),
search: Optional[str] = Query(None, description="Full-text search in control_id, title, objective"),
sort: Optional[str] = Query("control_id", description="Sort field: control_id, created_at, severity"),
order: Optional[str] = Query("asc", description="Sort order: asc or desc"),
limit: Optional[int] = Query(None, ge=1, le=5000, description="Max results"),
offset: Optional[int] = Query(None, ge=0, description="Offset for pagination"),
):
"""List all canonical controls, with optional filters."""
"""List canonical controls with filters, search, sorting and pagination."""
query = f"""
SELECT {_CONTROL_COLS}
FROM canonical_controls
@@ -324,8 +331,35 @@ async def list_controls(
if target_audience:
query += " AND target_audience = :ta"
params["ta"] = target_audience
if source:
if source == "__none__":
query += " AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')"
else:
query += " AND source_citation->>'source' = :src"
params["src"] = source
if search:
query += " AND (control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)"
params["q"] = f"%{search}%"
query += " ORDER BY control_id"
# Sorting
sort_col = "control_id"
if sort in ("created_at", "updated_at", "severity", "control_id"):
sort_col = sort
elif sort == "source":
sort_col = "source_citation->>'source'"
sort_dir = "DESC" if order and order.lower() == "desc" else "ASC"
if sort == "source":
# Group by source first, then by control_id within each source
query += f" ORDER BY {sort_col} {sort_dir} NULLS LAST, control_id ASC"
else:
query += f" ORDER BY {sort_col} {sort_dir}"
if limit is not None:
query += " LIMIT :lim"
params["lim"] = limit
if offset is not None:
query += " OFFSET :off"
params["off"] = offset
with SessionLocal() as db:
rows = db.execute(text(query), params).fetchall()
@@ -333,6 +367,87 @@ async def list_controls(
return [_control_row(r) for r in rows]
@router.get("/controls-count")
async def count_controls(
severity: Optional[str] = Query(None),
domain: Optional[str] = Query(None),
release_state: Optional[str] = Query(None),
verification_method: Optional[str] = Query(None),
category: Optional[str] = Query(None),
target_audience: Optional[str] = Query(None),
source: Optional[str] = Query(None),
search: Optional[str] = Query(None),
):
"""Count controls matching filters (for pagination)."""
query = "SELECT count(*) FROM canonical_controls WHERE 1=1"
params: dict[str, Any] = {}
if severity:
query += " AND severity = :sev"
params["sev"] = severity
if domain:
query += " AND LEFT(control_id, LENGTH(:dom)) = :dom"
params["dom"] = domain.upper()
if release_state:
query += " AND release_state = :rs"
params["rs"] = release_state
if verification_method:
query += " AND verification_method = :vm"
params["vm"] = verification_method
if category:
query += " AND category = :cat"
params["cat"] = category
if target_audience:
query += " AND target_audience = :ta"
params["ta"] = target_audience
if source:
if source == "__none__":
query += " AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')"
else:
query += " AND source_citation->>'source' = :src"
params["src"] = source
if search:
query += " AND (control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)"
params["q"] = f"%{search}%"
with SessionLocal() as db:
total = db.execute(text(query), params).scalar()
return {"total": total}
@router.get("/controls-meta")
async def controls_meta():
"""Return aggregated metadata for filter dropdowns (domains, sources, counts)."""
with SessionLocal() as db:
total = db.execute(text("SELECT count(*) FROM canonical_controls")).scalar()
domains = db.execute(text("""
SELECT UPPER(SPLIT_PART(control_id, '-', 1)) as domain, count(*) as cnt
FROM canonical_controls
GROUP BY domain ORDER BY domain
""")).fetchall()
sources = db.execute(text("""
SELECT source_citation->>'source' as src, count(*) as cnt
FROM canonical_controls
WHERE source_citation->>'source' IS NOT NULL AND source_citation->>'source' != ''
GROUP BY src ORDER BY cnt DESC
""")).fetchall()
no_source = db.execute(text("""
SELECT count(*) FROM canonical_controls
WHERE source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = ''
""")).scalar()
return {
"total": total,
"domains": [{"domain": r[0], "count": r[1]} for r in domains],
"sources": [{"source": r[0], "count": r[1]} for r in sources],
"no_source_count": no_source,
}
@router.get("/controls/{control_id}")
async def get_control(control_id: str):
"""Get a single canonical control by its control_id (e.g. AUTH-001)."""
@@ -661,6 +776,7 @@ def _control_row(r) -> dict:
"category": r.category,
"target_audience": r.target_audience,
"generation_metadata": r.generation_metadata,
"generation_strategy": getattr(r, "generation_strategy", "ungrouped"),
"created_at": r.created_at.isoformat() if r.created_at else None,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
}

View File

@@ -23,6 +23,7 @@ import logging
import os
import re
import uuid
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from typing import Dict, List, Optional, Set
@@ -368,6 +369,7 @@ class GeneratedControl:
source_citation: Optional[dict] = None
customer_visible: bool = True
generation_metadata: dict = field(default_factory=dict)
generation_strategy: str = "ungrouped" # ungrouped | document_grouped
# Classification fields
verification_method: Optional[str] = None # code_review, document, tool, hybrid
category: Optional[str] = None # one of 17 categories
@@ -940,6 +942,24 @@ Gib JSON zurück mit diesen Feldern:
license_infos: list[dict],
) -> list[Optional[GeneratedControl]]:
"""Structure multiple free-use/citation chunks in a single Anthropic call."""
# Build document context header if chunks share a regulation
regulations_in_batch = set(c.regulation_name for c in chunks)
doc_context = ""
if len(regulations_in_batch) == 1:
reg_name = next(iter(regulations_in_batch))
articles = sorted(set(c.article or "?" for c in chunks))
doc_context = (
f"\nDOKUMENTKONTEXT: Alle {len(chunks)} Chunks stammen aus demselben Gesetz: {reg_name}.\n"
f"Betroffene Artikel/Abschnitte: {', '.join(articles)}.\n"
f"Nutze diesen Zusammenhang fuer eine kohaerente, aufeinander abgestimmte Formulierung der Controls.\n"
f"Vermeide Redundanzen zwischen den Controls — jedes soll einen eigenen Aspekt abdecken.\n"
)
elif len(regulations_in_batch) <= 3:
doc_context = (
f"\nDOKUMENTKONTEXT: Die Chunks stammen aus {len(regulations_in_batch)} Gesetzen: "
f"{', '.join(regulations_in_batch)}.\n"
)
chunk_entries = []
for idx, (chunk, lic) in enumerate(zip(chunks, license_infos)):
source_name = lic.get("name", chunk.regulation_name)
@@ -952,20 +972,21 @@ Gib JSON zurück mit diesen Feldern:
joined = "\n\n".join(chunk_entries)
prompt = f"""Strukturiere die folgenden {len(chunks)} Gesetzestexte jeweils als eigenstaendiges Security/Compliance Control.
Du DARFST den Originaltext verwenden (Quellen sind jeweils angegeben).
{doc_context}
WICHTIG:
- Erstelle fuer JEDEN Chunk ein separates Control mit verstaendlicher, praxisorientierter Formulierung.
- Jedes Control muss eigenstaendig und vollstaendig sein — nicht auf andere Controls verweisen.
- Qualitaet ist wichtiger als Geschwindigkeit. Jedes Control muss die gleiche Qualitaet haben wie ein einzeln erstelltes.
- Antworte IMMER auf Deutsch.
Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat diese Felder:
- chunk_index: 1-basierter Index des Chunks (1, 2, 3, ...)
- title: Kurzer praegnanter Titel (max 100 Zeichen)
- objective: Was soll erreicht werden? (1-3 Saetze)
- rationale: Warum ist das wichtig? (1-2 Saetze)
- requirements: Liste von konkreten Anforderungen (Strings)
- test_procedure: Liste von Pruefschritten (Strings)
- evidence: Liste von Nachweisdokumenten (Strings)
- title: Kurzer praegnanter Titel auf Deutsch (max 100 Zeichen)
- objective: Was soll erreicht werden? (1-3 Saetze, Deutsch)
- rationale: Warum ist das wichtig? (1-2 Saetze, Deutsch)
- requirements: Liste von konkreten Anforderungen (Strings, Deutsch)
- test_procedure: Liste von Pruefschritten (Strings, Deutsch)
- evidence: Liste von Nachweisdokumenten (Strings, Deutsch)
- severity: low/medium/high/critical
- tags: Liste von Tags
@@ -1003,13 +1024,16 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
control.customer_visible = True
control.verification_method = _detect_verification_method(chunk.text)
control.category = _detect_category(chunk.text)
same_doc = len(set(c.regulation_code for c in chunks)) == 1
control.generation_metadata = {
"processing_path": "structured_batch",
"license_rule": lic["rule"],
"source_regulation": chunk.regulation_code,
"source_article": chunk.article,
"batch_size": len(chunks),
"document_grouped": same_doc,
}
control.generation_strategy = "document_grouped" if same_doc else "ungrouped"
controls[idx] = control
return controls
@@ -1369,7 +1393,7 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
open_anchors, release_state, tags,
license_rule, source_original_text, source_citation,
customer_visible, generation_metadata,
verification_method, category
verification_method, category, generation_strategy
) VALUES (
:framework_id, :control_id, :title, :objective, :rationale,
:scope, :requirements, :test_procedure, :evidence,
@@ -1377,7 +1401,7 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
:open_anchors, :release_state, :tags,
:license_rule, :source_original_text, :source_citation,
:customer_visible, :generation_metadata,
:verification_method, :category
:verification_method, :category, :generation_strategy
)
ON CONFLICT (framework_id, control_id) DO NOTHING
RETURNING id
@@ -1405,6 +1429,7 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
"generation_metadata": json.dumps(control.generation_metadata) if control.generation_metadata else None,
"verification_method": control.verification_method,
"category": control.category,
"generation_strategy": control.generation_strategy,
},
)
self.db.commit()
@@ -1479,21 +1504,48 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
self._update_job(job_id, result)
return result
# ── Group chunks by document (regulation_code) for coherent batching ──
doc_groups: dict[str, list[RAGSearchResult]] = defaultdict(list)
for chunk in chunks:
group_key = chunk.regulation_code or "unknown"
doc_groups[group_key].append(chunk)
# Sort chunks within each group by article for sequential context
for key in doc_groups:
doc_groups[key].sort(key=lambda c: (c.article or "", c.paragraph or ""))
logger.info(
"Grouped %d chunks into %d document groups for coherent batching",
len(chunks), len(doc_groups),
)
# Flatten back: chunks from same document are now adjacent
chunks = []
for group_list in doc_groups.values():
chunks.extend(group_list)
# Process chunks — batch mode (N chunks per Anthropic API call)
BATCH_SIZE = config.batch_size or 5
controls_count = 0
chunks_skipped_prefilter = 0
pending_batch: list[tuple[RAGSearchResult, dict]] = [] # (chunk, license_info)
current_batch_regulation: Optional[str] = None # Track regulation for group-aware flushing
async def _flush_batch():
"""Send pending batch to Anthropic and process results."""
nonlocal controls_count
nonlocal controls_count, current_batch_regulation
if not pending_batch:
return
batch = pending_batch.copy()
pending_batch.clear()
current_batch_regulation = None
logger.info("Processing batch of %d chunks via single API call...", len(batch))
# Log which document this batch belongs to
regs_in_batch = set(c.regulation_code for c, _ in batch)
logger.info(
"Processing batch of %d chunks (docs: %s) via single API call...",
len(batch), ", ".join(regs_in_batch),
)
try:
batch_controls = await self._process_batch(batch, config, job_id)
except Exception as e:
@@ -1514,6 +1566,9 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
self._mark_chunk_processed(chunk, lic_info, "no_control", [], job_id)
continue
# Mark as document_grouped strategy
control.generation_strategy = "document_grouped"
# Count by state
if control.release_state == "too_close":
result.controls_too_close += 1
@@ -1567,12 +1622,18 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
# Classify license and add to batch
license_info = self._classify_license(chunk)
pending_batch.append((chunk, license_info))
chunk_regulation = chunk.regulation_code or "unknown"
# Flush when batch is full
if len(pending_batch) >= BATCH_SIZE:
# Flush when: batch is full OR regulation changes (group boundary)
if pending_batch and (
len(pending_batch) >= BATCH_SIZE
or chunk_regulation != current_batch_regulation
):
await _flush_batch()
pending_batch.append((chunk, license_info))
current_batch_regulation = chunk_regulation
except Exception as e:
error_msg = f"Error processing chunk {chunk.regulation_code}/{chunk.article}: {e}"
logger.error(error_msg)

View File

@@ -0,0 +1,23 @@
-- 057: Add batch processing paths to canonical_processed_chunks
-- New values: structured_batch, llm_reform_batch (used by batch control generation)
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'canonical_processed_chunks') THEN
ALTER TABLE canonical_processed_chunks
DROP CONSTRAINT IF EXISTS canonical_processed_chunks_processing_path_check;
ALTER TABLE canonical_processed_chunks
ADD CONSTRAINT canonical_processed_chunks_processing_path_check
CHECK (processing_path IN (
'structured',
'llm_reform',
'skipped',
'prefilter_skip',
'no_control',
'store_failed',
'error',
'structured_batch',
'llm_reform_batch'
));
END IF;
END $$;

View File

@@ -0,0 +1,8 @@
-- Migration 058: Add generation_strategy column to canonical_controls
-- Tracks whether a control was generated with document-grouped or ungrouped batching
ALTER TABLE canonical_controls
ADD COLUMN IF NOT EXISTS generation_strategy TEXT NOT NULL DEFAULT 'ungrouped';
COMMENT ON COLUMN canonical_controls.generation_strategy IS
'How chunks were batched during generation: ungrouped (random), document_grouped (by regulation+article)';

View File

@@ -1,17 +1,36 @@
"""Tests for Canonical Control Library routes (canonical_control_routes.py)."""
"""Tests for Canonical Control Library routes (canonical_control_routes.py).
Includes:
- Model validation tests (FrameworkResponse, ControlResponse, etc.)
- _control_row conversion tests
- Server-side pagination, sorting, search, source filter tests
- /controls-count and /controls-meta endpoint tests
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime, timezone
from fastapi import FastAPI
from fastapi.testclient import TestClient
from compliance.api.canonical_control_routes import (
FrameworkResponse,
ControlResponse,
SimilarityCheckRequest,
SimilarityCheckResponse,
_control_row,
router,
)
# ---------------------------------------------------------------------------
# TestClient setup for endpoint tests
# ---------------------------------------------------------------------------
_app = FastAPI()
_app.include_router(router, prefix="/api/compliance")
_client = TestClient(_app)
class TestFrameworkResponse:
"""Tests for FrameworkResponse model."""
@@ -175,6 +194,7 @@ class TestControlRowConversion:
],
"release_state": "draft",
"tags": ["mfa"],
"generation_strategy": "ungrouped",
"created_at": now,
"updated_at": now,
}
@@ -223,3 +243,213 @@ class TestControlRowConversion:
result = _control_row(row)
assert result["created_at"] is None
assert result["updated_at"] is None
def test_generation_strategy_default(self):
row = self._make_row()
result = _control_row(row)
assert result["generation_strategy"] == "ungrouped"
def test_generation_strategy_document_grouped(self):
row = self._make_row(generation_strategy="document_grouped")
result = _control_row(row)
assert result["generation_strategy"] == "document_grouped"
# =============================================================================
# ENDPOINT TESTS — Server-Side Pagination, Sort, Search, Source Filter
# =============================================================================
def _make_mock_row(**overrides):
"""Build a mock Row with all canonical_controls columns."""
now = datetime.now(timezone.utc)
defaults = {
"id": "uuid-ctrl-1",
"framework_id": "uuid-fw-1",
"control_id": "AUTH-001",
"title": "Test Control",
"objective": "Test obj",
"rationale": "Test rat",
"scope": {},
"requirements": ["Req 1"],
"test_procedure": ["Test 1"],
"evidence": [],
"severity": "high",
"risk_score": 3.0,
"implementation_effort": "m",
"evidence_confidence": None,
"open_anchors": [],
"release_state": "draft",
"tags": [],
"license_rule": 1,
"source_original_text": None,
"source_citation": None,
"customer_visible": True,
"verification_method": "automated",
"category": "authentication",
"target_audience": "developer",
"generation_metadata": {},
"generation_strategy": "ungrouped",
"created_at": now,
"updated_at": now,
}
defaults.update(overrides)
mock = MagicMock()
for k, v in defaults.items():
setattr(mock, k, v)
return mock
def _session_returning(rows=None, scalar=None):
"""Create a mock SessionLocal that returns rows or scalar."""
db = MagicMock()
result = MagicMock()
if rows is not None:
result.fetchall.return_value = rows
if scalar is not None:
result.scalar.return_value = scalar
db.execute.return_value = result
db.__enter__ = MagicMock(return_value=db)
db.__exit__ = MagicMock(return_value=False)
return db
class TestListControlsPagination:
"""GET /controls with limit/offset."""
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_limit_param_in_sql(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[_make_mock_row()])
resp = _client.get("/api/compliance/v1/canonical/controls?limit=10&offset=20")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "LIMIT" in sql
assert "OFFSET" in sql
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_no_limit_by_default(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "LIMIT" not in sql
class TestListControlsSorting:
"""GET /controls with sort/order."""
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_sort_created_at_desc(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls?sort=created_at&order=desc")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "created_at DESC" in sql
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_default_sort_control_id_asc(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "control_id ASC" in sql
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_sql_injection_in_sort_blocked(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls?sort=1;DROP+TABLE")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "DROP" not in sql
assert "control_id" in sql
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_sort_by_source(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls?sort=source&order=asc")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "source_citation" in sql
assert "control_id ASC" in sql # secondary sort within source group
class TestListControlsSearch:
"""GET /controls with search."""
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_search_uses_ilike(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls?search=encryption")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "ILIKE" in sql
params = mock_cls.return_value.__enter__().execute.call_args[0][1]
assert params["q"] == "%encryption%"
class TestListControlsSourceFilter:
"""GET /controls with source filter."""
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_specific_source(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls?source=DSGVO")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "source_citation" in sql
params = mock_cls.return_value.__enter__().execute.call_args[0][1]
assert params["src"] == "DSGVO"
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_no_source_filter(self, mock_cls):
mock_cls.return_value = _session_returning(rows=[])
resp = _client.get("/api/compliance/v1/canonical/controls?source=__none__")
assert resp.status_code == 200
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "IS NULL" in sql
class TestControlsCount:
"""GET /controls-count."""
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_returns_total(self, mock_cls):
mock_cls.return_value = _session_returning(scalar=42)
resp = _client.get("/api/compliance/v1/canonical/controls-count")
assert resp.status_code == 200
assert resp.json() == {"total": 42}
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_with_filters(self, mock_cls):
mock_cls.return_value = _session_returning(scalar=5)
resp = _client.get("/api/compliance/v1/canonical/controls-count?severity=critical&search=mfa")
assert resp.status_code == 200
assert resp.json() == {"total": 5}
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
assert "severity" in sql
assert "ILIKE" in sql
class TestControlsMeta:
"""GET /controls-meta."""
@patch("compliance.api.canonical_control_routes.SessionLocal")
def test_returns_structure(self, mock_cls):
db = MagicMock()
db.__enter__ = MagicMock(return_value=db)
db.__exit__ = MagicMock(return_value=False)
# 4 sequential execute() calls
total_r = MagicMock(); total_r.scalar.return_value = 100
domain_r = MagicMock(); domain_r.fetchall.return_value = []
source_r = MagicMock(); source_r.fetchall.return_value = []
nosrc_r = MagicMock(); nosrc_r.scalar.return_value = 20
db.execute.side_effect = [total_r, domain_r, source_r, nosrc_r]
mock_cls.return_value = db
resp = _client.get("/api/compliance/v1/canonical/controls-meta")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 100
assert data["no_source_count"] == 20
assert isinstance(data["domains"], list)
assert isinstance(data["sources"], list)