feat(control-library): document-grouped batching, generation strategy tracking, sort by source
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 31s
CI/CD / test-python-backend-compliance (push) Successful in 31s
CI/CD / test-python-document-crawler (push) Successful in 21s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Successful in 2s
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 31s
CI/CD / test-python-backend-compliance (push) Successful in 31s
CI/CD / test-python-document-crawler (push) Successful in 21s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Successful in 2s
- Group chunks by regulation_code before batching for better LLM context - Add generation_strategy column (ungrouped=v1, document_grouped=v2) - Add v1/v2 badge to control cards in frontend - Add sort-by-source option with visual group headers - Add frontend page tests (18 tests) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -80,6 +80,7 @@ class ControlResponse(BaseModel):
|
||||
category: Optional[str] = None
|
||||
target_audience: Optional[str] = None
|
||||
generation_metadata: Optional[dict] = None
|
||||
generation_strategy: Optional[str] = "ungrouped"
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@@ -161,7 +162,7 @@ _CONTROL_COLS = """id, framework_id, control_id, title, objective, rationale,
|
||||
evidence_confidence, open_anchors, release_state, tags,
|
||||
license_rule, source_original_text, source_citation,
|
||||
customer_visible, verification_method, category,
|
||||
target_audience, generation_metadata,
|
||||
target_audience, generation_metadata, generation_strategy,
|
||||
created_at, updated_at"""
|
||||
|
||||
|
||||
@@ -297,8 +298,14 @@ async def list_controls(
|
||||
verification_method: Optional[str] = Query(None),
|
||||
category: Optional[str] = Query(None),
|
||||
target_audience: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None, description="Filter by source_citation->source"),
|
||||
search: Optional[str] = Query(None, description="Full-text search in control_id, title, objective"),
|
||||
sort: Optional[str] = Query("control_id", description="Sort field: control_id, created_at, severity"),
|
||||
order: Optional[str] = Query("asc", description="Sort order: asc or desc"),
|
||||
limit: Optional[int] = Query(None, ge=1, le=5000, description="Max results"),
|
||||
offset: Optional[int] = Query(None, ge=0, description="Offset for pagination"),
|
||||
):
|
||||
"""List all canonical controls, with optional filters."""
|
||||
"""List canonical controls with filters, search, sorting and pagination."""
|
||||
query = f"""
|
||||
SELECT {_CONTROL_COLS}
|
||||
FROM canonical_controls
|
||||
@@ -324,8 +331,35 @@ async def list_controls(
|
||||
if target_audience:
|
||||
query += " AND target_audience = :ta"
|
||||
params["ta"] = target_audience
|
||||
if source:
|
||||
if source == "__none__":
|
||||
query += " AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')"
|
||||
else:
|
||||
query += " AND source_citation->>'source' = :src"
|
||||
params["src"] = source
|
||||
if search:
|
||||
query += " AND (control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)"
|
||||
params["q"] = f"%{search}%"
|
||||
|
||||
query += " ORDER BY control_id"
|
||||
# Sorting
|
||||
sort_col = "control_id"
|
||||
if sort in ("created_at", "updated_at", "severity", "control_id"):
|
||||
sort_col = sort
|
||||
elif sort == "source":
|
||||
sort_col = "source_citation->>'source'"
|
||||
sort_dir = "DESC" if order and order.lower() == "desc" else "ASC"
|
||||
if sort == "source":
|
||||
# Group by source first, then by control_id within each source
|
||||
query += f" ORDER BY {sort_col} {sort_dir} NULLS LAST, control_id ASC"
|
||||
else:
|
||||
query += f" ORDER BY {sort_col} {sort_dir}"
|
||||
|
||||
if limit is not None:
|
||||
query += " LIMIT :lim"
|
||||
params["lim"] = limit
|
||||
if offset is not None:
|
||||
query += " OFFSET :off"
|
||||
params["off"] = offset
|
||||
|
||||
with SessionLocal() as db:
|
||||
rows = db.execute(text(query), params).fetchall()
|
||||
@@ -333,6 +367,87 @@ async def list_controls(
|
||||
return [_control_row(r) for r in rows]
|
||||
|
||||
|
||||
@router.get("/controls-count")
|
||||
async def count_controls(
|
||||
severity: Optional[str] = Query(None),
|
||||
domain: Optional[str] = Query(None),
|
||||
release_state: Optional[str] = Query(None),
|
||||
verification_method: Optional[str] = Query(None),
|
||||
category: Optional[str] = Query(None),
|
||||
target_audience: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
):
|
||||
"""Count controls matching filters (for pagination)."""
|
||||
query = "SELECT count(*) FROM canonical_controls WHERE 1=1"
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if severity:
|
||||
query += " AND severity = :sev"
|
||||
params["sev"] = severity
|
||||
if domain:
|
||||
query += " AND LEFT(control_id, LENGTH(:dom)) = :dom"
|
||||
params["dom"] = domain.upper()
|
||||
if release_state:
|
||||
query += " AND release_state = :rs"
|
||||
params["rs"] = release_state
|
||||
if verification_method:
|
||||
query += " AND verification_method = :vm"
|
||||
params["vm"] = verification_method
|
||||
if category:
|
||||
query += " AND category = :cat"
|
||||
params["cat"] = category
|
||||
if target_audience:
|
||||
query += " AND target_audience = :ta"
|
||||
params["ta"] = target_audience
|
||||
if source:
|
||||
if source == "__none__":
|
||||
query += " AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')"
|
||||
else:
|
||||
query += " AND source_citation->>'source' = :src"
|
||||
params["src"] = source
|
||||
if search:
|
||||
query += " AND (control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)"
|
||||
params["q"] = f"%{search}%"
|
||||
|
||||
with SessionLocal() as db:
|
||||
total = db.execute(text(query), params).scalar()
|
||||
|
||||
return {"total": total}
|
||||
|
||||
|
||||
@router.get("/controls-meta")
|
||||
async def controls_meta():
|
||||
"""Return aggregated metadata for filter dropdowns (domains, sources, counts)."""
|
||||
with SessionLocal() as db:
|
||||
total = db.execute(text("SELECT count(*) FROM canonical_controls")).scalar()
|
||||
|
||||
domains = db.execute(text("""
|
||||
SELECT UPPER(SPLIT_PART(control_id, '-', 1)) as domain, count(*) as cnt
|
||||
FROM canonical_controls
|
||||
GROUP BY domain ORDER BY domain
|
||||
""")).fetchall()
|
||||
|
||||
sources = db.execute(text("""
|
||||
SELECT source_citation->>'source' as src, count(*) as cnt
|
||||
FROM canonical_controls
|
||||
WHERE source_citation->>'source' IS NOT NULL AND source_citation->>'source' != ''
|
||||
GROUP BY src ORDER BY cnt DESC
|
||||
""")).fetchall()
|
||||
|
||||
no_source = db.execute(text("""
|
||||
SELECT count(*) FROM canonical_controls
|
||||
WHERE source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = ''
|
||||
""")).scalar()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"domains": [{"domain": r[0], "count": r[1]} for r in domains],
|
||||
"sources": [{"source": r[0], "count": r[1]} for r in sources],
|
||||
"no_source_count": no_source,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/controls/{control_id}")
|
||||
async def get_control(control_id: str):
|
||||
"""Get a single canonical control by its control_id (e.g. AUTH-001)."""
|
||||
@@ -661,6 +776,7 @@ def _control_row(r) -> dict:
|
||||
"category": r.category,
|
||||
"target_audience": r.target_audience,
|
||||
"generation_metadata": r.generation_metadata,
|
||||
"generation_strategy": getattr(r, "generation_strategy", "ungrouped"),
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Set
|
||||
@@ -368,6 +369,7 @@ class GeneratedControl:
|
||||
source_citation: Optional[dict] = None
|
||||
customer_visible: bool = True
|
||||
generation_metadata: dict = field(default_factory=dict)
|
||||
generation_strategy: str = "ungrouped" # ungrouped | document_grouped
|
||||
# Classification fields
|
||||
verification_method: Optional[str] = None # code_review, document, tool, hybrid
|
||||
category: Optional[str] = None # one of 17 categories
|
||||
@@ -940,6 +942,24 @@ Gib JSON zurück mit diesen Feldern:
|
||||
license_infos: list[dict],
|
||||
) -> list[Optional[GeneratedControl]]:
|
||||
"""Structure multiple free-use/citation chunks in a single Anthropic call."""
|
||||
# Build document context header if chunks share a regulation
|
||||
regulations_in_batch = set(c.regulation_name for c in chunks)
|
||||
doc_context = ""
|
||||
if len(regulations_in_batch) == 1:
|
||||
reg_name = next(iter(regulations_in_batch))
|
||||
articles = sorted(set(c.article or "?" for c in chunks))
|
||||
doc_context = (
|
||||
f"\nDOKUMENTKONTEXT: Alle {len(chunks)} Chunks stammen aus demselben Gesetz: {reg_name}.\n"
|
||||
f"Betroffene Artikel/Abschnitte: {', '.join(articles)}.\n"
|
||||
f"Nutze diesen Zusammenhang fuer eine kohaerente, aufeinander abgestimmte Formulierung der Controls.\n"
|
||||
f"Vermeide Redundanzen zwischen den Controls — jedes soll einen eigenen Aspekt abdecken.\n"
|
||||
)
|
||||
elif len(regulations_in_batch) <= 3:
|
||||
doc_context = (
|
||||
f"\nDOKUMENTKONTEXT: Die Chunks stammen aus {len(regulations_in_batch)} Gesetzen: "
|
||||
f"{', '.join(regulations_in_batch)}.\n"
|
||||
)
|
||||
|
||||
chunk_entries = []
|
||||
for idx, (chunk, lic) in enumerate(zip(chunks, license_infos)):
|
||||
source_name = lic.get("name", chunk.regulation_name)
|
||||
@@ -952,20 +972,21 @@ Gib JSON zurück mit diesen Feldern:
|
||||
joined = "\n\n".join(chunk_entries)
|
||||
prompt = f"""Strukturiere die folgenden {len(chunks)} Gesetzestexte jeweils als eigenstaendiges Security/Compliance Control.
|
||||
Du DARFST den Originaltext verwenden (Quellen sind jeweils angegeben).
|
||||
|
||||
{doc_context}
|
||||
WICHTIG:
|
||||
- Erstelle fuer JEDEN Chunk ein separates Control mit verstaendlicher, praxisorientierter Formulierung.
|
||||
- Jedes Control muss eigenstaendig und vollstaendig sein — nicht auf andere Controls verweisen.
|
||||
- Qualitaet ist wichtiger als Geschwindigkeit. Jedes Control muss die gleiche Qualitaet haben wie ein einzeln erstelltes.
|
||||
- Antworte IMMER auf Deutsch.
|
||||
|
||||
Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat diese Felder:
|
||||
- chunk_index: 1-basierter Index des Chunks (1, 2, 3, ...)
|
||||
- title: Kurzer praegnanter Titel (max 100 Zeichen)
|
||||
- objective: Was soll erreicht werden? (1-3 Saetze)
|
||||
- rationale: Warum ist das wichtig? (1-2 Saetze)
|
||||
- requirements: Liste von konkreten Anforderungen (Strings)
|
||||
- test_procedure: Liste von Pruefschritten (Strings)
|
||||
- evidence: Liste von Nachweisdokumenten (Strings)
|
||||
- title: Kurzer praegnanter Titel auf Deutsch (max 100 Zeichen)
|
||||
- objective: Was soll erreicht werden? (1-3 Saetze, Deutsch)
|
||||
- rationale: Warum ist das wichtig? (1-2 Saetze, Deutsch)
|
||||
- requirements: Liste von konkreten Anforderungen (Strings, Deutsch)
|
||||
- test_procedure: Liste von Pruefschritten (Strings, Deutsch)
|
||||
- evidence: Liste von Nachweisdokumenten (Strings, Deutsch)
|
||||
- severity: low/medium/high/critical
|
||||
- tags: Liste von Tags
|
||||
|
||||
@@ -1003,13 +1024,16 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
control.customer_visible = True
|
||||
control.verification_method = _detect_verification_method(chunk.text)
|
||||
control.category = _detect_category(chunk.text)
|
||||
same_doc = len(set(c.regulation_code for c in chunks)) == 1
|
||||
control.generation_metadata = {
|
||||
"processing_path": "structured_batch",
|
||||
"license_rule": lic["rule"],
|
||||
"source_regulation": chunk.regulation_code,
|
||||
"source_article": chunk.article,
|
||||
"batch_size": len(chunks),
|
||||
"document_grouped": same_doc,
|
||||
}
|
||||
control.generation_strategy = "document_grouped" if same_doc else "ungrouped"
|
||||
controls[idx] = control
|
||||
|
||||
return controls
|
||||
@@ -1369,7 +1393,7 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
open_anchors, release_state, tags,
|
||||
license_rule, source_original_text, source_citation,
|
||||
customer_visible, generation_metadata,
|
||||
verification_method, category
|
||||
verification_method, category, generation_strategy
|
||||
) VALUES (
|
||||
:framework_id, :control_id, :title, :objective, :rationale,
|
||||
:scope, :requirements, :test_procedure, :evidence,
|
||||
@@ -1377,7 +1401,7 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
:open_anchors, :release_state, :tags,
|
||||
:license_rule, :source_original_text, :source_citation,
|
||||
:customer_visible, :generation_metadata,
|
||||
:verification_method, :category
|
||||
:verification_method, :category, :generation_strategy
|
||||
)
|
||||
ON CONFLICT (framework_id, control_id) DO NOTHING
|
||||
RETURNING id
|
||||
@@ -1405,6 +1429,7 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
"generation_metadata": json.dumps(control.generation_metadata) if control.generation_metadata else None,
|
||||
"verification_method": control.verification_method,
|
||||
"category": control.category,
|
||||
"generation_strategy": control.generation_strategy,
|
||||
},
|
||||
)
|
||||
self.db.commit()
|
||||
@@ -1479,21 +1504,48 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
self._update_job(job_id, result)
|
||||
return result
|
||||
|
||||
# ── Group chunks by document (regulation_code) for coherent batching ──
|
||||
doc_groups: dict[str, list[RAGSearchResult]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
group_key = chunk.regulation_code or "unknown"
|
||||
doc_groups[group_key].append(chunk)
|
||||
|
||||
# Sort chunks within each group by article for sequential context
|
||||
for key in doc_groups:
|
||||
doc_groups[key].sort(key=lambda c: (c.article or "", c.paragraph or ""))
|
||||
|
||||
logger.info(
|
||||
"Grouped %d chunks into %d document groups for coherent batching",
|
||||
len(chunks), len(doc_groups),
|
||||
)
|
||||
|
||||
# Flatten back: chunks from same document are now adjacent
|
||||
chunks = []
|
||||
for group_list in doc_groups.values():
|
||||
chunks.extend(group_list)
|
||||
|
||||
# Process chunks — batch mode (N chunks per Anthropic API call)
|
||||
BATCH_SIZE = config.batch_size or 5
|
||||
controls_count = 0
|
||||
chunks_skipped_prefilter = 0
|
||||
pending_batch: list[tuple[RAGSearchResult, dict]] = [] # (chunk, license_info)
|
||||
current_batch_regulation: Optional[str] = None # Track regulation for group-aware flushing
|
||||
|
||||
async def _flush_batch():
|
||||
"""Send pending batch to Anthropic and process results."""
|
||||
nonlocal controls_count
|
||||
nonlocal controls_count, current_batch_regulation
|
||||
if not pending_batch:
|
||||
return
|
||||
batch = pending_batch.copy()
|
||||
pending_batch.clear()
|
||||
current_batch_regulation = None
|
||||
|
||||
logger.info("Processing batch of %d chunks via single API call...", len(batch))
|
||||
# Log which document this batch belongs to
|
||||
regs_in_batch = set(c.regulation_code for c, _ in batch)
|
||||
logger.info(
|
||||
"Processing batch of %d chunks (docs: %s) via single API call...",
|
||||
len(batch), ", ".join(regs_in_batch),
|
||||
)
|
||||
try:
|
||||
batch_controls = await self._process_batch(batch, config, job_id)
|
||||
except Exception as e:
|
||||
@@ -1514,6 +1566,9 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
self._mark_chunk_processed(chunk, lic_info, "no_control", [], job_id)
|
||||
continue
|
||||
|
||||
# Mark as document_grouped strategy
|
||||
control.generation_strategy = "document_grouped"
|
||||
|
||||
# Count by state
|
||||
if control.release_state == "too_close":
|
||||
result.controls_too_close += 1
|
||||
@@ -1567,12 +1622,18 @@ Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Objekten. Jedes Objekt hat di
|
||||
|
||||
# Classify license and add to batch
|
||||
license_info = self._classify_license(chunk)
|
||||
pending_batch.append((chunk, license_info))
|
||||
chunk_regulation = chunk.regulation_code or "unknown"
|
||||
|
||||
# Flush when batch is full
|
||||
if len(pending_batch) >= BATCH_SIZE:
|
||||
# Flush when: batch is full OR regulation changes (group boundary)
|
||||
if pending_batch and (
|
||||
len(pending_batch) >= BATCH_SIZE
|
||||
or chunk_regulation != current_batch_regulation
|
||||
):
|
||||
await _flush_batch()
|
||||
|
||||
pending_batch.append((chunk, license_info))
|
||||
current_batch_regulation = chunk_regulation
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing chunk {chunk.regulation_code}/{chunk.article}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
23
backend-compliance/migrations/057_processing_path_batch.sql
Normal file
23
backend-compliance/migrations/057_processing_path_batch.sql
Normal file
@@ -0,0 +1,23 @@
|
||||
-- 057: Add batch processing paths to canonical_processed_chunks
|
||||
-- New values: structured_batch, llm_reform_batch (used by batch control generation)
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'canonical_processed_chunks') THEN
|
||||
ALTER TABLE canonical_processed_chunks
|
||||
DROP CONSTRAINT IF EXISTS canonical_processed_chunks_processing_path_check;
|
||||
ALTER TABLE canonical_processed_chunks
|
||||
ADD CONSTRAINT canonical_processed_chunks_processing_path_check
|
||||
CHECK (processing_path IN (
|
||||
'structured',
|
||||
'llm_reform',
|
||||
'skipped',
|
||||
'prefilter_skip',
|
||||
'no_control',
|
||||
'store_failed',
|
||||
'error',
|
||||
'structured_batch',
|
||||
'llm_reform_batch'
|
||||
));
|
||||
END IF;
|
||||
END $$;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- Migration 058: Add generation_strategy column to canonical_controls
|
||||
-- Tracks whether a control was generated with document-grouped or ungrouped batching
|
||||
|
||||
ALTER TABLE canonical_controls
|
||||
ADD COLUMN IF NOT EXISTS generation_strategy TEXT NOT NULL DEFAULT 'ungrouped';
|
||||
|
||||
COMMENT ON COLUMN canonical_controls.generation_strategy IS
|
||||
'How chunks were batched during generation: ungrouped (random), document_grouped (by regulation+article)';
|
||||
@@ -1,17 +1,36 @@
|
||||
"""Tests for Canonical Control Library routes (canonical_control_routes.py)."""
|
||||
"""Tests for Canonical Control Library routes (canonical_control_routes.py).
|
||||
|
||||
Includes:
|
||||
- Model validation tests (FrameworkResponse, ControlResponse, etc.)
|
||||
- _control_row conversion tests
|
||||
- Server-side pagination, sorting, search, source filter tests
|
||||
- /controls-count and /controls-meta endpoint tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from compliance.api.canonical_control_routes import (
|
||||
FrameworkResponse,
|
||||
ControlResponse,
|
||||
SimilarityCheckRequest,
|
||||
SimilarityCheckResponse,
|
||||
_control_row,
|
||||
router,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestClient setup for endpoint tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_app = FastAPI()
|
||||
_app.include_router(router, prefix="/api/compliance")
|
||||
_client = TestClient(_app)
|
||||
|
||||
|
||||
class TestFrameworkResponse:
|
||||
"""Tests for FrameworkResponse model."""
|
||||
@@ -175,6 +194,7 @@ class TestControlRowConversion:
|
||||
],
|
||||
"release_state": "draft",
|
||||
"tags": ["mfa"],
|
||||
"generation_strategy": "ungrouped",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
@@ -223,3 +243,213 @@ class TestControlRowConversion:
|
||||
result = _control_row(row)
|
||||
assert result["created_at"] is None
|
||||
assert result["updated_at"] is None
|
||||
|
||||
def test_generation_strategy_default(self):
|
||||
row = self._make_row()
|
||||
result = _control_row(row)
|
||||
assert result["generation_strategy"] == "ungrouped"
|
||||
|
||||
def test_generation_strategy_document_grouped(self):
|
||||
row = self._make_row(generation_strategy="document_grouped")
|
||||
result = _control_row(row)
|
||||
assert result["generation_strategy"] == "document_grouped"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ENDPOINT TESTS — Server-Side Pagination, Sort, Search, Source Filter
|
||||
# =============================================================================
|
||||
|
||||
def _make_mock_row(**overrides):
|
||||
"""Build a mock Row with all canonical_controls columns."""
|
||||
now = datetime.now(timezone.utc)
|
||||
defaults = {
|
||||
"id": "uuid-ctrl-1",
|
||||
"framework_id": "uuid-fw-1",
|
||||
"control_id": "AUTH-001",
|
||||
"title": "Test Control",
|
||||
"objective": "Test obj",
|
||||
"rationale": "Test rat",
|
||||
"scope": {},
|
||||
"requirements": ["Req 1"],
|
||||
"test_procedure": ["Test 1"],
|
||||
"evidence": [],
|
||||
"severity": "high",
|
||||
"risk_score": 3.0,
|
||||
"implementation_effort": "m",
|
||||
"evidence_confidence": None,
|
||||
"open_anchors": [],
|
||||
"release_state": "draft",
|
||||
"tags": [],
|
||||
"license_rule": 1,
|
||||
"source_original_text": None,
|
||||
"source_citation": None,
|
||||
"customer_visible": True,
|
||||
"verification_method": "automated",
|
||||
"category": "authentication",
|
||||
"target_audience": "developer",
|
||||
"generation_metadata": {},
|
||||
"generation_strategy": "ungrouped",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock()
|
||||
for k, v in defaults.items():
|
||||
setattr(mock, k, v)
|
||||
return mock
|
||||
|
||||
|
||||
def _session_returning(rows=None, scalar=None):
|
||||
"""Create a mock SessionLocal that returns rows or scalar."""
|
||||
db = MagicMock()
|
||||
result = MagicMock()
|
||||
if rows is not None:
|
||||
result.fetchall.return_value = rows
|
||||
if scalar is not None:
|
||||
result.scalar.return_value = scalar
|
||||
db.execute.return_value = result
|
||||
db.__enter__ = MagicMock(return_value=db)
|
||||
db.__exit__ = MagicMock(return_value=False)
|
||||
return db
|
||||
|
||||
|
||||
class TestListControlsPagination:
|
||||
"""GET /controls with limit/offset."""
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_limit_param_in_sql(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[_make_mock_row()])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?limit=10&offset=20")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "LIMIT" in sql
|
||||
assert "OFFSET" in sql
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_no_limit_by_default(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "LIMIT" not in sql
|
||||
|
||||
|
||||
class TestListControlsSorting:
|
||||
"""GET /controls with sort/order."""
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_sort_created_at_desc(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?sort=created_at&order=desc")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "created_at DESC" in sql
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_default_sort_control_id_asc(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "control_id ASC" in sql
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_sql_injection_in_sort_blocked(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?sort=1;DROP+TABLE")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "DROP" not in sql
|
||||
assert "control_id" in sql
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_sort_by_source(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?sort=source&order=asc")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "source_citation" in sql
|
||||
assert "control_id ASC" in sql # secondary sort within source group
|
||||
|
||||
|
||||
class TestListControlsSearch:
|
||||
"""GET /controls with search."""
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_search_uses_ilike(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?search=encryption")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "ILIKE" in sql
|
||||
params = mock_cls.return_value.__enter__().execute.call_args[0][1]
|
||||
assert params["q"] == "%encryption%"
|
||||
|
||||
|
||||
class TestListControlsSourceFilter:
|
||||
"""GET /controls with source filter."""
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_specific_source(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?source=DSGVO")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "source_citation" in sql
|
||||
params = mock_cls.return_value.__enter__().execute.call_args[0][1]
|
||||
assert params["src"] == "DSGVO"
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_no_source_filter(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(rows=[])
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls?source=__none__")
|
||||
assert resp.status_code == 200
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "IS NULL" in sql
|
||||
|
||||
|
||||
class TestControlsCount:
|
||||
"""GET /controls-count."""
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_returns_total(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(scalar=42)
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls-count")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"total": 42}
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_with_filters(self, mock_cls):
|
||||
mock_cls.return_value = _session_returning(scalar=5)
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls-count?severity=critical&search=mfa")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"total": 5}
|
||||
sql = str(mock_cls.return_value.__enter__().execute.call_args[0][0].text)
|
||||
assert "severity" in sql
|
||||
assert "ILIKE" in sql
|
||||
|
||||
|
||||
class TestControlsMeta:
|
||||
"""GET /controls-meta."""
|
||||
|
||||
@patch("compliance.api.canonical_control_routes.SessionLocal")
|
||||
def test_returns_structure(self, mock_cls):
|
||||
db = MagicMock()
|
||||
db.__enter__ = MagicMock(return_value=db)
|
||||
db.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# 4 sequential execute() calls
|
||||
total_r = MagicMock(); total_r.scalar.return_value = 100
|
||||
domain_r = MagicMock(); domain_r.fetchall.return_value = []
|
||||
source_r = MagicMock(); source_r.fetchall.return_value = []
|
||||
nosrc_r = MagicMock(); nosrc_r.scalar.return_value = 20
|
||||
db.execute.side_effect = [total_r, domain_r, source_r, nosrc_r]
|
||||
mock_cls.return_value = db
|
||||
|
||||
resp = _client.get("/api/compliance/v1/canonical/controls-meta")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 100
|
||||
assert data["no_source_count"] == 20
|
||||
assert isinstance(data["domains"], list)
|
||||
assert isinstance(data["sources"], list)
|
||||
|
||||
Reference in New Issue
Block a user