merge: sync with origin/main, take upstream on conflicts

# Conflicts:
#	admin-compliance/lib/sdk/types.ts
#	admin-compliance/lib/sdk/vendor-compliance/types.ts
This commit is contained in:
Sharang Parnerkar
2026-04-16 16:26:48 +02:00
352 changed files with 181673 additions and 2188 deletions
@@ -6,6 +6,8 @@ from .routes import router
logger = logging.getLogger(__name__)
_failed_routers: dict[str, str] = {}
def _safe_import_router(module_name: str, attr: str = "router"):
"""Import a router module safely — log error but don't crash the whole app."""
@@ -14,6 +16,7 @@ def _safe_import_router(module_name: str, attr: str = "router"):
return getattr(mod, attr)
except Exception as e:
logger.error("Failed to import %s: %s", module_name, e)
_failed_routers[module_name] = str(e)
return None
@@ -53,6 +56,13 @@ _ROUTER_MODULES = [
"wiki_routes",
"canonical_control_routes",
"control_generator_routes",
"crosswalk_routes",
"process_task_routes",
"evidence_check_routes",
"vvt_library_routes",
"tom_mapping_routes",
"llm_audit_routes",
"assertion_routes",
]
_loaded_count = 0
@@ -0,0 +1,227 @@
"""
API routes for Assertion Engine (Anti-Fake-Evidence Phase 2).
Endpoints:
- /assertions: CRUD for assertions
- /assertions/extract: Automatic extraction from entity text
- /assertions/summary: Stats (total assertions, facts, unverified)
"""
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from ..db.models import AssertionDB
from ..services.assertion_engine import extract_assertions
from .schemas import (
AssertionCreate,
AssertionUpdate,
AssertionResponse,
AssertionListResponse,
AssertionSummaryResponse,
AssertionExtractRequest,
)
from .audit_trail_utils import log_audit_trail, generate_id
logger = logging.getLogger(__name__)
router = APIRouter(tags=["compliance-assertions"])
def _build_assertion_response(a: AssertionDB) -> AssertionResponse:
return AssertionResponse(
id=a.id,
tenant_id=a.tenant_id,
entity_type=a.entity_type,
entity_id=a.entity_id,
sentence_text=a.sentence_text,
sentence_index=a.sentence_index,
assertion_type=a.assertion_type,
evidence_ids=a.evidence_ids or [],
confidence=a.confidence or 0.0,
normative_tier=a.normative_tier,
verified_by=a.verified_by,
verified_at=a.verified_at,
created_at=a.created_at,
updated_at=a.updated_at,
)
@router.post("/assertions", response_model=AssertionResponse)
async def create_assertion(
data: AssertionCreate,
tenant_id: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
"""Create a single assertion manually."""
a = AssertionDB(
id=generate_id(),
tenant_id=tenant_id,
entity_type=data.entity_type,
entity_id=data.entity_id,
sentence_text=data.sentence_text,
assertion_type=data.assertion_type or "assertion",
evidence_ids=data.evidence_ids or [],
normative_tier=data.normative_tier,
)
db.add(a)
db.commit()
db.refresh(a)
return _build_assertion_response(a)
@router.get("/assertions", response_model=AssertionListResponse)
async def list_assertions(
entity_type: Optional[str] = Query(None),
entity_id: Optional[str] = Query(None),
assertion_type: Optional[str] = Query(None),
tenant_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=500),
db: Session = Depends(get_db),
):
"""List assertions with optional filters."""
query = db.query(AssertionDB)
if entity_type:
query = query.filter(AssertionDB.entity_type == entity_type)
if entity_id:
query = query.filter(AssertionDB.entity_id == entity_id)
if assertion_type:
query = query.filter(AssertionDB.assertion_type == assertion_type)
if tenant_id:
query = query.filter(AssertionDB.tenant_id == tenant_id)
total = query.count()
records = query.order_by(AssertionDB.sentence_index.asc()).limit(limit).all()
return AssertionListResponse(
assertions=[_build_assertion_response(a) for a in records],
total=total,
)
@router.get("/assertions/summary", response_model=AssertionSummaryResponse)
async def assertion_summary(
tenant_id: Optional[str] = Query(None),
entity_type: Optional[str] = Query(None),
entity_id: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
"""Summary stats: total assertions, facts, rationale, unverified."""
query = db.query(AssertionDB)
if tenant_id:
query = query.filter(AssertionDB.tenant_id == tenant_id)
if entity_type:
query = query.filter(AssertionDB.entity_type == entity_type)
if entity_id:
query = query.filter(AssertionDB.entity_id == entity_id)
all_records = query.all()
total = len(all_records)
facts = sum(1 for a in all_records if a.assertion_type == "fact")
rationale = sum(1 for a in all_records if a.assertion_type == "rationale")
unverified = sum(1 for a in all_records if a.assertion_type == "assertion" and not a.verified_by)
return AssertionSummaryResponse(
total_assertions=total,
total_facts=facts,
total_rationale=rationale,
unverified_count=unverified,
)
@router.get("/assertions/{assertion_id}", response_model=AssertionResponse)
async def get_assertion(
assertion_id: str,
db: Session = Depends(get_db),
):
"""Get a single assertion by ID."""
a = db.query(AssertionDB).filter(AssertionDB.id == assertion_id).first()
if not a:
raise HTTPException(status_code=404, detail=f"Assertion {assertion_id} not found")
return _build_assertion_response(a)
@router.put("/assertions/{assertion_id}", response_model=AssertionResponse)
async def update_assertion(
assertion_id: str,
data: AssertionUpdate,
db: Session = Depends(get_db),
):
"""Update an assertion (e.g. link evidence, change type)."""
a = db.query(AssertionDB).filter(AssertionDB.id == assertion_id).first()
if not a:
raise HTTPException(status_code=404, detail=f"Assertion {assertion_id} not found")
update_fields = data.model_dump(exclude_unset=True)
for key, value in update_fields.items():
setattr(a, key, value)
a.updated_at = datetime.utcnow()
db.commit()
db.refresh(a)
return _build_assertion_response(a)
@router.post("/assertions/{assertion_id}/verify", response_model=AssertionResponse)
async def verify_assertion(
assertion_id: str,
verified_by: str = Query(...),
db: Session = Depends(get_db),
):
"""Mark an assertion as verified fact."""
a = db.query(AssertionDB).filter(AssertionDB.id == assertion_id).first()
if not a:
raise HTTPException(status_code=404, detail=f"Assertion {assertion_id} not found")
a.assertion_type = "fact"
a.verified_by = verified_by
a.verified_at = datetime.utcnow()
a.updated_at = datetime.utcnow()
db.commit()
db.refresh(a)
return _build_assertion_response(a)
@router.post("/assertions/extract", response_model=AssertionListResponse)
async def extract_assertions_endpoint(
data: AssertionExtractRequest,
tenant_id: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
"""Extract assertions from free text and persist them."""
extracted = extract_assertions(
text=data.text,
entity_type=data.entity_type,
entity_id=data.entity_id,
tenant_id=tenant_id,
)
created = []
for item in extracted:
a = AssertionDB(
id=generate_id(),
tenant_id=item["tenant_id"],
entity_type=item["entity_type"],
entity_id=item["entity_id"],
sentence_text=item["sentence_text"],
sentence_index=item["sentence_index"],
assertion_type=item["assertion_type"],
evidence_ids=item["evidence_ids"],
normative_tier=item.get("normative_tier"),
confidence=item.get("confidence", 0.0),
)
db.add(a)
created.append(a)
db.commit()
for a in created:
db.refresh(a)
return AssertionListResponse(
assertions=[_build_assertion_response(a) for a in created],
total=len(created),
)
@@ -0,0 +1,53 @@
"""Shared audit trail utilities.
Extracted from isms_routes.py for reuse across evidence, control,
and assertion routes.
"""
import hashlib
import uuid
from datetime import datetime
from sqlalchemy.orm import Session
from ..db.models import AuditTrailDB
def generate_id() -> str:
"""Generate a UUID string."""
return str(uuid.uuid4())
def create_signature(data: str) -> str:
"""Create SHA-256 signature."""
return hashlib.sha256(data.encode()).hexdigest()
def log_audit_trail(
db: Session,
entity_type: str,
entity_id: str,
entity_name: str,
action: str,
performed_by: str,
field_changed: str = None,
old_value: str = None,
new_value: str = None,
change_summary: str = None,
):
"""Log an entry to the audit trail."""
trail = AuditTrailDB(
id=generate_id(),
entity_type=entity_type,
entity_id=entity_id,
entity_name=entity_name,
action=action,
field_changed=field_changed,
old_value=old_value,
new_value=new_value,
change_summary=change_summary,
performed_by=performed_by,
performed_at=datetime.utcnow(),
checksum=create_signature(f"{entity_type}|{entity_id}|{action}|{performed_by}"),
)
db.add(trail)
File diff suppressed because it is too large Load Diff
@@ -12,6 +12,7 @@ Endpoints:
POST /v1/canonical/blocked-sources/cleanup — Start cleanup workflow
"""
import asyncio
import json
import logging
from typing import Optional, List
@@ -25,7 +26,16 @@ from compliance.services.control_generator import (
ControlGeneratorPipeline,
GeneratorConfig,
ALL_COLLECTIONS,
VALID_CATEGORIES,
VALID_DOMAINS,
_classify_regulation,
_detect_category,
_detect_domain,
_llm_local,
_parse_llm_json,
CATEGORY_LIST_STR,
)
from compliance.services.citation_backfill import CitationBackfill, BackfillResult
from compliance.services.rag_client import get_rag_client
logger = logging.getLogger(__name__)
@@ -40,9 +50,12 @@ class GenerateRequest(BaseModel):
domain: Optional[str] = None
collections: Optional[List[str]] = None
max_controls: int = 50
max_chunks: int = 1000 # Default: process max 1000 chunks per job (respects document boundaries)
batch_size: int = 5
skip_web_search: bool = False
dry_run: bool = False
regulation_filter: Optional[List[str]] = None # Only process these regulation_code prefixes
skip_prefilter: bool = False # Skip local LLM pre-filter, send all chunks to API
class GenerateResponse(BaseModel):
@@ -55,6 +68,7 @@ class GenerateResponse(BaseModel):
controls_needs_review: int = 0
controls_too_close: int = 0
controls_duplicates_found: int = 0
controls_qa_fixed: int = 0
errors: list = []
controls: list = []
@@ -89,42 +103,111 @@ class BlockedSourceResponse(BaseModel):
# ENDPOINTS
# =============================================================================
async def _run_pipeline_background(config: GeneratorConfig, job_id: str):
"""Run the pipeline in the background. Uses its own DB session."""
db = SessionLocal()
try:
config.existing_job_id = job_id
pipeline = ControlGeneratorPipeline(db=db, rag_client=get_rag_client())
result = await pipeline.run(config)
logger.info(
"Background generation job %s completed: %d controls from %d chunks",
job_id, result.controls_generated, result.total_chunks_scanned,
)
except Exception as e:
logger.error("Background generation job %s failed: %s", job_id, e)
# Update job as failed
try:
db.execute(
text("""
UPDATE canonical_generation_jobs
SET status = 'failed', errors = :errors, completed_at = NOW()
WHERE id = CAST(:job_id AS uuid)
"""),
{"job_id": job_id, "errors": json.dumps([str(e)])},
)
db.commit()
except Exception:
pass
finally:
db.close()
@router.post("/generate", response_model=GenerateResponse)
async def start_generation(req: GenerateRequest):
"""Start a control generation run."""
"""Start a control generation run (runs in background).
Returns immediately with job_id. Use GET /generate/status/{job_id} to poll progress.
"""
config = GeneratorConfig(
collections=req.collections,
domain=req.domain,
batch_size=req.batch_size,
max_controls=req.max_controls,
max_chunks=req.max_chunks,
skip_web_search=req.skip_web_search,
dry_run=req.dry_run,
regulation_filter=req.regulation_filter,
skip_prefilter=req.skip_prefilter,
)
if req.dry_run:
# Dry run: execute synchronously and return controls
db = SessionLocal()
try:
pipeline = ControlGeneratorPipeline(db=db, rag_client=get_rag_client())
result = await pipeline.run(config)
return GenerateResponse(
job_id=result.job_id,
status=result.status,
message=f"Dry run: {result.controls_generated} controls from {result.total_chunks_scanned} chunks",
total_chunks_scanned=result.total_chunks_scanned,
controls_generated=result.controls_generated,
controls_verified=result.controls_verified,
controls_needs_review=result.controls_needs_review,
controls_too_close=result.controls_too_close,
controls_duplicates_found=result.controls_duplicates_found,
errors=result.errors,
controls=result.controls,
)
except Exception as e:
logger.error("Dry run failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# Create job record first so we can return the ID
db = SessionLocal()
try:
pipeline = ControlGeneratorPipeline(db=db, rag_client=get_rag_client())
result = await pipeline.run(config)
return GenerateResponse(
job_id=result.job_id,
status=result.status,
message=f"Generated {result.controls_generated} controls from {result.total_chunks_scanned} chunks",
total_chunks_scanned=result.total_chunks_scanned,
controls_generated=result.controls_generated,
controls_verified=result.controls_verified,
controls_needs_review=result.controls_needs_review,
controls_too_close=result.controls_too_close,
controls_duplicates_found=result.controls_duplicates_found,
errors=result.errors,
controls=result.controls if req.dry_run else [],
result = db.execute(
text("""
INSERT INTO canonical_generation_jobs (status, config)
VALUES ('running', :config)
RETURNING id
"""),
{"config": json.dumps(config.model_dump())},
)
db.commit()
row = result.fetchone()
job_id = str(row[0]) if row else None
except Exception as e:
logger.error("Generation failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
logger.error("Failed to create job: %s", e)
raise HTTPException(status_code=500, detail=f"Failed to create job: {e}")
finally:
db.close()
if not job_id:
raise HTTPException(status_code=500, detail="Failed to create job record")
# Launch pipeline in background
asyncio.create_task(_run_pipeline_background(config, job_id))
return GenerateResponse(
job_id=job_id,
status="running",
message="Generation started in background. Poll /generate/status/{job_id} for progress.",
)
@router.get("/generate/status/{job_id}")
async def get_job_status(job_id: str):
@@ -132,7 +215,7 @@ async def get_job_status(job_id: str):
db = SessionLocal()
try:
result = db.execute(
text("SELECT * FROM canonical_generation_jobs WHERE id = :id::uuid"),
text("SELECT * FROM canonical_generation_jobs WHERE id = CAST(:id AS uuid)"),
{"id": job_id},
)
row = result.fetchone()
@@ -270,6 +353,188 @@ async def review_control(control_id: str, req: ReviewRequest):
db.close()
class BulkReviewRequest(BaseModel):
release_state: str # Filter: which controls to bulk-review
action: str # "approve" or "reject"
new_state: Optional[str] = None # Override target state
@router.post("/generate/bulk-review")
async def bulk_review(req: BulkReviewRequest):
"""Bulk review all controls matching a release_state filter.
Example: reject all needs_review → sets them to deprecated.
"""
if req.release_state not in ("needs_review", "too_close", "duplicate"):
raise HTTPException(status_code=400, detail=f"Invalid filter state: {req.release_state}")
if req.action == "approve":
target = req.new_state or "draft"
elif req.action == "reject":
target = "deprecated"
else:
raise HTTPException(status_code=400, detail=f"Unknown action: {req.action}")
if target not in ("draft", "review", "approved", "deprecated", "needs_review"):
raise HTTPException(status_code=400, detail=f"Invalid target state: {target}")
db = SessionLocal()
try:
result = db.execute(
text("""
UPDATE canonical_controls
SET release_state = :target, updated_at = NOW()
WHERE release_state = :source
RETURNING control_id
"""),
{"source": req.release_state, "target": target},
)
affected = [row[0] for row in result]
db.commit()
return {
"action": req.action,
"source_state": req.release_state,
"target_state": target,
"affected_count": len(affected),
}
finally:
db.close()
class QAReclassifyRequest(BaseModel):
limit: int = 100 # How many controls to reclassify per run
dry_run: bool = True # Preview only by default
filter_category: Optional[str] = None # Only reclassify controls of this category
filter_domain_prefix: Optional[str] = None # Only reclassify controls with this prefix
@router.post("/generate/qa-reclassify")
async def qa_reclassify(req: QAReclassifyRequest):
"""Run QA reclassification on existing controls using local LLM.
Finds controls where keyword-detection disagrees with current category/domain,
then uses Ollama to determine the correct classification.
"""
db = SessionLocal()
try:
# Load controls to check
where_clauses = ["release_state NOT IN ('deprecated')"]
params = {"limit": req.limit}
if req.filter_category:
where_clauses.append("category = :cat")
params["cat"] = req.filter_category
if req.filter_domain_prefix:
where_clauses.append("control_id LIKE :prefix")
params["prefix"] = f"{req.filter_domain_prefix}-%"
where_sql = " AND ".join(where_clauses)
rows = db.execute(
text(f"""
SELECT id, control_id, title, objective, category,
COALESCE(requirements::text, '[]') as requirements,
COALESCE(source_original_text, '') as source_text
FROM canonical_controls
WHERE {where_sql}
ORDER BY created_at DESC
LIMIT :limit
"""),
params,
).fetchall()
results = {"checked": 0, "mismatches": 0, "fixes": [], "errors": []}
for row in rows:
results["checked"] += 1
control_id = row[1]
title = row[2]
objective = row[3] or ""
current_category = row[4]
source_text = row[6] or objective
# Keyword detection on source text
kw_category = _detect_category(source_text) or _detect_category(objective)
kw_domain = _detect_domain(source_text)
current_prefix = control_id.split("-")[0] if "-" in control_id else ""
# Skip if keyword detection agrees with current classification
if kw_category == current_category and kw_domain == current_prefix:
continue
results["mismatches"] += 1
# Ask Ollama to arbitrate
try:
reqs_text = ""
try:
reqs = json.loads(row[5])
if isinstance(reqs, list):
reqs_text = ", ".join(str(r) for r in reqs[:3])
except Exception:
pass
prompt = f"""Pruefe dieses Compliance-Control auf korrekte Klassifizierung.
Titel: {title[:100]}
Ziel: {objective[:200]}
Anforderungen: {reqs_text[:200]}
Aktuelle Zuordnung: domain={current_prefix}, category={current_category}
Keyword-Erkennung: domain={kw_domain}, category={kw_category}
Welche Zuordnung ist korrekt? Antworte NUR als JSON:
{{"domain": "KUERZEL", "category": "kategorie_name", "reason": "kurze Begruendung"}}
Domains: AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe, ENV=Umwelt, HLT=Gesundheit
Kategorien: {CATEGORY_LIST_STR}"""
raw = await _llm_local(prompt)
data = _parse_llm_json(raw)
if not data:
continue
qa_domain = data.get("domain", "").upper()
qa_category = data.get("category", "")
reason = data.get("reason", "")
fix_entry = {
"control_id": control_id,
"title": title[:80],
"old_category": current_category,
"old_domain": current_prefix,
"new_category": qa_category if qa_category in VALID_CATEGORIES else current_category,
"new_domain": qa_domain if qa_domain in VALID_DOMAINS else current_prefix,
"reason": reason,
}
category_changed = qa_category in VALID_CATEGORIES and qa_category != current_category
if category_changed and not req.dry_run:
db.execute(
text("""
UPDATE canonical_controls
SET category = :category, updated_at = NOW()
WHERE id = :id
"""),
{"id": row[0], "category": qa_category},
)
fix_entry["applied"] = True
else:
fix_entry["applied"] = False
results["fixes"].append(fix_entry)
except Exception as e:
results["errors"].append({"control_id": control_id, "error": str(e)})
if not req.dry_run:
db.commit()
return results
finally:
db.close()
@router.get("/generate/processed-stats")
async def get_processed_stats():
"""Get processing statistics per collection."""
@@ -429,3 +694,407 @@ async def get_controls_customer_view(
return {"controls": controls, "total": len(controls)}
finally:
db.close()
# =============================================================================
# CITATION BACKFILL
# =============================================================================
class BackfillRequest(BaseModel):
dry_run: bool = True # Default to dry_run for safety
limit: int = 0 # 0 = all controls
class BackfillResponse(BaseModel):
status: str
total_controls: int = 0
matched_hash: int = 0
matched_regex: int = 0
matched_llm: int = 0
unmatched: int = 0
updated: int = 0
errors: list = []
_backfill_status: dict = {}
async def _run_backfill_background(dry_run: bool, limit: int, backfill_id: str):
"""Run backfill in background with own DB session."""
db = SessionLocal()
try:
backfill = CitationBackfill(db=db, rag_client=get_rag_client())
result = await backfill.run(dry_run=dry_run, limit=limit)
_backfill_status[backfill_id] = {
"status": "completed",
"total_controls": result.total_controls,
"matched_hash": result.matched_hash,
"matched_regex": result.matched_regex,
"matched_llm": result.matched_llm,
"unmatched": result.unmatched,
"updated": result.updated,
"errors": result.errors[:50],
}
logger.info("Backfill %s completed: %d updated", backfill_id, result.updated)
except Exception as e:
logger.error("Backfill %s failed: %s", backfill_id, e)
_backfill_status[backfill_id] = {"status": "failed", "errors": [str(e)]}
finally:
db.close()
@router.post("/generate/backfill-citations", response_model=BackfillResponse)
async def start_backfill(req: BackfillRequest):
"""Backfill article/paragraph into existing control source_citations.
Uses 3-tier matching: hash lookup → regex parse → Ollama LLM.
Default is dry_run=True (preview only, no DB changes).
"""
import uuid
backfill_id = str(uuid.uuid4())[:8]
_backfill_status[backfill_id] = {"status": "running"}
# Always run in background (RAG index build takes minutes)
asyncio.create_task(_run_backfill_background(req.dry_run, req.limit, backfill_id))
return BackfillResponse(
status=f"running (id={backfill_id})",
)
@router.get("/generate/backfill-status/{backfill_id}")
async def get_backfill_status(backfill_id: str):
"""Get status of a backfill job."""
status = _backfill_status.get(backfill_id)
if not status:
raise HTTPException(status_code=404, detail="Backfill job not found")
return status
# =============================================================================
# DOMAIN + TARGET AUDIENCE BACKFILL
# =============================================================================
class DomainBackfillRequest(BaseModel):
dry_run: bool = True
job_id: Optional[str] = None # Only backfill controls from this job
limit: int = 0 # 0 = all
_domain_backfill_status: dict = {}
async def _run_domain_backfill(req: DomainBackfillRequest, backfill_id: str):
"""Backfill domain, category, and target_audience for existing controls using Anthropic."""
import os
import httpx
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
ANTHROPIC_MODEL = os.getenv("CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6")
if not ANTHROPIC_API_KEY:
_domain_backfill_status[backfill_id] = {
"status": "failed", "error": "ANTHROPIC_API_KEY not set"
}
return
db = SessionLocal()
try:
# Find controls needing backfill
where_clauses = ["(target_audience IS NULL OR target_audience = '[]' OR target_audience = 'null')"]
params: dict = {}
if req.job_id:
where_clauses.append("generation_metadata->>'job_id' = :job_id")
params["job_id"] = req.job_id
query = f"""
SELECT id, control_id, title, objective, category, source_original_text, tags
FROM canonical_controls
WHERE {' AND '.join(where_clauses)}
ORDER BY control_id
"""
if req.limit > 0:
query += f" LIMIT {req.limit}"
result = db.execute(text(query), params)
controls = [dict(zip(result.keys(), row)) for row in result]
total = len(controls)
updated = 0
errors = []
_domain_backfill_status[backfill_id] = {
"status": "running", "total": total, "updated": 0, "errors": []
}
# Process in batches of 10
BATCH_SIZE = 10
for batch_start in range(0, total, BATCH_SIZE):
batch = controls[batch_start:batch_start + BATCH_SIZE]
entries = []
for idx, ctrl in enumerate(batch):
text_for_analysis = ctrl.get("objective") or ctrl.get("title") or ""
original = ctrl.get("source_original_text") or ""
if original:
text_for_analysis += f"\n\nQuelltext-Auszug: {original[:500]}"
entries.append(
f"--- CONTROL {idx + 1}: {ctrl['control_id']} ---\n"
f"Titel: {ctrl.get('title', '')}\n"
f"Objective: {text_for_analysis[:800]}\n"
f"Tags: {json.dumps(ctrl.get('tags', []))}"
)
prompt = f"""Analysiere die folgenden {len(batch)} Controls und bestimme fuer jedes:
1. domain: Das Fachgebiet (AUTH, CRYP, NET, DATA, LOG, ACC, SEC, INC, AI, COMP, GOV, LAB, FIN, TRD, ENV, HLT)
2. category: Die Kategorie (encryption, authentication, network, data_protection, logging, incident, continuity, compliance, supply_chain, physical, personnel, application, system, risk, governance, hardware, identity, public_administration, labor_law, finance, trade_regulation, environmental, health)
3. target_audience: Liste der Zielgruppen (moegliche Werte: "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "einkauf", "produktion", "vertrieb", "gesundheitswesen", "finanzwesen", "oeffentlicher_dienst")
Antworte mit einem JSON-Array mit {len(batch)} Objekten. Jedes Objekt hat:
- control_index: 1-basierter Index
- domain: Fachgebiet-Kuerzel
- category: Kategorie
- target_audience: Liste der Zielgruppen
{"".join(entries)}"""
try:
headers = {
"x-api-key": ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
payload = {
"model": ANTHROPIC_MODEL,
"max_tokens": 4096,
"system": "Du bist ein Compliance-Experte. Klassifiziere Controls nach Fachgebiet und Zielgruppe. Antworte NUR mit validem JSON.",
"messages": [{"role": "user", "content": prompt}],
}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload,
)
if resp.status_code != 200:
errors.append(f"Anthropic API {resp.status_code} at batch {batch_start}")
continue
raw = resp.json().get("content", [{}])[0].get("text", "")
# Parse response
import re
bracket_match = re.search(r"\[.*\]", raw, re.DOTALL)
if not bracket_match:
errors.append(f"No JSON array in response at batch {batch_start}")
continue
results_list = json.loads(bracket_match.group(0))
for item in results_list:
idx = item.get("control_index", 0) - 1
if idx < 0 or idx >= len(batch):
continue
ctrl = batch[idx]
ctrl_id = str(ctrl["id"])
new_domain = item.get("domain", "")
new_category = item.get("category", "")
new_audience = item.get("target_audience", [])
if not isinstance(new_audience, list):
new_audience = []
# Build new control_id from domain if domain changed
old_prefix = ctrl["control_id"].split("-")[0] if ctrl["control_id"] else ""
new_prefix = new_domain.upper()[:4] if new_domain else old_prefix
if not req.dry_run:
update_parts = []
update_params: dict = {"ctrl_id": ctrl_id}
if new_category:
update_parts.append("category = :category")
update_params["category"] = new_category
if new_audience:
update_parts.append("target_audience = :target_audience")
update_params["target_audience"] = json.dumps(new_audience)
# Note: We do NOT rename control_ids here — that would
# break references and cause unique constraint violations.
if update_parts:
update_parts.append("updated_at = NOW()")
db.execute(
text(f"UPDATE canonical_controls SET {', '.join(update_parts)} WHERE id = CAST(:ctrl_id AS uuid)"),
update_params,
)
updated += 1
if not req.dry_run:
db.commit()
except Exception as e:
errors.append(f"Batch {batch_start}: {str(e)}")
db.rollback()
_domain_backfill_status[backfill_id] = {
"status": "running", "total": total, "updated": updated,
"progress": f"{min(batch_start + BATCH_SIZE, total)}/{total}",
"errors": errors[-10:],
}
_domain_backfill_status[backfill_id] = {
"status": "completed", "total": total, "updated": updated,
"errors": errors[-50:],
}
logger.info("Domain backfill %s completed: %d/%d updated", backfill_id, updated, total)
except Exception as e:
logger.error("Domain backfill %s failed: %s", backfill_id, e)
_domain_backfill_status[backfill_id] = {"status": "failed", "error": str(e)}
finally:
db.close()
@router.post("/generate/backfill-domain")
async def start_domain_backfill(req: DomainBackfillRequest):
"""Backfill domain, category, and target_audience for controls using Anthropic API.
Finds controls where target_audience is NULL and enriches them.
Default is dry_run=True (preview only).
"""
import uuid
backfill_id = str(uuid.uuid4())[:8]
_domain_backfill_status[backfill_id] = {"status": "starting"}
asyncio.create_task(_run_domain_backfill(req, backfill_id))
return {"status": "running", "backfill_id": backfill_id,
"message": f"Domain backfill started. Poll /generate/backfill-status/{backfill_id}"}
@router.get("/generate/domain-backfill-status/{backfill_id}")
async def get_domain_backfill_status(backfill_id: str):
"""Get status of a domain backfill job."""
status = _domain_backfill_status.get(backfill_id)
if not status:
raise HTTPException(status_code=404, detail="Domain backfill job not found")
return status
# ---------------------------------------------------------------------------
# Source-Type Backfill — Classify law vs guideline vs standard vs restricted
# ---------------------------------------------------------------------------
class SourceTypeBackfillRequest(BaseModel):
dry_run: bool = True
_source_type_backfill_status: dict = {}
async def _run_source_type_backfill(dry_run: bool, backfill_id: str):
"""Backfill source_type into source_citation JSONB for all controls."""
db = SessionLocal()
try:
# Find controls with source_citation that lack source_type
rows = db.execute(text("""
SELECT control_id, source_citation, generation_metadata
FROM compliance.canonical_controls
WHERE source_citation IS NOT NULL
AND (source_citation->>'source_type' IS NULL
OR source_citation->>'source_type' = '')
""")).fetchall()
total = len(rows)
updated = 0
already_correct = 0
errors = []
_source_type_backfill_status[backfill_id] = {
"status": "running", "total": total, "updated": 0, "dry_run": dry_run,
}
for row in rows:
cid = row[0]
citation = row[1] if isinstance(row[1], dict) else json.loads(row[1] or "{}")
metadata = row[2] if isinstance(row[2], dict) else json.loads(row[2] or "{}")
# Get regulation_code from metadata
reg_code = metadata.get("source_regulation", "")
if not reg_code:
# Try to infer from source name
errors.append(f"{cid}: no source_regulation in metadata")
continue
# Classify
license_info = _classify_regulation(reg_code)
source_type = license_info.get("source_type", "restricted")
# Update citation
citation["source_type"] = source_type
if not dry_run:
db.execute(text("""
UPDATE compliance.canonical_controls
SET source_citation = :citation
WHERE control_id = :cid
"""), {"citation": json.dumps(citation), "cid": cid})
if updated % 100 == 0:
db.commit()
updated += 1
if not dry_run:
db.commit()
# Count distribution
dist_query = db.execute(text("""
SELECT source_citation->>'source_type' as st, COUNT(*)
FROM compliance.canonical_controls
WHERE source_citation IS NOT NULL
AND source_citation->>'source_type' IS NOT NULL
GROUP BY st
""")).fetchall() if not dry_run else []
distribution = {r[0]: r[1] for r in dist_query}
_source_type_backfill_status[backfill_id] = {
"status": "completed", "total": total, "updated": updated,
"dry_run": dry_run, "distribution": distribution,
"errors": errors[:50],
}
logger.info("Source-type backfill %s completed: %d/%d updated (dry_run=%s)",
backfill_id, updated, total, dry_run)
except Exception as e:
logger.error("Source-type backfill %s failed: %s", backfill_id, e)
_source_type_backfill_status[backfill_id] = {"status": "failed", "error": str(e)}
finally:
db.close()
@router.post("/generate/backfill-source-type")
async def start_source_type_backfill(req: SourceTypeBackfillRequest):
"""Backfill source_type (law/guideline/standard/restricted) into source_citation JSONB.
Classifies each control's source as binding law, authority guideline,
voluntary standard, or restricted norm based on regulation_code.
Default is dry_run=True (preview only).
"""
import uuid
backfill_id = str(uuid.uuid4())[:8]
_source_type_backfill_status[backfill_id] = {"status": "starting"}
asyncio.create_task(_run_source_type_backfill(req.dry_run, backfill_id))
return {
"status": "running",
"backfill_id": backfill_id,
"message": f"Source-type backfill started. Poll /generate/source-type-backfill-status/{backfill_id}",
}
@router.get("/generate/source-type-backfill-status/{backfill_id}")
async def get_source_type_backfill_status(backfill_id: str):
"""Get status of a source-type backfill job."""
status = _source_type_backfill_status.get(backfill_id)
if not status:
raise HTTPException(status_code=404, detail="Source-type backfill job not found")
return status
@@ -0,0 +1,856 @@
"""
FastAPI routes for the Multi-Layer Control Architecture.
Pattern Library, Obligation Extraction, Crosswalk Matrix, and Migration endpoints.
Endpoints:
GET /v1/canonical/patterns — All patterns (with filters)
GET /v1/canonical/patterns/{pattern_id} — Single pattern
GET /v1/canonical/patterns/{pattern_id}/controls — Controls for a pattern
POST /v1/canonical/obligations/extract — Extract obligations from text
GET /v1/canonical/crosswalk — Query crosswalk matrix
GET /v1/canonical/crosswalk/stats — Coverage statistics
POST /v1/canonical/migrate/decompose — Pass 0a: Obligation extraction
POST /v1/canonical/migrate/merge-obligations — Merge implementation-level dupes
POST /v1/canonical/migrate/enrich-obligations — Add trigger_type, impl metadata
POST /v1/canonical/migrate/compose-atomic — Pass 0b: Atomic control composition
POST /v1/canonical/migrate/link-obligations — Pass 1: Obligation linkage
POST /v1/canonical/migrate/classify-patterns — Pass 2: Pattern classification
POST /v1/canonical/migrate/triage — Pass 3: Quality triage
POST /v1/canonical/migrate/backfill-crosswalk — Pass 4: Crosswalk backfill
POST /v1/canonical/migrate/deduplicate — Pass 5: Deduplication
GET /v1/canonical/migrate/status — Migration progress
GET /v1/canonical/migrate/decomposition-status — Decomposition progress
"""
import json
import logging
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import text
from database import SessionLocal
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/v1/canonical", tags=["crosswalk"])
# =============================================================================
# REQUEST / RESPONSE MODELS
# =============================================================================
class PatternResponse(BaseModel):
id: str
name: str
name_de: str
domain: str
category: str
description: str
objective_template: str
severity_default: str
implementation_effort_default: str = "m"
tags: list = []
composable_with: list = []
open_anchor_refs: list = []
controls_count: int = 0
class PatternListResponse(BaseModel):
patterns: List[PatternResponse]
total: int
class PatternDetailResponse(PatternResponse):
rationale_template: str = ""
requirements_template: list = []
test_procedure_template: list = []
evidence_template: list = []
obligation_match_keywords: list = []
class ObligationExtractRequest(BaseModel):
text: str
regulation_code: Optional[str] = None
article: Optional[str] = None
paragraph: Optional[str] = None
class ObligationExtractResponse(BaseModel):
obligation_id: Optional[str] = None
obligation_title: Optional[str] = None
obligation_text: Optional[str] = None
method: str = "none"
confidence: float = 0.0
regulation_id: Optional[str] = None
pattern_id: Optional[str] = None
pattern_confidence: float = 0.0
class CrosswalkRow(BaseModel):
regulation_code: str = ""
article: Optional[str] = None
obligation_id: Optional[str] = None
pattern_id: Optional[str] = None
master_control_id: Optional[str] = None
confidence: float = 0.0
source: str = "auto"
class CrosswalkQueryResponse(BaseModel):
rows: List[CrosswalkRow]
total: int
class CrosswalkStatsResponse(BaseModel):
total_rows: int = 0
regulations_covered: int = 0
obligations_linked: int = 0
patterns_used: int = 0
controls_linked: int = 0
coverage_by_regulation: dict = {}
class MigrationRequest(BaseModel):
limit: int = 0 # 0 = no limit
batch_size: int = 0 # 0 = auto (5 for Anthropic, 1 for Ollama)
use_anthropic: bool = False # Use Anthropic API instead of Ollama
category_filter: Optional[str] = None # Comma-separated categories
source_filter: Optional[str] = None # Comma-separated source regulations (ILIKE match)
class BatchSubmitRequest(BaseModel):
limit: int = 0
batch_size: int = 5
category_filter: Optional[str] = None
source_filter: Optional[str] = None
class BatchProcessRequest(BaseModel):
batch_id: str
pass_type: str = "0a" # "0a" or "0b"
class MigrationResponse(BaseModel):
status: str = "completed"
stats: dict = {}
class MigrationStatusResponse(BaseModel):
total_controls: int = 0
has_obligation: int = 0
has_pattern: int = 0
fully_linked: int = 0
deprecated: int = 0
coverage_obligation_pct: float = 0.0
coverage_pattern_pct: float = 0.0
coverage_full_pct: float = 0.0
class DecompositionStatusResponse(BaseModel):
rich_controls: int = 0
decomposed_controls: int = 0
total_candidates: int = 0
validated: int = 0
rejected: int = 0
composed: int = 0
atomic_controls: int = 0
merged: int = 0
enriched: int = 0
ready_for_pass0b: int = 0
decomposition_pct: float = 0.0
composition_pct: float = 0.0
# =============================================================================
# PATTERN LIBRARY ENDPOINTS
# =============================================================================
@router.get("/patterns", response_model=PatternListResponse)
async def list_patterns(
domain: Optional[str] = Query(None, description="Filter by domain (e.g. AUTH, CRYP)"),
category: Optional[str] = Query(None, description="Filter by category"),
tag: Optional[str] = Query(None, description="Filter by tag"),
):
"""List all control patterns with optional filters."""
from compliance.services.pattern_matcher import PatternMatcher
matcher = PatternMatcher()
matcher._load_patterns()
matcher._build_keyword_index()
patterns = matcher._patterns
if domain:
patterns = [p for p in patterns if p.domain == domain.upper()]
if category:
patterns = [p for p in patterns if p.category == category.lower()]
if tag:
patterns = [p for p in patterns if tag.lower() in [t.lower() for t in p.tags]]
# Count controls per pattern from DB
control_counts = _get_pattern_control_counts()
response_patterns = []
for p in patterns:
response_patterns.append(PatternResponse(
id=p.id,
name=p.name,
name_de=p.name_de,
domain=p.domain,
category=p.category,
description=p.description,
objective_template=p.objective_template,
severity_default=p.severity_default,
implementation_effort_default=p.implementation_effort_default,
tags=p.tags,
composable_with=p.composable_with,
open_anchor_refs=p.open_anchor_refs,
controls_count=control_counts.get(p.id, 0),
))
return PatternListResponse(patterns=response_patterns, total=len(response_patterns))
@router.get("/patterns/{pattern_id}", response_model=PatternDetailResponse)
async def get_pattern(pattern_id: str):
"""Get a single control pattern by ID."""
from compliance.services.pattern_matcher import PatternMatcher
matcher = PatternMatcher()
matcher._load_patterns()
pattern = matcher.get_pattern(pattern_id)
if not pattern:
raise HTTPException(status_code=404, detail=f"Pattern {pattern_id} not found")
control_counts = _get_pattern_control_counts()
return PatternDetailResponse(
id=pattern.id,
name=pattern.name,
name_de=pattern.name_de,
domain=pattern.domain,
category=pattern.category,
description=pattern.description,
objective_template=pattern.objective_template,
rationale_template=pattern.rationale_template,
requirements_template=pattern.requirements_template,
test_procedure_template=pattern.test_procedure_template,
evidence_template=pattern.evidence_template,
severity_default=pattern.severity_default,
implementation_effort_default=pattern.implementation_effort_default,
tags=pattern.tags,
composable_with=pattern.composable_with,
open_anchor_refs=pattern.open_anchor_refs,
obligation_match_keywords=pattern.obligation_match_keywords,
controls_count=control_counts.get(pattern.id, 0),
)
@router.get("/patterns/{pattern_id}/controls")
async def get_pattern_controls(
pattern_id: str,
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""Get controls generated from a specific pattern."""
db = SessionLocal()
try:
result = db.execute(
text("""
SELECT id, control_id, title, objective, severity,
release_state, category, obligation_ids
FROM canonical_controls
WHERE pattern_id = :pattern_id
AND release_state NOT IN ('deprecated')
ORDER BY control_id
LIMIT :limit OFFSET :offset
"""),
{"pattern_id": pattern_id.upper(), "limit": limit, "offset": offset},
)
rows = result.fetchall()
count_result = db.execute(
text("""
SELECT count(*) FROM canonical_controls
WHERE pattern_id = :pattern_id
AND release_state NOT IN ('deprecated')
"""),
{"pattern_id": pattern_id.upper()},
)
total = count_result.fetchone()[0]
controls = []
for row in rows:
obl_ids = row[7]
if isinstance(obl_ids, str):
try:
obl_ids = json.loads(obl_ids)
except (json.JSONDecodeError, TypeError):
obl_ids = []
controls.append({
"id": str(row[0]),
"control_id": row[1],
"title": row[2],
"objective": row[3],
"severity": row[4],
"release_state": row[5],
"category": row[6],
"obligation_ids": obl_ids or [],
})
return {"controls": controls, "total": total}
finally:
db.close()
# =============================================================================
# OBLIGATION EXTRACTION ENDPOINT
# =============================================================================
@router.post("/obligations/extract", response_model=ObligationExtractResponse)
async def extract_obligation(req: ObligationExtractRequest):
"""Extract obligation from text using 3-tier strategy, then match to pattern."""
from compliance.services.obligation_extractor import ObligationExtractor
from compliance.services.pattern_matcher import PatternMatcher
extractor = ObligationExtractor()
await extractor.initialize()
obligation = await extractor.extract(
chunk_text=req.text,
regulation_code=req.regulation_code or "",
article=req.article,
paragraph=req.paragraph,
)
# Also match to pattern
matcher = PatternMatcher()
matcher._load_patterns()
matcher._build_keyword_index()
pattern_text = obligation.obligation_text or obligation.obligation_title or req.text[:500]
pattern_result = matcher._tier1_keyword(pattern_text, obligation.regulation_id)
return ObligationExtractResponse(
obligation_id=obligation.obligation_id,
obligation_title=obligation.obligation_title,
obligation_text=obligation.obligation_text,
method=obligation.method,
confidence=obligation.confidence,
regulation_id=obligation.regulation_id,
pattern_id=pattern_result.pattern_id if pattern_result else None,
pattern_confidence=pattern_result.confidence if pattern_result else 0,
)
# =============================================================================
# CROSSWALK MATRIX ENDPOINTS
# =============================================================================
@router.get("/crosswalk", response_model=CrosswalkQueryResponse)
async def query_crosswalk(
regulation_code: Optional[str] = Query(None),
article: Optional[str] = Query(None),
obligation_id: Optional[str] = Query(None),
pattern_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
):
"""Query the crosswalk matrix with filters."""
db = SessionLocal()
try:
conditions = ["1=1"]
params = {"limit": limit, "offset": offset}
if regulation_code:
conditions.append("regulation_code = :reg")
params["reg"] = regulation_code
if article:
conditions.append("article = :art")
params["art"] = article
if obligation_id:
conditions.append("obligation_id = :obl")
params["obl"] = obligation_id
if pattern_id:
conditions.append("pattern_id = :pat")
params["pat"] = pattern_id
where = " AND ".join(conditions)
result = db.execute(
text(f"""
SELECT regulation_code, article, obligation_id,
pattern_id, master_control_id, confidence, source
FROM crosswalk_matrix
WHERE {where}
ORDER BY regulation_code, article
LIMIT :limit OFFSET :offset
"""),
params,
)
rows = result.fetchall()
count_result = db.execute(
text(f"SELECT count(*) FROM crosswalk_matrix WHERE {where}"),
params,
)
total = count_result.fetchone()[0]
crosswalk_rows = [
CrosswalkRow(
regulation_code=r[0] or "",
article=r[1],
obligation_id=r[2],
pattern_id=r[3],
master_control_id=r[4],
confidence=float(r[5] or 0),
source=r[6] or "auto",
)
for r in rows
]
return CrosswalkQueryResponse(rows=crosswalk_rows, total=total)
finally:
db.close()
@router.get("/crosswalk/stats", response_model=CrosswalkStatsResponse)
async def crosswalk_stats():
"""Get crosswalk coverage statistics."""
db = SessionLocal()
try:
row = db.execute(text("""
SELECT
count(*) AS total,
count(DISTINCT regulation_code) FILTER (WHERE regulation_code != '') AS regs,
count(DISTINCT obligation_id) FILTER (WHERE obligation_id IS NOT NULL) AS obls,
count(DISTINCT pattern_id) FILTER (WHERE pattern_id IS NOT NULL) AS pats,
count(DISTINCT master_control_id) FILTER (WHERE master_control_id IS NOT NULL) AS ctrls
FROM crosswalk_matrix
""")).fetchone()
# Coverage by regulation
reg_rows = db.execute(text("""
SELECT regulation_code, count(*) AS cnt
FROM crosswalk_matrix
WHERE regulation_code != ''
GROUP BY regulation_code
ORDER BY cnt DESC
""")).fetchall()
coverage = {r[0]: r[1] for r in reg_rows}
return CrosswalkStatsResponse(
total_rows=row[0],
regulations_covered=row[1],
obligations_linked=row[2],
patterns_used=row[3],
controls_linked=row[4],
coverage_by_regulation=coverage,
)
finally:
db.close()
# =============================================================================
# MIGRATION ENDPOINTS
# =============================================================================
@router.post("/migrate/decompose", response_model=MigrationResponse)
async def migrate_decompose(req: MigrationRequest):
"""Pass 0a: Extract obligation candidates from rich controls.
With use_anthropic=true, uses Anthropic API with prompt caching
and content batching (multiple controls per API call).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.run_pass0a(
limit=req.limit,
batch_size=req.batch_size,
use_anthropic=req.use_anthropic,
category_filter=req.category_filter,
source_filter=req.source_filter,
)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Decomposition pass 0a failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/merge-obligations", response_model=MigrationResponse)
async def migrate_merge_obligations():
"""Merge implementation-level duplicate obligations within each parent.
Run AFTER Pass 0a, BEFORE Pass 0b. No LLM calls — rule-based.
Merges obligations that share similar action+object into the more
abstract survivor, marking the concrete duplicate as 'merged'.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = decomp.run_merge_pass()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Merge pass failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/enrich-obligations", response_model=MigrationResponse)
async def migrate_enrich_obligations():
"""Add trigger_type and is_implementation_specific metadata.
Run AFTER merge pass, BEFORE Pass 0b. No LLM calls — rule-based.
Classifies trigger_type (event/periodic/continuous) from obligation text
and detects implementation-specific obligations (concrete tools/protocols).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = decomp.enrich_obligations()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Enrich pass failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/compose-atomic", response_model=MigrationResponse)
async def migrate_compose_atomic(req: MigrationRequest):
"""Pass 0b: Compose atomic controls from obligation candidates.
With use_anthropic=true, uses Anthropic API with prompt caching
and content batching (multiple obligations per API call).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.run_pass0b(
limit=req.limit,
batch_size=req.batch_size,
use_anthropic=req.use_anthropic,
)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Decomposition pass 0b failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/batch-submit-0a", response_model=MigrationResponse)
async def batch_submit_pass0a(req: BatchSubmitRequest):
"""Submit Pass 0a as Anthropic Batch API job (50% cost reduction).
Returns a batch_id for polling. Results are processed asynchronously
within 24 hours by Anthropic.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
result = await decomp.submit_batch_pass0a(
limit=req.limit,
batch_size=req.batch_size,
category_filter=req.category_filter,
source_filter=req.source_filter,
)
return MigrationResponse(status=result.pop("status", "submitted"), stats=result)
except Exception as e:
logger.error("Batch submit 0a failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/batch-submit-0b", response_model=MigrationResponse)
async def batch_submit_pass0b(req: BatchSubmitRequest):
"""Submit Pass 0b as Anthropic Batch API job (50% cost reduction)."""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
result = await decomp.submit_batch_pass0b(
limit=req.limit,
batch_size=req.batch_size,
)
return MigrationResponse(status=result.pop("status", "submitted"), stats=result)
except Exception as e:
logger.error("Batch submit 0b failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/batch-status/{batch_id}")
async def batch_check_status(batch_id: str):
"""Check processing status of an Anthropic batch job."""
from compliance.services.decomposition_pass import check_batch_status
try:
status = await check_batch_status(batch_id)
return status
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/migrate/batch-process", response_model=MigrationResponse)
async def batch_process_results(req: BatchProcessRequest):
"""Fetch and process results from a completed Anthropic batch.
Call this after batch-status shows processing_status='ended'.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.process_batch_results(
batch_id=req.batch_id,
pass_type=req.pass_type,
)
return MigrationResponse(status=stats.pop("status", "completed"), stats=stats)
except Exception as e:
logger.error("Batch process failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/link-obligations", response_model=MigrationResponse)
async def migrate_link_obligations(req: MigrationRequest):
"""Pass 1: Link controls to obligations via source_citation article."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
await migration.initialize()
stats = await migration.run_pass1_obligation_linkage(limit=req.limit)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 1 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/classify-patterns", response_model=MigrationResponse)
async def migrate_classify_patterns(req: MigrationRequest):
"""Pass 2: Classify controls into patterns via keyword matching."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
await migration.initialize()
stats = await migration.run_pass2_pattern_classification(limit=req.limit)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 2 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/triage", response_model=MigrationResponse)
async def migrate_triage():
"""Pass 3: Quality triage — categorize by linkage completeness."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass3_quality_triage()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 3 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/backfill-crosswalk", response_model=MigrationResponse)
async def migrate_backfill_crosswalk():
"""Pass 4: Create crosswalk rows for linked controls."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass4_crosswalk_backfill()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 4 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/deduplicate", response_model=MigrationResponse)
async def migrate_deduplicate():
"""Pass 5: Mark duplicate controls (same obligation + pattern)."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass5_deduplication()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 5 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/status", response_model=MigrationStatusResponse)
async def migration_status():
"""Get overall migration progress."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
status = migration.migration_status()
return MigrationStatusResponse(**status)
except Exception as e:
logger.error("Migration status failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/decomposition-status", response_model=DecompositionStatusResponse)
async def decomposition_status():
"""Get decomposition progress (Pass 0a/0b)."""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
status = decomp.decomposition_status()
return DecompositionStatusResponse(**status)
except Exception as e:
logger.error("Decomposition status failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# =============================================================================
# BATCH DEDUP ENDPOINTS
# =============================================================================
# Module-level runner reference for status polling
_batch_dedup_runner = None
@router.post("/migrate/batch-dedup", response_model=MigrationResponse)
async def migrate_batch_dedup(
dry_run: bool = Query(False, description="Preview mode — no DB changes"),
hint_filter: Optional[str] = Query(None, description="Only process hints matching this prefix"),
):
"""Batch dedup: reduce ~85k Pass 0b controls to ~18-25k masters.
Phase 1: Groups by merge_group_hint, picks best quality master, links rest.
Phase 2: Cross-group embedding search for semantically similar masters.
"""
global _batch_dedup_runner
from compliance.services.batch_dedup_runner import BatchDedupRunner
db = SessionLocal()
try:
runner = BatchDedupRunner(db=db)
_batch_dedup_runner = runner
stats = await runner.run(dry_run=dry_run, hint_filter=hint_filter)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Batch dedup failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
_batch_dedup_runner = None
db.close()
@router.get("/migrate/batch-dedup/status")
async def batch_dedup_status():
"""Get current batch dedup progress (while running)."""
if _batch_dedup_runner is not None:
return {"running": True, **_batch_dedup_runner.get_status()}
# Not running — show DB stats
db = SessionLocal()
try:
row = db.execute(text("""
SELECT
count(*) FILTER (WHERE decomposition_method = 'pass0b') AS total_pass0b,
count(*) FILTER (WHERE decomposition_method = 'pass0b'
AND release_state = 'duplicate') AS duplicates,
count(*) FILTER (WHERE decomposition_method = 'pass0b'
AND release_state != 'duplicate'
AND release_state != 'deprecated') AS masters
FROM canonical_controls
""")).fetchone()
review_count = db.execute(text(
"SELECT count(*) FROM control_dedup_reviews WHERE review_status = 'pending'"
)).fetchone()[0]
return {
"running": False,
"total_pass0b": row[0],
"duplicates": row[1],
"masters": row[2],
"pending_reviews": review_count,
}
finally:
db.close()
# =============================================================================
# HELPERS
# =============================================================================
def _get_pattern_control_counts() -> dict[str, int]:
"""Get count of controls per pattern_id from DB."""
db = SessionLocal()
try:
result = db.execute(text("""
SELECT pattern_id, count(*) AS cnt
FROM canonical_controls
WHERE pattern_id IS NOT NULL AND pattern_id != ''
AND release_state NOT IN ('deprecated')
GROUP BY pattern_id
"""))
return {row[0]: row[1] for row in result.fetchall()}
except Exception:
return {}
finally:
db.close()
@@ -5,16 +5,23 @@ Endpoints:
- /dashboard: Main compliance dashboard
- /dashboard/executive: Executive summary for managers
- /dashboard/trend: Compliance score trend over time
- /dashboard/roadmap: Prioritised controls in 4 buckets
- /dashboard/module-status: Completion status of each SDK module
- /dashboard/next-actions: Top 5 most important actions
- /dashboard/snapshot: Save / query compliance score snapshots
- /score: Quick compliance score
- /reports: Report generation
"""
import logging
from datetime import datetime, timedelta, timezone
from datetime import datetime, date, timedelta
from calendar import month_abbr
from typing import Optional
from typing import Optional, Dict, Any, List
from decimal import Decimal
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
@@ -25,15 +32,24 @@ from ..db import (
ControlRepository,
EvidenceRepository,
RiskRepository,
AssertionDB,
)
from .schemas import (
DashboardResponse,
MultiDimensionalScore,
ExecutiveDashboardResponse,
TrendDataPoint,
RiskSummary,
DeadlineItem,
TeamWorkloadItem,
TraceabilityAssertion,
TraceabilityEvidence,
TraceabilityCoverage,
TraceabilityControl,
TraceabilityMatrixResponse,
)
from .tenant_utils import get_tenant_id as _get_tenant_id
from .db_utils import row_to_dict as _row_to_dict
logger = logging.getLogger(__name__)
router = APIRouter(tags=["compliance-dashboard"])
@@ -86,6 +102,14 @@ async def get_dashboard(db: Session = Depends(get_db)):
# or compute from by_status dict
score = ctrl_stats.get("compliance_score", 0.0)
# Multi-dimensional score (Anti-Fake-Evidence)
try:
ms = ctrl_repo.get_multi_dimensional_score()
multi_score = MultiDimensionalScore(**ms)
except Exception as e:
logger.warning(f"Failed to compute multi-dimensional score: {e}")
multi_score = None
return DashboardResponse(
compliance_score=round(score, 1),
total_regulations=len(regulations),
@@ -98,6 +122,7 @@ async def get_dashboard(db: Session = Depends(get_db)):
total_risks=len(risks),
risks_by_level=risks_by_level,
recent_activity=[],
multi_score=multi_score,
)
@@ -116,11 +141,18 @@ async def get_compliance_score(db: Session = Depends(get_db)):
else:
score = 0
# Multi-dimensional score (Anti-Fake-Evidence)
try:
multi_score = ctrl_repo.get_multi_dimensional_score()
except Exception:
multi_score = None
return {
"score": round(score, 1),
"total_controls": total,
"passing_controls": passing,
"partial_controls": partial,
"multi_score": multi_score,
}
@@ -322,6 +354,424 @@ async def get_compliance_trend(
}
# ============================================================================
# Dashboard Extended — Roadmap, Module-Status, Next-Actions, Snapshots
# ============================================================================
# Weight map for control prioritisation
_PRIORITY_WEIGHTS = {"legal": 5, "security": 3, "best_practice": 1, "operational": 2}
# SDK module definitions → DB table used for counting completion
_MODULE_DEFS: List[Dict[str, str]] = [
{"key": "vvt", "label": "VVT", "table": "compliance_vvt_activities"},
{"key": "tom", "label": "TOM", "table": "compliance_toms"},
{"key": "dsfa", "label": "DSFA", "table": "compliance_dsfa_assessments"},
{"key": "loeschfristen", "label": "Loeschfristen", "table": "compliance_loeschfristen"},
{"key": "risks", "label": "Risiken", "table": "compliance_risks"},
{"key": "controls", "label": "Controls", "table": "compliance_controls"},
{"key": "evidence", "label": "Nachweise", "table": "compliance_evidence"},
{"key": "obligations", "label": "Pflichten", "table": "compliance_obligations"},
{"key": "incidents", "label": "Vorfaelle", "table": "compliance_notfallplan_incidents"},
{"key": "vendor", "label": "Auftragsverarbeiter", "table": "compliance_vendor_assessments"},
{"key": "legal_templates", "label": "Rechtl. Dokumente", "table": "compliance_legal_templates"},
{"key": "training", "label": "Schulungen", "table": "training_modules"},
{"key": "audit", "label": "Audit", "table": "compliance_audit_sessions"},
{"key": "security_backlog", "label": "Security-Backlog", "table": "compliance_security_backlog"},
{"key": "quality", "label": "Qualitaet", "table": "compliance_quality_items"},
]
@router.get("/dashboard/roadmap")
async def get_dashboard_roadmap(
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""Prioritised controls in 4 buckets: Quick Wins, Must Have, Should Have, Nice to Have."""
ctrl_repo = ControlRepository(db)
controls = ctrl_repo.get_all()
today = datetime.utcnow().date()
buckets: Dict[str, list] = {
"quick_wins": [],
"must_have": [],
"should_have": [],
"nice_to_have": [],
}
for ctrl in controls:
status = ctrl.status.value if ctrl.status else "planned"
if status == "pass":
continue # already done
weight = _PRIORITY_WEIGHTS.get(ctrl.category if hasattr(ctrl, "category") else "best_practice", 1)
days_overdue = 0
if ctrl.next_review_at:
review_date = ctrl.next_review_at.date() if hasattr(ctrl.next_review_at, "date") else ctrl.next_review_at
days_overdue = (today - review_date).days
urgency = weight * 2 + (1 if days_overdue > 0 else 0)
item = {
"id": str(ctrl.id),
"control_id": ctrl.control_id,
"title": ctrl.title,
"status": status,
"domain": ctrl.domain.value if ctrl.domain else "unknown",
"owner": ctrl.owner,
"next_review_at": ctrl.next_review_at.isoformat() if ctrl.next_review_at else None,
"days_overdue": max(0, days_overdue),
"weight": weight,
}
if weight >= 5 and days_overdue > 0:
buckets["quick_wins"].append(item)
elif weight >= 4:
buckets["must_have"].append(item)
elif weight >= 2:
buckets["should_have"].append(item)
else:
buckets["nice_to_have"].append(item)
# Sort each bucket by urgency desc
for key in buckets:
buckets[key].sort(key=lambda x: x["days_overdue"], reverse=True)
return {
"buckets": buckets,
"counts": {k: len(v) for k, v in buckets.items()},
"generated_at": datetime.utcnow().isoformat(),
}
@router.get("/dashboard/module-status")
async def get_module_status(
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""Completion status for each SDK module based on DB record counts."""
modules = []
for mod in _MODULE_DEFS:
try:
row = db.execute(
text(f"SELECT COUNT(*) FROM {mod['table']} WHERE tenant_id = :tid"),
{"tid": tenant_id},
).fetchone()
count = int(row[0]) if row else 0
except Exception:
count = 0
# Simple heuristic: 0 = not started, 1-2 = in progress, 3+ = complete
if count == 0:
status = "not_started"
progress = 0
elif count < 3:
status = "in_progress"
progress = min(60, count * 30)
else:
status = "complete"
progress = 100
modules.append({
"key": mod["key"],
"label": mod["label"],
"count": count,
"status": status,
"progress": progress,
})
started = sum(1 for m in modules if m["status"] != "not_started")
complete = sum(1 for m in modules if m["status"] == "complete")
return {
"modules": modules,
"total": len(modules),
"started": started,
"complete": complete,
"overall_progress": round((complete / len(modules)) * 100, 1) if modules else 0,
}
@router.get("/dashboard/next-actions")
async def get_next_actions(
limit: int = Query(5, ge=1, le=20),
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""Top N most important actions sorted by urgency*impact."""
ctrl_repo = ControlRepository(db)
controls = ctrl_repo.get_all()
today = datetime.utcnow().date()
actions = []
for ctrl in controls:
status = ctrl.status.value if ctrl.status else "planned"
if status == "pass":
continue
days_overdue = 0
if ctrl.next_review_at:
review_date = ctrl.next_review_at.date() if hasattr(ctrl.next_review_at, "date") else ctrl.next_review_at
days_overdue = max(0, (today - review_date).days)
weight = _PRIORITY_WEIGHTS.get(ctrl.category if hasattr(ctrl, "category") else "best_practice", 1)
urgency_score = weight * 10 + days_overdue
actions.append({
"id": str(ctrl.id),
"control_id": ctrl.control_id,
"title": ctrl.title,
"status": status,
"domain": ctrl.domain.value if ctrl.domain else "unknown",
"owner": ctrl.owner,
"days_overdue": days_overdue,
"urgency_score": urgency_score,
"reason": "Ueberfaellig" if days_overdue > 0 else "Offen",
})
actions.sort(key=lambda x: x["urgency_score"], reverse=True)
return {"actions": actions[:limit]}
@router.post("/dashboard/snapshot")
async def create_score_snapshot(
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""Save current compliance score as a historical snapshot."""
ctrl_repo = ControlRepository(db)
evidence_repo = EvidenceRepository(db)
risk_repo = RiskRepository(db)
ctrl_stats = ctrl_repo.get_statistics()
evidence_stats = evidence_repo.get_statistics()
risks = risk_repo.get_all()
total = ctrl_stats.get("total", 0)
passing = ctrl_stats.get("pass", 0)
partial = ctrl_stats.get("partial", 0)
score = round(((passing + partial * 0.5) / total) * 100, 2) if total > 0 else 0
risks_high = sum(1 for r in risks if (r.inherent_risk.value if r.inherent_risk else "low") in ("high", "critical"))
today = date.today()
row = db.execute(text("""
INSERT INTO compliance_score_snapshots (
tenant_id, score, controls_total, controls_pass, controls_partial,
evidence_total, evidence_valid, risks_total, risks_high, snapshot_date
) VALUES (
:tenant_id, :score, :controls_total, :controls_pass, :controls_partial,
:evidence_total, :evidence_valid, :risks_total, :risks_high, :snapshot_date
)
ON CONFLICT (tenant_id, project_id, snapshot_date) DO UPDATE SET
score = EXCLUDED.score,
controls_total = EXCLUDED.controls_total,
controls_pass = EXCLUDED.controls_pass,
controls_partial = EXCLUDED.controls_partial,
evidence_total = EXCLUDED.evidence_total,
evidence_valid = EXCLUDED.evidence_valid,
risks_total = EXCLUDED.risks_total,
risks_high = EXCLUDED.risks_high
RETURNING *
"""), {
"tenant_id": tenant_id,
"score": score,
"controls_total": total,
"controls_pass": passing,
"controls_partial": partial,
"evidence_total": evidence_stats.get("total", 0),
"evidence_valid": evidence_stats.get("by_status", {}).get("valid", 0),
"risks_total": len(risks),
"risks_high": risks_high,
"snapshot_date": today,
}).fetchone()
db.commit()
return _row_to_dict(row)
@router.get("/dashboard/score-history")
async def get_score_history(
months: int = Query(12, ge=1, le=36),
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""Get compliance score history from snapshots."""
since = date.today() - timedelta(days=months * 30)
rows = db.execute(text("""
SELECT * FROM compliance_score_snapshots
WHERE tenant_id = :tenant_id AND snapshot_date >= :since
ORDER BY snapshot_date ASC
"""), {"tenant_id": tenant_id, "since": since}).fetchall()
snapshots = []
for r in rows:
d = _row_to_dict(r)
# Convert Decimal to float for JSON
if isinstance(d.get("score"), Decimal):
d["score"] = float(d["score"])
snapshots.append(d)
return {
"snapshots": snapshots,
"total": len(snapshots),
"period_months": months,
}
# ============================================================================
# Evidence Distribution (Anti-Fake-Evidence Phase 3)
# ============================================================================
@router.get("/dashboard/evidence-distribution")
async def get_evidence_distribution(
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""Evidence counts by confidence level and four-eyes status."""
evidence_repo = EvidenceRepository(db)
all_evidence = evidence_repo.get_all()
by_confidence = {"E0": 0, "E1": 0, "E2": 0, "E3": 0, "E4": 0}
four_eyes_pending = 0
for e in all_evidence:
level = e.confidence_level.value if e.confidence_level else "E1"
if level in by_confidence:
by_confidence[level] += 1
if e.requires_four_eyes and e.approval_status not in ("approved", "rejected"):
four_eyes_pending += 1
return {
"by_confidence": by_confidence,
"four_eyes_pending": four_eyes_pending,
"total": len(all_evidence),
}
# ============================================================================
# Traceability Matrix (Anti-Fake-Evidence Phase 4a)
# ============================================================================
@router.get("/dashboard/traceability-matrix", response_model=TraceabilityMatrixResponse)
async def get_traceability_matrix(
db: Session = Depends(get_db),
tenant_id: str = Depends(_get_tenant_id),
):
"""
Full traceability chain: Control → Evidence → Assertions.
Loads each entity set once, builds in-memory indices, and nests
the result so the frontend can render a matrix view.
"""
ctrl_repo = ControlRepository(db)
evidence_repo = EvidenceRepository(db)
# 1. Load all three entity sets
controls = ctrl_repo.get_all()
all_evidence = evidence_repo.get_all()
all_assertions = db.query(AssertionDB).filter(
AssertionDB.entity_type == "evidence",
).all()
# 2. Index assertions by evidence_id (entity_id)
assertions_by_evidence: Dict[str, list] = {}
for a in all_assertions:
assertions_by_evidence.setdefault(a.entity_id, []).append(a)
# 3. Index evidence by control_id
evidence_by_control: Dict[str, list] = {}
for e in all_evidence:
evidence_by_control.setdefault(str(e.control_id), []).append(e)
# 4. Build nested response
result_controls: list = []
total_controls = 0
covered_controls = 0
fully_verified = 0
for ctrl in controls:
total_controls += 1
ctrl_id = str(ctrl.id)
ctrl_evidence = evidence_by_control.get(ctrl_id, [])
nested_evidence: list = []
has_evidence = len(ctrl_evidence) > 0
has_assertions = False
all_verified = True
min_conf: Optional[str] = None
conf_order = {"E0": 0, "E1": 1, "E2": 2, "E3": 3, "E4": 4}
for e in ctrl_evidence:
ev_id = str(e.id)
ev_assertions = assertions_by_evidence.get(ev_id, [])
nested_assertions = [
TraceabilityAssertion(
id=str(a.id),
sentence_text=a.sentence_text,
assertion_type=a.assertion_type or "assertion",
confidence=a.confidence or 0.0,
verified=a.verified_by is not None,
)
for a in ev_assertions
]
if nested_assertions:
has_assertions = True
for na in nested_assertions:
if not na.verified:
all_verified = False
conf = e.confidence_level.value if e.confidence_level else "E1"
if min_conf is None or conf_order.get(conf, 1) < conf_order.get(min_conf, 1):
min_conf = conf
nested_evidence.append(TraceabilityEvidence(
id=ev_id,
title=e.title,
evidence_type=e.evidence_type,
confidence_level=conf,
status=e.status.value if e.status else "valid",
assertions=nested_assertions,
))
if not has_assertions:
all_verified = False
if has_evidence:
covered_controls += 1
if has_evidence and has_assertions and all_verified:
fully_verified += 1
coverage = TraceabilityCoverage(
has_evidence=has_evidence,
has_assertions=has_assertions,
all_assertions_verified=all_verified,
min_confidence_level=min_conf,
)
result_controls.append(TraceabilityControl(
id=ctrl_id,
control_id=ctrl.control_id,
title=ctrl.title,
status=ctrl.status.value if ctrl.status else "planned",
domain=ctrl.domain.value if ctrl.domain else "unknown",
evidence=nested_evidence,
coverage=coverage,
))
summary = {
"total_controls": total_controls,
"covered_controls": covered_controls,
"fully_verified": fully_verified,
"uncovered_controls": total_controls - covered_controls,
}
return TraceabilityMatrixResponse(controls=result_controls, summary=summary)
# ============================================================================
# Reports
# ============================================================================
@@ -60,10 +60,314 @@ def get_dsfa_service(db: Session = Depends(get_db)) -> DSFAService:
return DSFAService(db)
def get_workflow_service(
db: Session = Depends(get_db),
) -> DSFAWorkflowService:
return DSFAWorkflowService(db)
# =============================================================================
# Pydantic Schemas
# =============================================================================
class DSFACreate(BaseModel):
title: str
description: str = ""
status: str = "draft"
risk_level: str = "low"
processing_activity: str = ""
data_categories: List[str] = []
recipients: List[str] = []
measures: List[str] = []
created_by: str = "system"
# Section 1
processing_description: Optional[str] = None
processing_purpose: Optional[str] = None
legal_basis: Optional[str] = None
legal_basis_details: Optional[str] = None
# Section 2
necessity_assessment: Optional[str] = None
proportionality_assessment: Optional[str] = None
data_minimization: Optional[str] = None
alternatives_considered: Optional[str] = None
retention_justification: Optional[str] = None
# Section 3
involves_ai: Optional[bool] = None
overall_risk_level: Optional[str] = None
risk_score: Optional[int] = None
# Section 6
dpo_consulted: Optional[bool] = None
dpo_name: Optional[str] = None
dpo_opinion: Optional[str] = None
dpo_approved: Optional[bool] = None
authority_consulted: Optional[bool] = None
authority_reference: Optional[str] = None
authority_decision: Optional[str] = None
# Metadata
version: Optional[int] = None
conclusion: Optional[str] = None
federal_state: Optional[str] = None
authority_resource_id: Optional[str] = None
submitted_by: Optional[str] = None
# JSONB Arrays
data_subjects: Optional[List[str]] = None
affected_rights: Optional[List[str]] = None
triggered_rule_codes: Optional[List[str]] = None
ai_trigger_ids: Optional[List[str]] = None
wp248_criteria_met: Optional[List[str]] = None
art35_abs3_triggered: Optional[List[str]] = None
tom_references: Optional[List[str]] = None
risks: Optional[List[dict]] = None
mitigations: Optional[List[dict]] = None
stakeholder_consultations: Optional[List[dict]] = None
review_triggers: Optional[List[dict]] = None
review_comments: Optional[List[dict]] = None
ai_use_case_modules: Optional[List[dict]] = None
section_8_complete: Optional[bool] = None
# JSONB Objects
threshold_analysis: Optional[dict] = None
consultation_requirement: Optional[dict] = None
review_schedule: Optional[dict] = None
section_progress: Optional[dict] = None
metadata: Optional[dict] = None
class DSFAUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
status: Optional[str] = None
risk_level: Optional[str] = None
processing_activity: Optional[str] = None
data_categories: Optional[List[str]] = None
recipients: Optional[List[str]] = None
measures: Optional[List[str]] = None
approved_by: Optional[str] = None
# Section 1
processing_description: Optional[str] = None
processing_purpose: Optional[str] = None
legal_basis: Optional[str] = None
legal_basis_details: Optional[str] = None
# Section 2
necessity_assessment: Optional[str] = None
proportionality_assessment: Optional[str] = None
data_minimization: Optional[str] = None
alternatives_considered: Optional[str] = None
retention_justification: Optional[str] = None
# Section 3
involves_ai: Optional[bool] = None
overall_risk_level: Optional[str] = None
risk_score: Optional[int] = None
# Section 6
dpo_consulted: Optional[bool] = None
dpo_name: Optional[str] = None
dpo_opinion: Optional[str] = None
dpo_approved: Optional[bool] = None
authority_consulted: Optional[bool] = None
authority_reference: Optional[str] = None
authority_decision: Optional[str] = None
# Metadata
version: Optional[int] = None
conclusion: Optional[str] = None
federal_state: Optional[str] = None
authority_resource_id: Optional[str] = None
submitted_by: Optional[str] = None
# JSONB Arrays
data_subjects: Optional[List[str]] = None
affected_rights: Optional[List[str]] = None
triggered_rule_codes: Optional[List[str]] = None
ai_trigger_ids: Optional[List[str]] = None
wp248_criteria_met: Optional[List[str]] = None
art35_abs3_triggered: Optional[List[str]] = None
tom_references: Optional[List[str]] = None
risks: Optional[List[dict]] = None
mitigations: Optional[List[dict]] = None
stakeholder_consultations: Optional[List[dict]] = None
review_triggers: Optional[List[dict]] = None
review_comments: Optional[List[dict]] = None
ai_use_case_modules: Optional[List[dict]] = None
section_8_complete: Optional[bool] = None
# JSONB Objects
threshold_analysis: Optional[dict] = None
consultation_requirement: Optional[dict] = None
review_schedule: Optional[dict] = None
section_progress: Optional[dict] = None
metadata: Optional[dict] = None
class DSFAStatusUpdate(BaseModel):
status: str
approved_by: Optional[str] = None
class DSFASectionUpdate(BaseModel):
"""Body for PUT /dsfa/{id}/sections/{section_number}."""
content: Optional[str] = None
# Allow arbitrary extra fields so the frontend can send any section-specific data
extra: Optional[dict] = None
class DSFAApproveRequest(BaseModel):
"""Body for POST /dsfa/{id}/approve."""
approved: bool
comments: Optional[str] = None
approved_by: Optional[str] = None
# =============================================================================
# Helpers
# =============================================================================
def _get_tenant_id(tenant_id: Optional[str]) -> str:
return tenant_id or DEFAULT_TENANT_ID
def _dsfa_to_response(row) -> dict:
"""Convert a DB row to a JSON-serializable dict."""
import json
# SQLAlchemy 2.0: Row objects need ._mapping for string-key access
if hasattr(row, "_mapping"):
row = row._mapping
def _parse_arr(val):
"""Parse a JSONB array field → list."""
if val is None:
return []
if isinstance(val, list):
return val
if isinstance(val, str):
try:
parsed = json.loads(val)
return parsed if isinstance(parsed, list) else []
except Exception:
return []
return val
def _parse_obj(val):
"""Parse a JSONB object field → dict."""
if val is None:
return {}
if isinstance(val, dict):
return val
if isinstance(val, str):
try:
parsed = json.loads(val)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
return val
def _ts(val):
"""Timestamp → ISO string or None."""
if not val:
return None
if isinstance(val, str):
return val
return val.isoformat()
def _get(key, default=None):
"""Safe row access — returns default if key missing (handles old rows)."""
try:
v = row[key]
return default if v is None and default is not None else v
except (KeyError, IndexError):
return default
return {
# Core fields (always present since Migration 024)
"id": str(row["id"]),
"tenant_id": row["tenant_id"],
"title": row["title"],
"description": row["description"] or "",
"status": row["status"] or "draft",
"risk_level": row["risk_level"] or "low",
"processing_activity": row["processing_activity"] or "",
"data_categories": _parse_arr(row["data_categories"]),
"recipients": _parse_arr(row["recipients"]),
"measures": _parse_arr(row["measures"]),
"approved_by": row["approved_by"],
"approved_at": _ts(row["approved_at"]),
"created_by": row["created_by"] or "system",
"created_at": _ts(row["created_at"]),
"updated_at": _ts(row["updated_at"]),
# Section 1 (Migration 030)
"processing_description": _get("processing_description"),
"processing_purpose": _get("processing_purpose"),
"legal_basis": _get("legal_basis"),
"legal_basis_details": _get("legal_basis_details"),
# Section 2
"necessity_assessment": _get("necessity_assessment"),
"proportionality_assessment": _get("proportionality_assessment"),
"data_minimization": _get("data_minimization"),
"alternatives_considered": _get("alternatives_considered"),
"retention_justification": _get("retention_justification"),
# Section 3
"involves_ai": _get("involves_ai", False),
"overall_risk_level": _get("overall_risk_level"),
"risk_score": _get("risk_score", 0),
# Section 6
"dpo_consulted": _get("dpo_consulted", False),
"dpo_consulted_at": _ts(_get("dpo_consulted_at")),
"dpo_name": _get("dpo_name"),
"dpo_opinion": _get("dpo_opinion"),
"dpo_approved": _get("dpo_approved"),
"authority_consulted": _get("authority_consulted", False),
"authority_consulted_at": _ts(_get("authority_consulted_at")),
"authority_reference": _get("authority_reference"),
"authority_decision": _get("authority_decision"),
# Metadata / Versioning
"version": _get("version", 1),
"previous_version_id": str(_get("previous_version_id")) if _get("previous_version_id") else None,
"conclusion": _get("conclusion"),
"federal_state": _get("federal_state"),
"authority_resource_id": _get("authority_resource_id"),
"submitted_for_review_at": _ts(_get("submitted_for_review_at")),
"submitted_by": _get("submitted_by"),
# JSONB Arrays
"data_subjects": _parse_arr(_get("data_subjects")),
"affected_rights": _parse_arr(_get("affected_rights")),
"triggered_rule_codes": _parse_arr(_get("triggered_rule_codes")),
"ai_trigger_ids": _parse_arr(_get("ai_trigger_ids")),
"wp248_criteria_met": _parse_arr(_get("wp248_criteria_met")),
"art35_abs3_triggered": _parse_arr(_get("art35_abs3_triggered")),
"tom_references": _parse_arr(_get("tom_references")),
"risks": _parse_arr(_get("risks")),
"mitigations": _parse_arr(_get("mitigations")),
"stakeholder_consultations": _parse_arr(_get("stakeholder_consultations")),
"review_triggers": _parse_arr(_get("review_triggers")),
"review_comments": _parse_arr(_get("review_comments")),
# Section 8 / AI (Migration 028)
"ai_use_case_modules": _parse_arr(_get("ai_use_case_modules")),
"section_8_complete": _get("section_8_complete", False),
# JSONB Objects
"threshold_analysis": _parse_obj(_get("threshold_analysis")),
"consultation_requirement": _parse_obj(_get("consultation_requirement")),
"review_schedule": _parse_obj(_get("review_schedule")),
"section_progress": _parse_obj(_get("section_progress")),
"metadata": _parse_obj(_get("metadata")),
}
def _log_audit(
db: Session,
tenant_id: str,
dsfa_id,
action: str,
changed_by: str = "system",
old_values=None,
new_values=None,
):
import json
db.execute(
text("""
INSERT INTO compliance_dsfa_audit_log
(tenant_id, dsfa_id, action, changed_by, old_values, new_values)
VALUES
(:tenant_id, :dsfa_id, :action, :changed_by,
CAST(:old_values AS jsonb), CAST(:new_values AS jsonb))
"""),
{
"tenant_id": tenant_id,
"dsfa_id": str(dsfa_id) if dsfa_id else None,
"action": action,
"changed_by": changed_by,
"old_values": json.dumps(old_values) if old_values else None,
"new_values": json.dumps(new_values) if new_values else None,
},
)
# =============================================================================
@@ -177,8 +481,51 @@ async def create_dsfa(
service: DSFAService = Depends(get_dsfa_service),
) -> dict[str, Any]:
"""Neue DSFA erstellen."""
with translate_domain_errors():
return service.create(tenant_id, request)
import json
if request.status not in VALID_STATUSES:
raise HTTPException(status_code=422, detail=f"Ungültiger Status: {request.status}")
if request.risk_level not in VALID_RISK_LEVELS:
raise HTTPException(status_code=422, detail=f"Ungültiges Risiko-Level: {request.risk_level}")
tid = _get_tenant_id(tenant_id)
row = db.execute(
text("""
INSERT INTO compliance_dsfas
(tenant_id, title, description, status, risk_level,
processing_activity, data_categories, recipients, measures, created_by)
VALUES
(:tenant_id, :title, :description, :status, :risk_level,
:processing_activity,
CAST(:data_categories AS jsonb),
CAST(:recipients AS jsonb),
CAST(:measures AS jsonb),
:created_by)
RETURNING *
"""),
{
"tenant_id": tid,
"title": request.title,
"description": request.description,
"status": request.status,
"risk_level": request.risk_level,
"processing_activity": request.processing_activity,
"data_categories": json.dumps(request.data_categories),
"recipients": json.dumps(request.recipients),
"measures": json.dumps(request.measures),
"created_by": request.created_by,
},
).fetchone()
db.flush()
row_id = row._mapping["id"] if hasattr(row, "_mapping") else row[0]
_log_audit(
db, tid, row_id, "CREATE", request.created_by,
new_values={"title": request.title, "status": request.status},
)
db.commit()
return _dsfa_to_response(row)
# =============================================================================
File diff suppressed because it is too large Load Diff
@@ -22,23 +22,21 @@ from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from compliance.api._http_errors import translate_domain_errors
from compliance.db import ControlRepository, EvidenceRepository
from compliance.schemas.evidence import (
EvidenceCreate,
EvidenceListResponse,
EvidenceResponse,
from ..db import (
ControlRepository,
EvidenceRepository,
EvidenceStatusEnum,
EvidenceConfidenceEnum,
EvidenceTruthStatusEnum,
)
from compliance.services.auto_risk_updater import AutoRiskUpdater
from compliance.domain import NotFoundError, ValidationError
from compliance.services.evidence_service import (
SOURCE_CONTROL_MAP,
EvidenceService,
_extract_findings_detail, # re-exported for legacy test imports
_parse_ci_evidence, # re-exported for legacy test imports
_store_evidence, # re-exported for legacy test imports
_update_risks as _update_risks_impl,
from ..db.models import EvidenceDB, ControlDB, AuditTrailDB
from ..services.auto_risk_updater import AutoRiskUpdater
from .schemas import (
EvidenceCreate, EvidenceResponse, EvidenceListResponse,
EvidenceRejectRequest,
)
from .audit_trail_utils import log_audit_trail
logger = logging.getLogger(__name__)
router = APIRouter(tags=["compliance-evidence"])
@@ -56,7 +54,88 @@ def get_evidence_service(db: Session = Depends(get_db)) -> EvidenceService:
# ============================================================================
# Evidence CRUD
# Anti-Fake-Evidence: Four-Eyes Domain Check
# ============================================================================
FOUR_EYES_DOMAINS = {"gov", "priv"}
def _requires_four_eyes(control_domain: str) -> bool:
"""Controls in governance/privacy domains require two independent reviewers."""
return control_domain in FOUR_EYES_DOMAINS
# ============================================================================
# Anti-Fake-Evidence: Auto-Classification Helpers
# ============================================================================
def _classify_confidence(source: Optional[str], evidence_type: Optional[str] = None, artifact_hash: Optional[str] = None) -> EvidenceConfidenceEnum:
"""Classify evidence confidence level based on source and metadata."""
if source == "ci_pipeline":
return EvidenceConfidenceEnum.E3
if source == "api" and artifact_hash:
return EvidenceConfidenceEnum.E3
if source == "api":
return EvidenceConfidenceEnum.E3
if source in ("manual", "upload"):
return EvidenceConfidenceEnum.E1
if source == "generated":
return EvidenceConfidenceEnum.E0
# Default for unknown sources
return EvidenceConfidenceEnum.E1
def _classify_truth_status(source: Optional[str]) -> EvidenceTruthStatusEnum:
"""Classify evidence truth status based on source."""
if source == "ci_pipeline":
return EvidenceTruthStatusEnum.OBSERVED
if source in ("manual", "upload"):
return EvidenceTruthStatusEnum.UPLOADED
if source == "generated":
return EvidenceTruthStatusEnum.GENERATED
if source == "api":
return EvidenceTruthStatusEnum.OBSERVED
return EvidenceTruthStatusEnum.UPLOADED
def _build_evidence_response(e: EvidenceDB) -> EvidenceResponse:
"""Build an EvidenceResponse from an EvidenceDB, including anti-fake fields."""
return EvidenceResponse(
id=e.id,
control_id=e.control_id,
evidence_type=e.evidence_type,
title=e.title,
description=e.description,
artifact_path=e.artifact_path,
artifact_url=e.artifact_url,
artifact_hash=e.artifact_hash,
file_size_bytes=e.file_size_bytes,
mime_type=e.mime_type,
valid_from=e.valid_from,
valid_until=e.valid_until,
status=e.status.value if e.status else None,
source=e.source,
ci_job_id=e.ci_job_id,
uploaded_by=e.uploaded_by,
collected_at=e.collected_at,
created_at=e.created_at,
confidence_level=e.confidence_level.value if e.confidence_level else None,
truth_status=e.truth_status.value if e.truth_status else None,
generation_mode=e.generation_mode,
may_be_used_as_evidence=e.may_be_used_as_evidence,
reviewed_by=e.reviewed_by,
reviewed_at=e.reviewed_at,
approval_status=e.approval_status,
first_reviewer=e.first_reviewer,
first_reviewed_at=e.first_reviewed_at,
second_reviewer=e.second_reviewer,
second_reviewed_at=e.second_reviewed_at,
requires_four_eyes=e.requires_four_eyes,
)
# ============================================================================
# Evidence
# ============================================================================
@router.get("/evidence", response_model=EvidenceListResponse)
@@ -69,8 +148,38 @@ async def list_evidence(
service: EvidenceService = Depends(get_evidence_service),
) -> EvidenceListResponse:
"""List evidence with optional filters and pagination."""
with translate_domain_errors():
return service.list_evidence(control_id, evidence_type, status, page, limit)
repo = EvidenceRepository(db)
if control_id:
# First get the control UUID
ctrl_repo = ControlRepository(db)
control = ctrl_repo.get_by_control_id(control_id)
if not control:
raise HTTPException(status_code=404, detail=f"Control {control_id} not found")
evidence = repo.get_by_control(control.id)
else:
evidence = repo.get_all()
if evidence_type:
evidence = [e for e in evidence if e.evidence_type == evidence_type]
if status:
try:
status_enum = EvidenceStatusEnum(status)
evidence = [e for e in evidence if e.status == status_enum]
except ValueError:
pass
total = len(evidence)
# Apply pagination if requested
if page is not None and limit is not None:
offset = (page - 1) * limit
evidence = evidence[offset:offset + limit]
results = [_build_evidence_response(e) for e in evidence]
return EvidenceListResponse(evidence=results, total=total)
@router.post("/evidence", response_model=EvidenceResponse)
@@ -79,8 +188,66 @@ async def create_evidence(
service: EvidenceService = Depends(get_evidence_service),
) -> EvidenceResponse:
"""Create new evidence record."""
with translate_domain_errors():
return service.create_evidence(evidence_data)
repo = EvidenceRepository(db)
# Get control UUID
ctrl_repo = ControlRepository(db)
control = ctrl_repo.get_by_control_id(evidence_data.control_id)
if not control:
raise HTTPException(status_code=404, detail=f"Control {evidence_data.control_id} not found")
source = evidence_data.source or "api"
confidence = _classify_confidence(source, evidence_data.evidence_type)
truth = _classify_truth_status(source)
# Allow explicit override from request
if evidence_data.confidence_level:
try:
confidence = EvidenceConfidenceEnum(evidence_data.confidence_level)
except ValueError:
pass
if evidence_data.truth_status:
try:
truth = EvidenceTruthStatusEnum(evidence_data.truth_status)
except ValueError:
pass
evidence = repo.create(
control_id=control.id,
evidence_type=evidence_data.evidence_type,
title=evidence_data.title,
description=evidence_data.description,
artifact_url=evidence_data.artifact_url,
valid_from=evidence_data.valid_from,
valid_until=evidence_data.valid_until,
source=source,
ci_job_id=evidence_data.ci_job_id,
)
# Set anti-fake-evidence fields
evidence.confidence_level = confidence
evidence.truth_status = truth
# Generated evidence should not be used as evidence by default
if truth == EvidenceTruthStatusEnum.GENERATED:
evidence.may_be_used_as_evidence = False
# Four-Eyes: check if the linked control's domain requires it
control_domain = control.domain.value if control.domain else ""
if _requires_four_eyes(control_domain):
evidence.requires_four_eyes = True
evidence.approval_status = "pending_first"
db.commit()
# Audit trail
log_audit_trail(
db, "evidence", evidence.id, evidence.title, "create",
performed_by=evidence_data.source or "api",
change_summary=f"Evidence created with confidence={confidence.value}, truth={truth.value}",
)
db.commit()
return _build_evidence_response(evidence)
@router.delete("/evidence/{evidence_id}")
@@ -107,9 +274,271 @@ async def upload_evidence(
service: EvidenceService = Depends(get_evidence_service),
) -> EvidenceResponse:
"""Upload evidence file."""
with translate_domain_errors():
return await service.upload_evidence(
control_id, evidence_type, title, file, description
# Get control UUID
ctrl_repo = ControlRepository(db)
control = ctrl_repo.get_by_control_id(control_id)
if not control:
raise HTTPException(status_code=404, detail=f"Control {control_id} not found")
# Create upload directory
upload_dir = f"/tmp/compliance_evidence/{control_id}"
os.makedirs(upload_dir, exist_ok=True)
# Save file
file_path = os.path.join(upload_dir, file.filename)
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
# Calculate hash
file_hash = hashlib.sha256(content).hexdigest()
# Create evidence record
repo = EvidenceRepository(db)
evidence = repo.create(
control_id=control.id,
evidence_type=evidence_type,
title=title,
description=description,
artifact_path=file_path,
artifact_hash=file_hash,
file_size_bytes=len(content),
mime_type=file.content_type,
source="upload",
)
# Upload evidence → E1 + uploaded
evidence.confidence_level = EvidenceConfidenceEnum.E1
evidence.truth_status = EvidenceTruthStatusEnum.UPLOADED
# Four-Eyes: check if the linked control's domain requires it
control_domain = control.domain.value if control.domain else ""
if _requires_four_eyes(control_domain):
evidence.requires_four_eyes = True
evidence.approval_status = "pending_first"
db.commit()
return _build_evidence_response(evidence)
# ============================================================================
# CI/CD Evidence Collection — helpers
# ============================================================================
# Map CI source names to the corresponding control IDs
SOURCE_CONTROL_MAP = {
"sast": "SDLC-001",
"dependency_scan": "SDLC-002",
"secret_scan": "SDLC-003",
"code_review": "SDLC-004",
"sbom": "SDLC-005",
"container_scan": "SDLC-006",
"test_results": "AUD-001",
}
def _parse_ci_evidence(data: dict) -> dict:
"""
Parse and validate incoming CI evidence data.
Returns a dict with:
- report_json: str (serialised JSON)
- report_hash: str (SHA-256 hex digest)
- evidence_status: str ("valid" or "failed")
- findings_count: int
- critical_findings: int
"""
report_json = json.dumps(data) if data else "{}"
report_hash = hashlib.sha256(report_json.encode()).hexdigest()
findings_count = 0
critical_findings = 0
if data and isinstance(data, dict):
# Semgrep format
if "results" in data:
findings_count = len(data.get("results", []))
critical_findings = len([
r for r in data.get("results", [])
if r.get("extra", {}).get("severity", "").upper() in ["CRITICAL", "HIGH"]
])
# Trivy format
elif "Results" in data:
for result in data.get("Results", []):
vulns = result.get("Vulnerabilities", [])
findings_count += len(vulns)
critical_findings += len([
v for v in vulns
if v.get("Severity", "").upper() in ["CRITICAL", "HIGH"]
])
# Generic findings array
elif "findings" in data:
findings_count = len(data.get("findings", []))
# SBOM format - just count components
elif "components" in data:
findings_count = len(data.get("components", []))
evidence_status = "failed" if critical_findings > 0 else "valid"
return {
"report_json": report_json,
"report_hash": report_hash,
"evidence_status": evidence_status,
"findings_count": findings_count,
"critical_findings": critical_findings,
}
def _store_evidence(
db: Session,
*,
control_db_id: str,
source: str,
parsed: dict,
ci_job_id: str,
ci_job_url: str,
report_data: dict,
) -> EvidenceDB:
"""
Persist a CI evidence item to the database and write the report file.
Returns the created EvidenceDB instance (already committed).
"""
findings_count = parsed["findings_count"]
critical_findings = parsed["critical_findings"]
# Build title and description
title = f"{source.upper()} Report - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
description = "Automatically collected from CI/CD pipeline"
if findings_count > 0:
description += f"\n- Total findings: {findings_count}"
if critical_findings > 0:
description += f"\n- Critical/High findings: {critical_findings}"
if ci_job_id:
description += f"\n- CI Job ID: {ci_job_id}"
if ci_job_url:
description += f"\n- CI Job URL: {ci_job_url}"
# Store report file
upload_dir = f"/tmp/compliance_evidence/ci/{source}"
os.makedirs(upload_dir, exist_ok=True)
file_name = f"{source}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{parsed['report_hash'][:8]}.json"
file_path = os.path.join(upload_dir, file_name)
with open(file_path, "w") as f:
json.dump(report_data or {}, f, indent=2)
# Create evidence record with anti-fake-evidence classification
evidence = EvidenceDB(
id=str(uuid_module.uuid4()),
control_id=control_db_id,
evidence_type=f"ci_{source}",
title=title,
description=description,
artifact_path=file_path,
artifact_hash=parsed["report_hash"],
file_size_bytes=len(parsed["report_json"]),
mime_type="application/json",
source="ci_pipeline",
ci_job_id=ci_job_id,
valid_from=datetime.utcnow(),
valid_until=datetime.utcnow() + timedelta(days=90),
status=EvidenceStatusEnum(parsed["evidence_status"]),
# CI pipeline evidence → E3 observed (system-observed, hash-verified)
confidence_level=EvidenceConfidenceEnum.E3,
truth_status=EvidenceTruthStatusEnum.OBSERVED,
may_be_used_as_evidence=True,
)
db.add(evidence)
db.commit()
db.refresh(evidence)
return evidence
def _extract_findings_detail(report_data: dict) -> dict:
"""
Extract severity-bucketed finding counts from report data.
Returns dict with keys: critical, high, medium, low.
"""
findings_detail = {
"critical": 0,
"high": 0,
"medium": 0,
"low": 0,
}
if not report_data:
return findings_detail
# Semgrep format
if "results" in report_data:
for r in report_data.get("results", []):
severity = r.get("extra", {}).get("severity", "").upper()
if severity == "CRITICAL":
findings_detail["critical"] += 1
elif severity == "HIGH":
findings_detail["high"] += 1
elif severity == "MEDIUM":
findings_detail["medium"] += 1
elif severity in ["LOW", "INFO"]:
findings_detail["low"] += 1
# Trivy format
elif "Results" in report_data:
for result in report_data.get("Results", []):
for v in result.get("Vulnerabilities", []):
severity = v.get("Severity", "").upper()
if severity == "CRITICAL":
findings_detail["critical"] += 1
elif severity == "HIGH":
findings_detail["high"] += 1
elif severity == "MEDIUM":
findings_detail["medium"] += 1
elif severity == "LOW":
findings_detail["low"] += 1
# Generic findings with severity
elif "findings" in report_data:
for f in report_data.get("findings", []):
severity = f.get("severity", "").upper()
if severity == "CRITICAL":
findings_detail["critical"] += 1
elif severity == "HIGH":
findings_detail["high"] += 1
elif severity == "MEDIUM":
findings_detail["medium"] += 1
else:
findings_detail["low"] += 1
return findings_detail
def _update_risks(db: Session, *, source: str, control_id: str, ci_job_id: str, report_data: dict):
"""
Update risk status based on new evidence.
Uses AutoRiskUpdater to update Control status and linked Risks based on
severity-bucketed findings. Returns the update result or None on error.
"""
findings_detail = _extract_findings_detail(report_data)
try:
auto_updater = AutoRiskUpdater(db)
risk_update_result = auto_updater.process_evidence_collect_request(
tool=source,
control_id=control_id,
evidence_type=f"ci_{source}",
timestamp=datetime.utcnow().isoformat(),
commit_sha=report_data.get("commit_sha", "unknown") if report_data else "unknown",
ci_job_id=ci_job_id,
findings=findings_detail,
)
@@ -227,14 +656,229 @@ async def get_ci_evidence_status(
# Legacy re-exports for tests that import helpers directly.
# ----------------------------------------------------------------------------
__all__ = [
"router",
"SOURCE_CONTROL_MAP",
"EvidenceRepository",
"ControlRepository",
"AutoRiskUpdater",
"_parse_ci_evidence",
"_extract_findings_detail",
"_store_evidence",
"_update_risks",
]
if control_id:
ctrl_repo = ControlRepository(db)
control = ctrl_repo.get_by_control_id(control_id)
if control:
query = query.filter(EvidenceDB.control_id == control.id)
evidence_list = query.order_by(EvidenceDB.collected_at.desc()).limit(100).all()
# Group by control and calculate stats
control_stats = defaultdict(lambda: {
"total": 0,
"valid": 0,
"failed": 0,
"last_collected": None,
"evidence": [],
})
for e in evidence_list:
# Get control_id string
control = db.query(ControlDB).filter(ControlDB.id == e.control_id).first()
ctrl_id = control.control_id if control else "unknown"
stats = control_stats[ctrl_id]
stats["total"] += 1
if e.status:
if e.status.value == "valid":
stats["valid"] += 1
elif e.status.value == "failed":
stats["failed"] += 1
if not stats["last_collected"] or e.collected_at > stats["last_collected"]:
stats["last_collected"] = e.collected_at
# Add evidence summary
stats["evidence"].append({
"id": e.id,
"type": e.evidence_type,
"status": e.status.value if e.status else None,
"collected_at": e.collected_at.isoformat() if e.collected_at else None,
"ci_job_id": e.ci_job_id,
})
# Convert to list and sort
result = []
for ctrl_id, stats in control_stats.items():
result.append({
"control_id": ctrl_id,
"total_evidence": stats["total"],
"valid_count": stats["valid"],
"failed_count": stats["failed"],
"last_collected": stats["last_collected"].isoformat() if stats["last_collected"] else None,
"recent_evidence": stats["evidence"][:5],
})
result.sort(key=lambda x: x["last_collected"] or "", reverse=True)
return {
"period_days": days,
"total_evidence": len(evidence_list),
"controls": result,
}
# ============================================================================
# Evidence Review (Anti-Fake-Evidence)
# ============================================================================
from pydantic import BaseModel as _BaseModel
class _EvidenceReviewRequest(_BaseModel):
confidence_level: Optional[str] = None
truth_status: Optional[str] = None
reviewed_by: str
@router.patch("/evidence/{evidence_id}/review", response_model=EvidenceResponse)
async def review_evidence(
evidence_id: str,
review: _EvidenceReviewRequest,
db: Session = Depends(get_db),
):
"""
Review evidence: upgrade confidence level and/or change truth status.
For Four-Eyes evidence, the first reviewer sets first_reviewer and
approval_status='first_approved'. A second (different) reviewer then
sets second_reviewer and approval_status='approved'.
"""
evidence = db.query(EvidenceDB).filter(EvidenceDB.id == evidence_id).first()
if not evidence:
raise HTTPException(status_code=404, detail=f"Evidence {evidence_id} not found")
old_confidence = evidence.confidence_level.value if evidence.confidence_level else None
old_truth = evidence.truth_status.value if evidence.truth_status else None
if review.confidence_level:
try:
evidence.confidence_level = EvidenceConfidenceEnum(review.confidence_level)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid confidence_level: {review.confidence_level}")
if review.truth_status:
try:
evidence.truth_status = EvidenceTruthStatusEnum(review.truth_status)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid truth_status: {review.truth_status}")
# Four-Eyes branching
if evidence.requires_four_eyes:
status = evidence.approval_status or "none"
if status in ("none", "pending_first"):
evidence.first_reviewer = review.reviewed_by
evidence.first_reviewed_at = datetime.utcnow()
evidence.approval_status = "first_approved"
elif status == "first_approved":
if review.reviewed_by == evidence.first_reviewer:
raise HTTPException(
status_code=400,
detail="Four-Eyes: second reviewer must be different from first reviewer",
)
evidence.second_reviewer = review.reviewed_by
evidence.second_reviewed_at = datetime.utcnow()
evidence.approval_status = "approved"
elif status == "approved":
raise HTTPException(status_code=400, detail="Evidence already approved")
elif status == "rejected":
raise HTTPException(status_code=400, detail="Evidence was rejected — create new evidence instead")
evidence.reviewed_by = review.reviewed_by
evidence.reviewed_at = datetime.utcnow()
db.commit()
# Audit trail
new_confidence = evidence.confidence_level.value if evidence.confidence_level else None
if old_confidence != new_confidence:
log_audit_trail(
db, "evidence", evidence_id, evidence.title, "review",
performed_by=review.reviewed_by,
field_changed="confidence_level",
old_value=old_confidence,
new_value=new_confidence,
)
new_truth = evidence.truth_status.value if evidence.truth_status else None
if old_truth != new_truth:
log_audit_trail(
db, "evidence", evidence_id, evidence.title, "review",
performed_by=review.reviewed_by,
field_changed="truth_status",
old_value=old_truth,
new_value=new_truth,
)
db.commit()
db.refresh(evidence)
return _build_evidence_response(evidence)
@router.patch("/evidence/{evidence_id}/reject", response_model=EvidenceResponse)
async def reject_evidence(
evidence_id: str,
body: EvidenceRejectRequest,
db: Session = Depends(get_db),
):
"""Reject evidence (sets approval_status='rejected')."""
evidence = db.query(EvidenceDB).filter(EvidenceDB.id == evidence_id).first()
if not evidence:
raise HTTPException(status_code=404, detail=f"Evidence {evidence_id} not found")
evidence.approval_status = "rejected"
evidence.reviewed_by = body.reviewed_by
evidence.reviewed_at = datetime.utcnow()
db.commit()
log_audit_trail(
db, "evidence", evidence_id, evidence.title, "reject",
performed_by=body.reviewed_by,
change_summary=body.rejection_reason or "Evidence rejected",
)
db.commit()
db.refresh(evidence)
return _build_evidence_response(evidence)
# ============================================================================
# Audit Trail Query
# ============================================================================
@router.get("/audit-trail")
async def get_audit_trail(
entity_type: Optional[str] = Query(None),
entity_id: Optional[str] = Query(None),
action: Optional[str] = Query(None),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
):
"""Query audit trail entries for an entity."""
query = db.query(AuditTrailDB)
if entity_type:
query = query.filter(AuditTrailDB.entity_type == entity_type)
if entity_id:
query = query.filter(AuditTrailDB.entity_id == entity_id)
if action:
query = query.filter(AuditTrailDB.action == action)
records = query.order_by(AuditTrailDB.performed_at.desc()).limit(limit).all()
return {
"entries": [
{
"id": r.id,
"entity_type": r.entity_type,
"entity_id": r.entity_id,
"entity_name": r.entity_name,
"action": r.action,
"field_changed": r.field_changed,
"old_value": r.old_value,
"new_value": r.new_value,
"change_summary": r.change_summary,
"performed_by": r.performed_by,
"performed_at": r.performed_at.isoformat() if r.performed_at else None,
"checksum": r.checksum,
}
for r in records
],
"total": len(records),
}
@@ -39,7 +39,6 @@ router = APIRouter(tags=["extraction"])
ALL_COLLECTIONS = [
"bp_compliance_ce", # BSI-TR documents — primary Prüfaspekte source
"bp_compliance_recht", # Legal texts (GDPR, AI Act, ...)
"bp_compliance_gesetze", # German laws
"bp_compliance_datenschutz", # Data protection documents
"bp_dsfa_corpus", # DSFA corpus
@@ -80,9 +80,13 @@ def _handle(func, *args, **kwargs): # type: ignore[no-untyped-def]
raise HTTPException(status_code=400, detail=str(exc))
# ============================================================================
# ISMS Scope (ISO 27001 4.3)
# ============================================================================
# Shared audit trail utilities — canonical implementation in audit_trail_utils.py
from .audit_trail_utils import log_audit_trail, create_signature # noqa: E402
# =============================================================================
# ISMS SCOPE (ISO 27001 4.3)
# =============================================================================
@router.get("/scope", response_model=ISMSScopeResponse)
async def get_isms_scope(db: Session = Depends(get_db)):
@@ -50,6 +50,57 @@ VALID_DOCUMENT_TYPES = {
"cookie_banner",
"agb",
"clause",
# Security document templates (Migration 051)
"it_security_concept",
"data_protection_concept",
"backup_recovery_concept",
"logging_concept",
"incident_response_plan",
"access_control_concept",
"risk_management_concept",
# Policy templates — IT Security (Migration 054)
"information_security_policy",
"access_control_policy",
"password_policy",
"encryption_policy",
"logging_policy",
"backup_policy",
"incident_response_policy",
"change_management_policy",
"patch_management_policy",
"asset_management_policy",
"cloud_security_policy",
"devsecops_policy",
"secrets_management_policy",
"vulnerability_management_policy",
# Policy templates — Data (Migration 054)
"data_protection_policy",
"data_classification_policy",
"data_retention_policy",
"data_transfer_policy",
"privacy_incident_policy",
# Policy templates — Personnel (Migration 054)
"employee_security_policy",
"security_awareness_policy",
"remote_work_policy",
"offboarding_policy",
# Policy templates — Vendor/Supply Chain (Migration 054)
"vendor_risk_management_policy",
"third_party_security_policy",
"supplier_security_policy",
# Policy templates — BCM (Migration 054)
"business_continuity_policy",
"disaster_recovery_policy",
"crisis_management_policy",
# CRA Cybersecurity (Migration 056)
"cybersecurity_policy",
# DSFA template
"dsfa",
# Module document templates (Migration 073)
"vvt_register",
"tom_documentation",
"loeschkonzept",
"pflichtenregister",
}
VALID_STATUSES = {"published", "draft", "archived"}
@@ -0,0 +1,162 @@
"""
FastAPI routes for LLM Generation Audit Trail.
Endpoints:
- POST /llm-audit: Record an LLM generation event
- GET /llm-audit: List audit records with filters
"""
import logging
import uuid as uuid_module
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from ..db.models import LLMGenerationAuditDB
logger = logging.getLogger(__name__)
router = APIRouter(tags=["compliance-llm-audit"])
# ============================================================================
# Schemas
# ============================================================================
class LLMAuditCreate(BaseModel):
entity_type: str
entity_id: Optional[str] = None
generation_mode: str
truth_status: str = "generated"
may_be_used_as_evidence: bool = False
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
prompt_hash: Optional[str] = None
input_summary: Optional[str] = None
output_summary: Optional[str] = None
metadata: Optional[dict] = None
tenant_id: Optional[str] = None
class LLMAuditResponse(BaseModel):
id: str
tenant_id: Optional[str] = None
entity_type: str
entity_id: Optional[str] = None
generation_mode: str
truth_status: str
may_be_used_as_evidence: bool
llm_model: Optional[str] = None
llm_provider: Optional[str] = None
prompt_hash: Optional[str] = None
input_summary: Optional[str] = None
output_summary: Optional[str] = None
metadata: Optional[dict] = None
created_at: datetime
class Config:
from_attributes = True
# ============================================================================
# Routes
# ============================================================================
@router.post("/llm-audit", response_model=LLMAuditResponse)
async def create_llm_audit(
data: LLMAuditCreate,
db: Session = Depends(get_db),
):
"""Record an LLM generation event for audit trail."""
from ..db.models import EvidenceTruthStatusEnum
# Validate truth_status
try:
truth_enum = EvidenceTruthStatusEnum(data.truth_status)
except ValueError:
truth_enum = EvidenceTruthStatusEnum.GENERATED
record = LLMGenerationAuditDB(
id=str(uuid_module.uuid4()),
tenant_id=data.tenant_id,
entity_type=data.entity_type,
entity_id=data.entity_id,
generation_mode=data.generation_mode,
truth_status=truth_enum,
may_be_used_as_evidence=data.may_be_used_as_evidence,
llm_model=data.llm_model,
llm_provider=data.llm_provider,
prompt_hash=data.prompt_hash,
input_summary=data.input_summary[:500] if data.input_summary else None,
output_summary=data.output_summary[:500] if data.output_summary else None,
extra_metadata=data.metadata or {},
)
db.add(record)
db.commit()
db.refresh(record)
return LLMAuditResponse(
id=record.id,
tenant_id=record.tenant_id,
entity_type=record.entity_type,
entity_id=record.entity_id,
generation_mode=record.generation_mode,
truth_status=record.truth_status.value if record.truth_status else "generated",
may_be_used_as_evidence=record.may_be_used_as_evidence,
llm_model=record.llm_model,
llm_provider=record.llm_provider,
prompt_hash=record.prompt_hash,
input_summary=record.input_summary,
output_summary=record.output_summary,
metadata=record.extra_metadata,
created_at=record.created_at,
)
@router.get("/llm-audit")
async def list_llm_audit(
entity_type: Optional[str] = Query(None),
entity_id: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
):
"""List LLM generation audit records with optional filters."""
query = db.query(LLMGenerationAuditDB)
if entity_type:
query = query.filter(LLMGenerationAuditDB.entity_type == entity_type)
if entity_id:
query = query.filter(LLMGenerationAuditDB.entity_id == entity_id)
total = query.count()
offset = (page - 1) * limit
records = query.order_by(LLMGenerationAuditDB.created_at.desc()).offset(offset).limit(limit).all()
return {
"records": [
LLMAuditResponse(
id=r.id,
tenant_id=r.tenant_id,
entity_type=r.entity_type,
entity_id=r.entity_id,
generation_mode=r.generation_mode,
truth_status=r.truth_status.value if r.truth_status else "generated",
may_be_used_as_evidence=r.may_be_used_as_evidence,
llm_model=r.llm_model,
llm_provider=r.llm_provider,
prompt_hash=r.prompt_hash,
input_summary=r.input_summary,
output_summary=r.output_summary,
metadata=r.extra_metadata,
created_at=r.created_at,
)
for r in records
],
"total": total,
"page": page,
"limit": limit,
}
@@ -56,6 +56,7 @@ class LoeschfristCreate(BaseModel):
responsible_person: Optional[str] = None
release_process: Optional[str] = None
linked_vvt_activity_ids: Optional[List[Any]] = None
linked_vendor_ids: Optional[List[Any]] = None
status: str = "DRAFT"
last_review_date: Optional[datetime] = None
next_review_date: Optional[datetime] = None
@@ -86,6 +87,7 @@ class LoeschfristUpdate(BaseModel):
responsible_person: Optional[str] = None
release_process: Optional[str] = None
linked_vvt_activity_ids: Optional[List[Any]] = None
linked_vendor_ids: Optional[List[Any]] = None
status: Optional[str] = None
last_review_date: Optional[datetime] = None
next_review_date: Optional[datetime] = None
@@ -100,7 +102,7 @@ class StatusUpdate(BaseModel):
# JSONB fields that need CAST
JSONB_FIELDS = {
"affected_groups", "data_categories", "legal_holds",
"storage_locations", "linked_vvt_activity_ids", "tags"
"storage_locations", "linked_vvt_activity_ids", "linked_vendor_ids", "tags"
}
@@ -42,6 +42,7 @@ class ObligationCreate(BaseModel):
priority: str = "medium"
responsible: Optional[str] = None
linked_systems: Optional[List[str]] = None
linked_vendor_ids: Optional[List[str]] = None
assessment_id: Optional[str] = None
rule_code: Optional[str] = None
notes: Optional[str] = None
@@ -57,6 +58,7 @@ class ObligationUpdate(BaseModel):
priority: Optional[str] = None
responsible: Optional[str] = None
linked_systems: Optional[List[str]] = None
linked_vendor_ids: Optional[List[str]] = None
notes: Optional[str] = None
@@ -173,14 +175,15 @@ async def create_obligation(
import json
linked_systems = json.dumps(payload.linked_systems or [])
linked_vendor_ids = json.dumps(payload.linked_vendor_ids or [])
row = db.execute(text("""
INSERT INTO compliance_obligations
(tenant_id, title, description, source, source_article, deadline,
status, priority, responsible, linked_systems, assessment_id, rule_code, notes)
status, priority, responsible, linked_systems, linked_vendor_ids, assessment_id, rule_code, notes)
VALUES
(:tenant_id, :title, :description, :source, :source_article, :deadline,
:status, :priority, :responsible, CAST(:linked_systems AS jsonb), :assessment_id, :rule_code, :notes)
:status, :priority, :responsible, CAST(:linked_systems AS jsonb), CAST(:linked_vendor_ids AS jsonb), :assessment_id, :rule_code, :notes)
RETURNING *
"""), {
"tenant_id": tenant_id,
@@ -193,6 +196,7 @@ async def create_obligation(
"priority": payload.priority,
"responsible": payload.responsible,
"linked_systems": linked_systems,
"linked_vendor_ids": linked_vendor_ids,
"assessment_id": payload.assessment_id,
"rule_code": payload.rule_code,
"notes": payload.notes,
@@ -235,6 +239,9 @@ async def update_obligation(
if field == "linked_systems":
updates["linked_systems"] = json.dumps(value or [])
set_clauses.append("linked_systems = CAST(:linked_systems AS jsonb)")
elif field == "linked_vendor_ids":
updates["linked_vendor_ids"] = json.dumps(value or [])
set_clauses.append("linked_vendor_ids = CAST(:linked_vendor_ids AS jsonb)")
else:
updates[field] = value
set_clauses.append(f"{field} = :{field}")
File diff suppressed because it is too large Load Diff
+148 -6
View File
@@ -25,6 +25,7 @@ from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from .audit_trail_utils import log_audit_trail
from ..db import (
ControlDomainEnum,
ControlRepository,
@@ -312,8 +313,39 @@ async def get_control(
svc: ControlExportService = Depends(get_ctrl_export_service),
) -> ControlResponse:
"""Get a specific control by control_id."""
with translate_domain_errors():
return svc.get_control(control_id)
repo = ControlRepository(db)
control = repo.get_by_control_id(control_id)
if not control:
raise HTTPException(status_code=404, detail=f"Control {control_id} not found")
evidence_repo = EvidenceRepository(db)
evidence = evidence_repo.get_by_control(control.id)
return ControlResponse(
id=control.id,
control_id=control.control_id,
domain=control.domain.value if control.domain else None,
control_type=control.control_type.value if control.control_type else None,
title=control.title,
description=control.description,
pass_criteria=control.pass_criteria,
implementation_guidance=control.implementation_guidance,
code_reference=control.code_reference,
documentation_url=control.documentation_url,
is_automated=control.is_automated,
automation_tool=control.automation_tool,
automation_config=control.automation_config,
owner=control.owner,
review_frequency_days=control.review_frequency_days,
status=control.status.value if control.status else None,
status_notes=control.status_notes,
status_justification=control.status_justification,
last_reviewed_at=control.last_reviewed_at,
next_review_at=control.next_review_at,
created_at=control.created_at,
updated_at=control.updated_at,
evidence_count=len(evidence),
)
@router.put(
@@ -325,8 +357,83 @@ async def update_control(
svc: ControlExportService = Depends(get_ctrl_export_service),
) -> ControlResponse:
"""Update a control."""
with translate_domain_errors():
return svc.update_control(control_id, update)
repo = ControlRepository(db)
control = repo.get_by_control_id(control_id)
if not control:
raise HTTPException(status_code=404, detail=f"Control {control_id} not found")
update_data = update.model_dump(exclude_unset=True)
# Convert status string to enum and validate transition
if "status" in update_data:
try:
new_status_enum = ControlStatusEnum(update_data["status"])
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid status: {update_data['status']}")
# Validate status transition (Anti-Fake-Evidence)
from ..services.control_status_machine import validate_transition
current_status = control.status.value if control.status else "planned"
evidence_list = db.query(EvidenceDB).filter(EvidenceDB.control_id == control.id).all()
allowed, violations = validate_transition(
current_status=current_status,
new_status=update_data["status"],
evidence_list=evidence_list,
status_justification=update_data.get("status_justification") or update_data.get("status_notes"),
)
if not allowed:
raise HTTPException(
status_code=409,
detail={
"error": "Status transition not allowed",
"current_status": current_status,
"requested_status": update_data["status"],
"violations": violations,
}
)
update_data["status"] = new_status_enum
updated = repo.update(control.id, **update_data)
db.commit()
# Audit trail for status changes
new_status = updated.status.value if updated.status else None
if "status" in update.model_dump(exclude_unset=True) and current_status != new_status:
log_audit_trail(
db, "control", control.id, updated.control_id or updated.title,
"status_change",
performed_by=update.owner or "system",
field_changed="status",
old_value=current_status,
new_value=new_status,
)
db.commit()
return ControlResponse(
id=updated.id,
control_id=updated.control_id,
domain=updated.domain.value if updated.domain else None,
control_type=updated.control_type.value if updated.control_type else None,
title=updated.title,
description=updated.description,
pass_criteria=updated.pass_criteria,
implementation_guidance=updated.implementation_guidance,
code_reference=updated.code_reference,
documentation_url=updated.documentation_url,
is_automated=updated.is_automated,
automation_tool=updated.automation_tool,
automation_config=updated.automation_config,
owner=updated.owner,
review_frequency_days=updated.review_frequency_days,
status=updated.status.value if updated.status else None,
status_notes=updated.status_notes,
status_justification=updated.status_justification,
last_reviewed_at=updated.last_reviewed_at,
next_review_at=updated.next_review_at,
created_at=updated.created_at,
updated_at=updated.updated_at,
)
@router.put(
@@ -339,8 +446,43 @@ async def review_control(
svc: ControlExportService = Depends(get_ctrl_export_service),
) -> ControlResponse:
"""Mark a control as reviewed with new status."""
with translate_domain_errors():
return svc.review_control(control_id, review)
repo = ControlRepository(db)
control = repo.get_by_control_id(control_id)
if not control:
raise HTTPException(status_code=404, detail=f"Control {control_id} not found")
try:
status_enum = ControlStatusEnum(review.status)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid status: {review.status}")
updated = repo.mark_reviewed(control.id, status_enum, review.status_notes)
db.commit()
return ControlResponse(
id=updated.id,
control_id=updated.control_id,
domain=updated.domain.value if updated.domain else None,
control_type=updated.control_type.value if updated.control_type else None,
title=updated.title,
description=updated.description,
pass_criteria=updated.pass_criteria,
implementation_guidance=updated.implementation_guidance,
code_reference=updated.code_reference,
documentation_url=updated.documentation_url,
is_automated=updated.is_automated,
automation_tool=updated.automation_tool,
automation_config=updated.automation_config,
owner=updated.owner,
review_frequency_days=updated.review_frequency_days,
status=updated.status.value if updated.status else None,
status_notes=updated.status_notes,
status_justification=updated.status_justification,
last_reviewed_at=updated.last_reviewed_at,
next_review_at=updated.next_review_at,
created_at=updated.created_at,
updated_at=updated.updated_at,
)
@router.get(
File diff suppressed because it is too large Load Diff
@@ -22,7 +22,9 @@ import uuid
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
import httpx
from fastapi import APIRouter, File, Form, UploadFile, HTTPException
from pydantic import BaseModel
from sqlalchemy import text
from database import SessionLocal # re-exported below for legacy test patches
@@ -96,15 +98,13 @@ async def scan_dependencies(
db = SessionLocal()
try:
db.execute(
text(
"INSERT INTO compliance_screenings "
"(id, tenant_id, status, sbom_format, sbom_version, "
"total_components, total_issues, critical_issues, high_issues, "
"medium_issues, low_issues, sbom_data, started_at, completed_at) "
"VALUES (:id, :tenant_id, 'completed', 'CycloneDX', '1.5', "
":total_components, :total_issues, :critical, :high, :medium, :low, "
":sbom_data::jsonb, :started_at, :completed_at)"
),
text("""INSERT INTO compliance_screenings
(id, tenant_id, status, sbom_format, sbom_version,
total_components, total_issues, critical_issues, high_issues, medium_issues, low_issues,
sbom_data, started_at, completed_at)
VALUES (:id, :tenant_id, 'completed', 'CycloneDX', '1.5',
:total_components, :total_issues, :critical, :high, :medium, :low,
:sbom_data::jsonb, :started_at, :completed_at)"""),
{
"id": screening_id,
"tenant_id": tenant_id,
@@ -121,13 +121,11 @@ async def scan_dependencies(
)
for issue in issues:
db.execute(
text(
"INSERT INTO compliance_security_issues "
"(id, screening_id, severity, title, description, cve, cvss, "
"affected_component, affected_version, fixed_in, remediation, status) "
"VALUES (:id, :screening_id, :severity, :title, :description, :cve, :cvss, "
":component, :version, :fixed_in, :remediation, :status)"
),
text("""INSERT INTO compliance_security_issues
(id, screening_id, severity, title, description, cve, cvss,
affected_component, affected_version, fixed_in, remediation, status)
VALUES (:id, :screening_id, :severity, :title, :description, :cve, :cvss,
:component, :version, :fixed_in, :remediation, :status)"""),
{
"id": issue["id"],
"screening_id": screening_id,
@@ -214,8 +212,77 @@ async def get_screening(screening_id: str) -> ScreeningResponse:
"""Get a screening result by ID."""
db = SessionLocal()
try:
with translate_domain_errors():
return ScreeningService(db).get_screening(screening_id)
result = db.execute(
text("""SELECT id, status, sbom_format, sbom_version,
total_components, total_issues, critical_issues, high_issues,
medium_issues, low_issues, sbom_data, started_at, completed_at
FROM compliance_screenings WHERE id = :id"""),
{"id": screening_id},
)
row = result.fetchone()
if not row:
raise HTTPException(status_code=404, detail="Screening not found")
# Fetch issues
issues_result = db.execute(
text("""SELECT id, severity, title, description, cve, cvss,
affected_component, affected_version, fixed_in, remediation, status
FROM compliance_security_issues WHERE screening_id = :id"""),
{"id": screening_id},
)
issues_rows = issues_result.fetchall()
issues = [
SecurityIssueResponse(
id=str(r[0]), severity=r[1], title=r[2], description=r[3],
cve=r[4], cvss=r[5], affected_component=r[6],
affected_version=r[7], fixed_in=r[8], remediation=r[9], status=r[10],
)
for r in issues_rows
]
# Reconstruct components from SBOM data
sbom_data = row[10] or {}
components = []
comp_vulns: dict[str, list[dict]] = {}
for issue in issues:
if issue.affected_component not in comp_vulns:
comp_vulns[issue.affected_component] = []
comp_vulns[issue.affected_component].append({
"id": issue.cve or issue.id,
"cve": issue.cve,
"severity": issue.severity,
"title": issue.title,
"cvss": issue.cvss,
"fixedIn": issue.fixed_in,
})
for sc in sbom_data.get("components", []):
components.append(SBOMComponentResponse(
name=sc["name"],
version=sc["version"],
type=sc.get("type", "library"),
purl=sc.get("purl", ""),
licenses=sc.get("licenses", []),
vulnerabilities=comp_vulns.get(sc["name"], []),
))
return ScreeningResponse(
id=str(row[0]),
status=row[1],
sbom_format=row[2] or "CycloneDX",
sbom_version=row[3] or "1.5",
total_components=row[4] or 0,
total_issues=row[5] or 0,
critical_issues=row[6] or 0,
high_issues=row[7] or 0,
medium_issues=row[8] or 0,
low_issues=row[9] or 0,
components=components,
issues=issues,
started_at=str(row[11]) if row[11] else None,
completed_at=str(row[12]) if row[12] else None,
)
finally:
db.close()
@@ -225,8 +292,33 @@ async def list_screenings(tenant_id: str = "default") -> ScreeningListResponse:
"""List all screenings for a tenant."""
db = SessionLocal()
try:
with translate_domain_errors():
return ScreeningService(db).list_screenings(tenant_id)
result = db.execute(
text("""SELECT id, status, total_components, total_issues,
critical_issues, high_issues, medium_issues, low_issues,
started_at, completed_at, created_at
FROM compliance_screenings
WHERE tenant_id = :tenant_id
ORDER BY created_at DESC"""),
{"tenant_id": tenant_id},
)
rows = result.fetchall()
screenings = [
{
"id": str(r[0]),
"status": r[1],
"total_components": r[2],
"total_issues": r[3],
"critical_issues": r[4],
"high_issues": r[5],
"medium_issues": r[6],
"low_issues": r[7],
"started_at": str(r[8]) if r[8] else None,
"completed_at": str(r[9]) if r[9] else None,
"created_at": str(r[10]),
}
for r in rows
]
return ScreeningListResponse(screenings=screenings, total=len(screenings))
finally:
db.close()
@@ -0,0 +1,537 @@
"""
TOM Canonical Control Mapping Routes.
Three-layer architecture:
TOM Measures (~88, audit-level) Mapping Bridge Canonical Controls (10,000+)
Endpoints:
POST /v1/tom-mappings/sync Sync canonical controls for company profile
GET /v1/tom-mappings List all mappings for tenant/project
GET /v1/tom-mappings/by-tom/{code} Mappings for a specific TOM control
GET /v1/tom-mappings/stats Coverage statistics
POST /v1/tom-mappings/manual Manually add a mapping
DELETE /v1/tom-mappings/{id} Remove a mapping
"""
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Query, Header
from pydantic import BaseModel
from sqlalchemy import text
from database import SessionLocal
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tom-mappings", tags=["tom-control-mappings"])
# =============================================================================
# TOM CATEGORY → CANONICAL CATEGORY MAPPING
# =============================================================================
# Maps 13 TOM control categories to canonical_control_categories
# Each TOM category maps to 1-3 canonical categories for broad coverage
TOM_TO_CANONICAL_CATEGORIES: dict[str, list[str]] = {
"ACCESS_CONTROL": ["authentication", "identity", "physical"],
"ADMISSION_CONTROL": ["authentication", "identity", "system"],
"ACCESS_AUTHORIZATION": ["authentication", "identity"],
"TRANSFER_CONTROL": ["network", "data_protection", "encryption"],
"INPUT_CONTROL": ["application", "data_protection"],
"ORDER_CONTROL": ["supply_chain", "compliance"],
"AVAILABILITY": ["continuity", "system"],
"SEPARATION": ["network", "data_protection"],
"ENCRYPTION": ["encryption"],
"PSEUDONYMIZATION": ["data_protection", "encryption"],
"RESILIENCE": ["continuity", "system"],
"RECOVERY": ["continuity"],
"REVIEW": ["compliance", "governance", "risk"],
}
# =============================================================================
# REQUEST / RESPONSE MODELS
# =============================================================================
class SyncRequest(BaseModel):
"""Trigger a sync of canonical controls to TOM measures."""
industry: Optional[str] = None
company_size: Optional[str] = None
force: bool = False
class ManualMappingRequest(BaseModel):
"""Manually add a canonical control to a TOM measure."""
tom_control_code: str
tom_category: str
canonical_control_id: str
canonical_control_code: str
canonical_category: Optional[str] = None
relevance_score: float = 1.0
# =============================================================================
# HELPERS
# =============================================================================
def _get_tenant_id(x_tenant_id: Optional[str]) -> str:
"""Extract tenant ID from header."""
if not x_tenant_id:
raise HTTPException(status_code=400, detail="X-Tenant-ID header required")
return x_tenant_id
def _compute_profile_hash(industry: Optional[str], company_size: Optional[str]) -> str:
"""Compute a hash from profile parameters for change detection."""
data = json.dumps({"industry": industry, "company_size": company_size}, sort_keys=True)
return hashlib.sha256(data.encode()).hexdigest()[:16]
def _mapping_row_to_dict(r) -> dict[str, Any]:
"""Convert a mapping row to API response dict."""
return {
"id": str(r.id),
"tenant_id": str(r.tenant_id),
"project_id": str(r.project_id) if r.project_id else None,
"tom_control_code": r.tom_control_code,
"tom_category": r.tom_category,
"canonical_control_id": str(r.canonical_control_id),
"canonical_control_code": r.canonical_control_code,
"canonical_category": r.canonical_category,
"mapping_type": r.mapping_type,
"relevance_score": float(r.relevance_score) if r.relevance_score else 1.0,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
# =============================================================================
# SYNC ENDPOINT
# =============================================================================
@router.post("/sync")
async def sync_mappings(
body: SyncRequest,
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID"),
project_id: Optional[str] = Query(None),
):
"""
Sync canonical controls to TOM measures based on company profile.
Algorithm:
1. Compute profile hash skip if unchanged (unless force=True)
2. For each TOM category, find matching canonical controls by:
- Category mapping (TOM category canonical categories)
- Industry filter (applicable_industries JSONB containment)
- Company size filter (applicable_company_size JSONB containment)
- Only approved + customer_visible controls
3. Delete old auto-mappings, insert new ones
4. Update sync state
"""
tenant_id = _get_tenant_id(x_tenant_id)
profile_hash = _compute_profile_hash(body.industry, body.company_size)
with SessionLocal() as db:
# Check if sync is needed (profile unchanged)
if not body.force:
existing = db.execute(
text("""
SELECT profile_hash FROM tom_control_sync_state
WHERE tenant_id = :tid AND (project_id = :pid OR (project_id IS NULL AND :pid IS NULL))
"""),
{"tid": tenant_id, "pid": project_id},
).fetchone()
if existing and existing.profile_hash == profile_hash:
return {
"status": "unchanged",
"message": "Profile unchanged since last sync",
"profile_hash": profile_hash,
}
# Delete old auto-mappings for this tenant+project
db.execute(
text("""
DELETE FROM tom_control_mappings
WHERE tenant_id = :tid
AND (project_id = :pid OR (project_id IS NULL AND :pid IS NULL))
AND mapping_type = 'auto'
"""),
{"tid": tenant_id, "pid": project_id},
)
total_mappings = 0
canonical_ids_matched = set()
tom_codes_covered = set()
# For each TOM category, find matching canonical controls
for tom_category, canonical_categories in TOM_TO_CANONICAL_CATEGORIES.items():
# Build JSONB containment query for categories
cat_conditions = " OR ".join(
f"category = :cat_{i}" for i in range(len(canonical_categories))
)
cat_params = {f"cat_{i}": c for i, c in enumerate(canonical_categories)}
# Build industry filter
industry_filter = ""
if body.industry:
industry_filter = """
AND (
applicable_industries IS NULL
OR applicable_industries @> '"all"'::jsonb
OR applicable_industries @> (:industry)::jsonb
)
"""
cat_params["industry"] = json.dumps([body.industry])
# Build company size filter
size_filter = ""
if body.company_size:
size_filter = """
AND (
applicable_company_size IS NULL
OR applicable_company_size @> '"all"'::jsonb
OR applicable_company_size @> (:csize)::jsonb
)
"""
cat_params["csize"] = json.dumps([body.company_size])
query = f"""
SELECT id, control_id, category
FROM canonical_controls
WHERE ({cat_conditions})
AND release_state = 'approved'
AND customer_visible = true
{industry_filter}
{size_filter}
ORDER BY control_id
"""
rows = db.execute(text(query), cat_params).fetchall()
# Find TOM control codes in this category (query the frontend library
# codes; we use the category prefix pattern from the loader)
# TOM codes follow pattern: TOM-XX-NN where XX is category abbreviation
# We insert one mapping per canonical control per TOM category
for row in rows:
db.execute(
text("""
INSERT INTO tom_control_mappings (
tenant_id, project_id, tom_control_code, tom_category,
canonical_control_id, canonical_control_code, canonical_category,
mapping_type, relevance_score
) VALUES (
:tid, :pid, :tom_cat, :tom_cat,
:cc_id, :cc_code, :cc_category,
'auto', 1.00
)
ON CONFLICT (tenant_id, project_id, tom_control_code, canonical_control_id)
DO NOTHING
"""),
{
"tid": tenant_id,
"pid": project_id,
"tom_cat": tom_category,
"cc_id": str(row.id),
"cc_code": row.control_id,
"cc_category": row.category,
},
)
total_mappings += 1
canonical_ids_matched.add(str(row.id))
tom_codes_covered.add(tom_category)
# Upsert sync state
db.execute(
text("""
INSERT INTO tom_control_sync_state (
tenant_id, project_id, profile_hash,
total_mappings, canonical_controls_matched, tom_controls_covered,
last_synced_at
) VALUES (
:tid, :pid, :hash,
:total, :matched, :covered,
NOW()
)
ON CONFLICT (tenant_id, project_id)
DO UPDATE SET
profile_hash = :hash,
total_mappings = :total,
canonical_controls_matched = :matched,
tom_controls_covered = :covered,
last_synced_at = NOW()
"""),
{
"tid": tenant_id,
"pid": project_id,
"hash": profile_hash,
"total": total_mappings,
"matched": len(canonical_ids_matched),
"covered": len(tom_codes_covered),
},
)
db.commit()
return {
"status": "synced",
"profile_hash": profile_hash,
"total_mappings": total_mappings,
"canonical_controls_matched": len(canonical_ids_matched),
"tom_categories_covered": len(tom_codes_covered),
}
# =============================================================================
# LIST MAPPINGS
# =============================================================================
@router.get("")
async def list_mappings(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID"),
project_id: Optional[str] = Query(None),
tom_category: Optional[str] = Query(None),
mapping_type: Optional[str] = Query(None),
limit: int = Query(500, ge=1, le=5000),
offset: int = Query(0, ge=0),
):
"""List all TOM ↔ canonical control mappings for tenant/project."""
tenant_id = _get_tenant_id(x_tenant_id)
query = """
SELECT m.*, cc.title as canonical_title, cc.severity as canonical_severity
FROM tom_control_mappings m
LEFT JOIN canonical_controls cc ON cc.id = m.canonical_control_id
WHERE m.tenant_id = :tid
AND (m.project_id = :pid OR (m.project_id IS NULL AND :pid IS NULL))
"""
params: dict[str, Any] = {"tid": tenant_id, "pid": project_id}
if tom_category:
query += " AND m.tom_category = :tcat"
params["tcat"] = tom_category
if mapping_type:
query += " AND m.mapping_type = :mtype"
params["mtype"] = mapping_type
query += " ORDER BY m.tom_category, m.canonical_control_code"
query += " LIMIT :lim OFFSET :off"
params["lim"] = limit
params["off"] = offset
count_query = """
SELECT count(*) FROM tom_control_mappings
WHERE tenant_id = :tid
AND (project_id = :pid OR (project_id IS NULL AND :pid IS NULL))
"""
count_params: dict[str, Any] = {"tid": tenant_id, "pid": project_id}
if tom_category:
count_query += " AND tom_category = :tcat"
count_params["tcat"] = tom_category
with SessionLocal() as db:
rows = db.execute(text(query), params).fetchall()
total = db.execute(text(count_query), count_params).scalar()
mappings = []
for r in rows:
d = _mapping_row_to_dict(r)
d["canonical_title"] = getattr(r, "canonical_title", None)
d["canonical_severity"] = getattr(r, "canonical_severity", None)
mappings.append(d)
return {"mappings": mappings, "total": total}
# =============================================================================
# MAPPINGS BY TOM CONTROL
# =============================================================================
@router.get("/by-tom/{tom_code}")
async def get_mappings_by_tom(
tom_code: str,
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID"),
project_id: Optional[str] = Query(None),
):
"""Get all canonical controls mapped to a specific TOM control code or category."""
tenant_id = _get_tenant_id(x_tenant_id)
with SessionLocal() as db:
rows = db.execute(
text("""
SELECT m.*, cc.title as canonical_title, cc.severity as canonical_severity,
cc.objective as canonical_objective
FROM tom_control_mappings m
LEFT JOIN canonical_controls cc ON cc.id = m.canonical_control_id
WHERE m.tenant_id = :tid
AND (m.project_id = :pid OR (m.project_id IS NULL AND :pid IS NULL))
AND (m.tom_control_code = :code OR m.tom_category = :code)
ORDER BY m.canonical_control_code
"""),
{"tid": tenant_id, "pid": project_id, "code": tom_code},
).fetchall()
mappings = []
for r in rows:
d = _mapping_row_to_dict(r)
d["canonical_title"] = getattr(r, "canonical_title", None)
d["canonical_severity"] = getattr(r, "canonical_severity", None)
d["canonical_objective"] = getattr(r, "canonical_objective", None)
mappings.append(d)
return {"tom_code": tom_code, "mappings": mappings, "total": len(mappings)}
# =============================================================================
# STATS
# =============================================================================
@router.get("/stats")
async def get_mapping_stats(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID"),
project_id: Optional[str] = Query(None),
):
"""Coverage statistics for TOM ↔ canonical control mappings."""
tenant_id = _get_tenant_id(x_tenant_id)
with SessionLocal() as db:
# Sync state
sync_state = db.execute(
text("""
SELECT * FROM tom_control_sync_state
WHERE tenant_id = :tid
AND (project_id = :pid OR (project_id IS NULL AND :pid IS NULL))
"""),
{"tid": tenant_id, "pid": project_id},
).fetchone()
# Per-category breakdown
category_stats = db.execute(
text("""
SELECT tom_category,
count(*) as total_mappings,
count(DISTINCT canonical_control_id) as unique_controls,
count(*) FILTER (WHERE mapping_type = 'auto') as auto_count,
count(*) FILTER (WHERE mapping_type = 'manual') as manual_count
FROM tom_control_mappings
WHERE tenant_id = :tid
AND (project_id = :pid OR (project_id IS NULL AND :pid IS NULL))
GROUP BY tom_category
ORDER BY tom_category
"""),
{"tid": tenant_id, "pid": project_id},
).fetchall()
# Total canonical controls in DB (approved + visible)
total_canonical = db.execute(
text("""
SELECT count(*) FROM canonical_controls
WHERE release_state = 'approved' AND customer_visible = true
""")
).scalar()
return {
"sync_state": {
"profile_hash": sync_state.profile_hash if sync_state else None,
"total_mappings": sync_state.total_mappings if sync_state else 0,
"canonical_controls_matched": sync_state.canonical_controls_matched if sync_state else 0,
"tom_controls_covered": sync_state.tom_controls_covered if sync_state else 0,
"last_synced_at": sync_state.last_synced_at.isoformat() if sync_state and sync_state.last_synced_at else None,
},
"category_breakdown": [
{
"tom_category": r.tom_category,
"total_mappings": r.total_mappings,
"unique_controls": r.unique_controls,
"auto_count": r.auto_count,
"manual_count": r.manual_count,
}
for r in category_stats
],
"total_canonical_controls_available": total_canonical or 0,
}
# =============================================================================
# MANUAL MAPPING
# =============================================================================
@router.post("/manual", status_code=201)
async def add_manual_mapping(
body: ManualMappingRequest,
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID"),
project_id: Optional[str] = Query(None),
):
"""Manually add a canonical control to a TOM measure."""
tenant_id = _get_tenant_id(x_tenant_id)
with SessionLocal() as db:
# Verify canonical control exists
cc = db.execute(
text("SELECT id, control_id, category FROM canonical_controls WHERE id = CAST(:cid AS uuid)"),
{"cid": body.canonical_control_id},
).fetchone()
if not cc:
raise HTTPException(status_code=404, detail="Canonical control not found")
try:
row = db.execute(
text("""
INSERT INTO tom_control_mappings (
tenant_id, project_id, tom_control_code, tom_category,
canonical_control_id, canonical_control_code, canonical_category,
mapping_type, relevance_score
) VALUES (
:tid, :pid, :tom_code, :tom_cat,
CAST(:cc_id AS uuid), :cc_code, :cc_category,
'manual', :score
)
RETURNING *
"""),
{
"tid": tenant_id,
"pid": project_id,
"tom_code": body.tom_control_code,
"tom_cat": body.tom_category,
"cc_id": body.canonical_control_id,
"cc_code": body.canonical_control_code,
"cc_category": body.canonical_category or cc.category,
"score": body.relevance_score,
},
).fetchone()
db.commit()
except Exception as e:
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
raise HTTPException(status_code=409, detail="Mapping already exists")
raise
return _mapping_row_to_dict(row)
# =============================================================================
# DELETE MAPPING
# =============================================================================
@router.delete("/{mapping_id}", status_code=204)
async def delete_mapping(
mapping_id: str,
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID"),
):
"""Remove a mapping (manual or auto)."""
tenant_id = _get_tenant_id(x_tenant_id)
with SessionLocal() as db:
result = db.execute(
text("""
DELETE FROM tom_control_mappings
WHERE id = CAST(:mid AS uuid) AND tenant_id = :tid
"""),
{"mid": mapping_id, "tid": tenant_id},
)
if result.rowcount == 0:
raise HTTPException(status_code=404, detail="Mapping not found")
db.commit()
return None
@@ -0,0 +1,427 @@
"""
FastAPI routes for VVT Master Libraries + Process Templates.
Library endpoints (read-only, global):
GET /vvt/libraries Overview: all library types + counts
GET /vvt/libraries/data-subjects Data subjects (filter: typical_for)
GET /vvt/libraries/data-categories Hierarchical (filter: parent_id, is_art9, flat)
GET /vvt/libraries/recipients Recipients (filter: type)
GET /vvt/libraries/legal-bases Legal bases (filter: is_art9, type)
GET /vvt/libraries/retention-rules Retention rules
GET /vvt/libraries/transfer-mechanisms Transfer mechanisms
GET /vvt/libraries/purposes Purposes (filter: typical_for)
GET /vvt/libraries/toms TOMs (filter: category)
Template endpoints:
GET /vvt/templates List templates (filter: business_function, search)
GET /vvt/templates/{id} Single template with resolved labels
POST /vvt/templates/{id}/instantiate Create VVT activity from template
"""
import logging
import uuid
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from ..db.vvt_library_models import (
VVTLibDataSubjectDB,
VVTLibDataCategoryDB,
VVTLibRecipientDB,
VVTLibLegalBasisDB,
VVTLibRetentionRuleDB,
VVTLibTransferMechanismDB,
VVTLibPurposeDB,
VVTLibTomDB,
VVTProcessTemplateDB,
)
from ..db.vvt_models import VVTActivityDB, VVTAuditLogDB
from .tenant_utils import get_tenant_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/vvt", tags=["compliance-vvt-libraries"])
# ============================================================================
# Helper: row → dict
# ============================================================================
def _row_to_dict(row, extra_fields=None):
"""Generic row → dict for library items."""
d = {
"id": row.id,
"label_de": row.label_de,
}
if hasattr(row, 'description_de') and row.description_de:
d["description_de"] = row.description_de
if hasattr(row, 'sort_order'):
d["sort_order"] = row.sort_order
if extra_fields:
for f in extra_fields:
if hasattr(row, f):
val = getattr(row, f)
if val is not None:
d[f] = val
return d
# ============================================================================
# Library Overview
# ============================================================================
@router.get("/libraries")
async def get_libraries_overview(db: Session = Depends(get_db)):
"""Overview of all library types with item counts."""
return {
"libraries": [
{"type": "data-subjects", "count": db.query(VVTLibDataSubjectDB).count()},
{"type": "data-categories", "count": db.query(VVTLibDataCategoryDB).count()},
{"type": "recipients", "count": db.query(VVTLibRecipientDB).count()},
{"type": "legal-bases", "count": db.query(VVTLibLegalBasisDB).count()},
{"type": "retention-rules", "count": db.query(VVTLibRetentionRuleDB).count()},
{"type": "transfer-mechanisms", "count": db.query(VVTLibTransferMechanismDB).count()},
{"type": "purposes", "count": db.query(VVTLibPurposeDB).count()},
{"type": "toms", "count": db.query(VVTLibTomDB).count()},
]
}
# ============================================================================
# Data Subjects
# ============================================================================
@router.get("/libraries/data-subjects")
async def list_data_subjects(
typical_for: Optional[str] = Query(None, description="Filter by business function"),
db: Session = Depends(get_db),
):
query = db.query(VVTLibDataSubjectDB).order_by(VVTLibDataSubjectDB.sort_order)
rows = query.all()
items = [_row_to_dict(r, ["art9_relevant", "typical_for"]) for r in rows]
if typical_for:
items = [i for i in items if typical_for in (i.get("typical_for") or [])]
return items
# ============================================================================
# Data Categories (hierarchical)
# ============================================================================
@router.get("/libraries/data-categories")
async def list_data_categories(
flat: Optional[bool] = Query(False, description="Return flat list instead of tree"),
parent_id: Optional[str] = Query(None),
is_art9: Optional[bool] = Query(None),
db: Session = Depends(get_db),
):
query = db.query(VVTLibDataCategoryDB).order_by(VVTLibDataCategoryDB.sort_order)
if parent_id is not None:
query = query.filter(VVTLibDataCategoryDB.parent_id == parent_id)
if is_art9 is not None:
query = query.filter(VVTLibDataCategoryDB.is_art9 == is_art9)
rows = query.all()
extra = ["parent_id", "is_art9", "is_art10", "risk_weight", "default_retention_rule", "default_legal_basis"]
items = [_row_to_dict(r, extra) for r in rows]
if flat or parent_id is not None or is_art9 is not None:
return items
# Build tree
by_parent: dict = {}
for item in items:
pid = item.get("parent_id")
by_parent.setdefault(pid, []).append(item)
tree = []
for item in by_parent.get(None, []):
children = by_parent.get(item["id"], [])
if children:
item["children"] = children
tree.append(item)
return tree
# ============================================================================
# Recipients
# ============================================================================
@router.get("/libraries/recipients")
async def list_recipients(
type: Optional[str] = Query(None, description="INTERNAL, PROCESSOR, CONTROLLER, AUTHORITY"),
db: Session = Depends(get_db),
):
query = db.query(VVTLibRecipientDB).order_by(VVTLibRecipientDB.sort_order)
if type:
query = query.filter(VVTLibRecipientDB.type == type)
rows = query.all()
return [_row_to_dict(r, ["type", "is_third_country", "country"]) for r in rows]
# ============================================================================
# Legal Bases
# ============================================================================
@router.get("/libraries/legal-bases")
async def list_legal_bases(
is_art9: Optional[bool] = Query(None),
type: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
query = db.query(VVTLibLegalBasisDB).order_by(VVTLibLegalBasisDB.sort_order)
if is_art9 is not None:
query = query.filter(VVTLibLegalBasisDB.is_art9 == is_art9)
if type:
query = query.filter(VVTLibLegalBasisDB.type == type)
rows = query.all()
return [_row_to_dict(r, ["article", "type", "is_art9", "typical_national_law"]) for r in rows]
# ============================================================================
# Retention Rules
# ============================================================================
@router.get("/libraries/retention-rules")
async def list_retention_rules(db: Session = Depends(get_db)):
rows = db.query(VVTLibRetentionRuleDB).order_by(VVTLibRetentionRuleDB.sort_order).all()
return [_row_to_dict(r, ["legal_basis", "duration", "duration_unit", "start_event", "deletion_procedure"]) for r in rows]
# ============================================================================
# Transfer Mechanisms
# ============================================================================
@router.get("/libraries/transfer-mechanisms")
async def list_transfer_mechanisms(db: Session = Depends(get_db)):
rows = db.query(VVTLibTransferMechanismDB).order_by(VVTLibTransferMechanismDB.sort_order).all()
return [_row_to_dict(r, ["article", "requires_tia"]) for r in rows]
# ============================================================================
# Purposes
# ============================================================================
@router.get("/libraries/purposes")
async def list_purposes(
typical_for: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
rows = db.query(VVTLibPurposeDB).order_by(VVTLibPurposeDB.sort_order).all()
items = [_row_to_dict(r, ["typical_legal_basis", "typical_for"]) for r in rows]
if typical_for:
items = [i for i in items if typical_for in (i.get("typical_for") or [])]
return items
# ============================================================================
# TOMs
# ============================================================================
@router.get("/libraries/toms")
async def list_toms(
category: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
query = db.query(VVTLibTomDB).order_by(VVTLibTomDB.sort_order)
if category:
query = query.filter(VVTLibTomDB.category == category)
rows = query.all()
return [_row_to_dict(r, ["category", "art32_reference"]) for r in rows]
# ============================================================================
# Process Templates
# ============================================================================
def _template_to_dict(t: VVTProcessTemplateDB) -> dict:
return {
"id": t.id,
"name": t.name,
"description": t.description,
"business_function": t.business_function,
"purpose_refs": t.purpose_refs or [],
"legal_basis_refs": t.legal_basis_refs or [],
"data_subject_refs": t.data_subject_refs or [],
"data_category_refs": t.data_category_refs or [],
"recipient_refs": t.recipient_refs or [],
"tom_refs": t.tom_refs or [],
"transfer_mechanism_refs": t.transfer_mechanism_refs or [],
"retention_rule_ref": t.retention_rule_ref,
"typical_systems": t.typical_systems or [],
"protection_level": t.protection_level or "MEDIUM",
"dpia_required": t.dpia_required or False,
"risk_score": t.risk_score,
"tags": t.tags or [],
"is_system": t.is_system,
"sort_order": t.sort_order,
}
def _resolve_labels(template_dict: dict, db: Session) -> dict:
"""Resolve library IDs to labels within the template dict."""
resolvers = {
"purpose_refs": (VVTLibPurposeDB, "purpose_labels"),
"legal_basis_refs": (VVTLibLegalBasisDB, "legal_basis_labels"),
"data_subject_refs": (VVTLibDataSubjectDB, "data_subject_labels"),
"data_category_refs": (VVTLibDataCategoryDB, "data_category_labels"),
"recipient_refs": (VVTLibRecipientDB, "recipient_labels"),
"tom_refs": (VVTLibTomDB, "tom_labels"),
"transfer_mechanism_refs": (VVTLibTransferMechanismDB, "transfer_mechanism_labels"),
}
for refs_key, (model, labels_key) in resolvers.items():
ids = template_dict.get(refs_key) or []
if ids:
rows = db.query(model).filter(model.id.in_(ids)).all()
label_map = {r.id: r.label_de for r in rows}
template_dict[labels_key] = {rid: label_map.get(rid, rid) for rid in ids}
# Resolve single retention rule
rr = template_dict.get("retention_rule_ref")
if rr:
row = db.query(VVTLibRetentionRuleDB).filter(VVTLibRetentionRuleDB.id == rr).first()
if row:
template_dict["retention_rule_label"] = row.label_de
return template_dict
@router.get("/templates")
async def list_templates(
business_function: Optional[str] = Query(None),
search: Optional[str] = Query(None),
db: Session = Depends(get_db),
):
"""List process templates (system + tenant)."""
query = db.query(VVTProcessTemplateDB).order_by(VVTProcessTemplateDB.sort_order)
if business_function:
query = query.filter(VVTProcessTemplateDB.business_function == business_function)
if search:
term = f"%{search}%"
query = query.filter(
(VVTProcessTemplateDB.name.ilike(term)) |
(VVTProcessTemplateDB.description.ilike(term))
)
templates = query.all()
return [_template_to_dict(t) for t in templates]
@router.get("/templates/{template_id}")
async def get_template(
template_id: str,
db: Session = Depends(get_db),
):
"""Get a single template with resolved library labels."""
t = db.query(VVTProcessTemplateDB).filter(VVTProcessTemplateDB.id == template_id).first()
if not t:
raise HTTPException(status_code=404, detail=f"Template '{template_id}' not found")
result = _template_to_dict(t)
return _resolve_labels(result, db)
@router.post("/templates/{template_id}/instantiate", status_code=201)
async def instantiate_template(
template_id: str,
http_request: Request,
tid: str = Depends(get_tenant_id),
db: Session = Depends(get_db),
):
"""Create a new VVT activity from a process template."""
t = db.query(VVTProcessTemplateDB).filter(VVTProcessTemplateDB.id == template_id).first()
if not t:
raise HTTPException(status_code=404, detail=f"Template '{template_id}' not found")
# Generate unique VVT-ID
count = db.query(VVTActivityDB).filter(VVTActivityDB.tenant_id == tid).count()
vvt_id = f"VVT-{count + 1:04d}"
# Resolve library IDs to freetext labels for backward-compat fields
purpose_labels = _resolve_ids(db, VVTLibPurposeDB, t.purpose_refs or [])
legal_labels = _resolve_ids(db, VVTLibLegalBasisDB, t.legal_basis_refs or [])
subject_labels = _resolve_ids(db, VVTLibDataSubjectDB, t.data_subject_refs or [])
category_labels = _resolve_ids(db, VVTLibDataCategoryDB, t.data_category_refs or [])
recipient_labels = _resolve_ids(db, VVTLibRecipientDB, t.recipient_refs or [])
# Resolve retention rule
retention_period = {}
if t.retention_rule_ref:
rr = db.query(VVTLibRetentionRuleDB).filter(VVTLibRetentionRuleDB.id == t.retention_rule_ref).first()
if rr:
retention_period = {
"description": rr.label_de,
"legalBasis": rr.legal_basis or "",
"deletionProcedure": rr.deletion_procedure or "",
"duration": rr.duration,
"durationUnit": rr.duration_unit,
}
# Build structured TOMs from tom_refs
structured_toms = {"accessControl": [], "confidentiality": [], "integrity": [], "availability": [], "separation": []}
if t.tom_refs:
tom_rows = db.query(VVTLibTomDB).filter(VVTLibTomDB.id.in_(t.tom_refs)).all()
for tr in tom_rows:
cat = tr.category
if cat in structured_toms:
structured_toms[cat].append(tr.label_de)
act = VVTActivityDB(
tenant_id=tid,
vvt_id=vvt_id,
name=t.name,
description=t.description or "",
purposes=purpose_labels,
legal_bases=[{"type": lid, "description": lbl} for lid, lbl in zip(t.legal_basis_refs or [], legal_labels)],
data_subject_categories=subject_labels,
personal_data_categories=category_labels,
recipient_categories=[{"type": "unknown", "name": lbl} for lbl in recipient_labels],
retention_period=retention_period,
business_function=t.business_function,
systems=[{"systemId": s, "name": s} for s in (t.typical_systems or [])],
protection_level=t.protection_level or "MEDIUM",
dpia_required=t.dpia_required or False,
structured_toms=structured_toms,
status="DRAFT",
created_by=http_request.headers.get("X-User-ID", "system"),
# Library refs
purpose_refs=t.purpose_refs,
legal_basis_refs=t.legal_basis_refs,
data_subject_refs=t.data_subject_refs,
data_category_refs=t.data_category_refs,
recipient_refs=t.recipient_refs,
retention_rule_ref=t.retention_rule_ref,
transfer_mechanism_refs=t.transfer_mechanism_refs,
tom_refs=t.tom_refs,
source_template_id=t.id,
risk_score=t.risk_score,
)
db.add(act)
db.flush()
# Audit log
audit = VVTAuditLogDB(
tenant_id=tid,
action="CREATE",
entity_type="activity",
entity_id=act.id,
changed_by=http_request.headers.get("X-User-ID", "system"),
new_values={"vvt_id": vvt_id, "source_template_id": t.id, "name": t.name},
)
db.add(audit)
db.commit()
db.refresh(act)
# Return full response
from .vvt_routes import _activity_to_response
return _activity_to_response(act)
def _resolve_ids(db: Session, model, ids: list) -> list:
"""Resolve list of library IDs to list of label_de strings."""
if not ids:
return []
rows = db.query(model).filter(model.id.in_(ids)).all()
label_map = {r.id: r.label_de for r in rows}
return [label_map.get(i, i) for i in ids]
@@ -81,6 +81,54 @@ async def upsert_organization(
# Activities
# ============================================================================
def _activity_to_response(act: VVTActivityDB) -> VVTActivityResponse:
return VVTActivityResponse(
id=str(act.id),
vvt_id=act.vvt_id,
name=act.name,
description=act.description,
purposes=act.purposes or [],
legal_bases=act.legal_bases or [],
data_subject_categories=act.data_subject_categories or [],
personal_data_categories=act.personal_data_categories or [],
recipient_categories=act.recipient_categories or [],
third_country_transfers=act.third_country_transfers or [],
retention_period=act.retention_period or {},
tom_description=act.tom_description,
business_function=act.business_function,
systems=act.systems or [],
deployment_model=act.deployment_model,
data_sources=act.data_sources or [],
data_flows=act.data_flows or [],
protection_level=act.protection_level or 'MEDIUM',
dpia_required=act.dpia_required or False,
structured_toms=act.structured_toms or {},
status=act.status or 'DRAFT',
responsible=act.responsible,
owner=act.owner,
last_reviewed_at=act.last_reviewed_at,
next_review_at=act.next_review_at,
created_by=act.created_by,
dsfa_id=str(act.dsfa_id) if act.dsfa_id else None,
# Library refs
purpose_refs=act.purpose_refs,
legal_basis_refs=act.legal_basis_refs,
data_subject_refs=act.data_subject_refs,
data_category_refs=act.data_category_refs,
recipient_refs=act.recipient_refs,
retention_rule_ref=act.retention_rule_ref,
transfer_mechanism_refs=act.transfer_mechanism_refs,
tom_refs=act.tom_refs,
source_template_id=act.source_template_id,
risk_score=act.risk_score,
linked_loeschfristen_ids=act.linked_loeschfristen_ids,
linked_tom_measure_ids=act.linked_tom_measure_ids,
art30_completeness=act.art30_completeness,
created_at=act.created_at,
updated_at=act.updated_at,
)
@router.get("/activities", response_model=List[VVTActivityResponse])
async def list_activities(
status: Optional[str] = Query(None),
@@ -145,6 +193,107 @@ async def delete_activity(
return service.delete_activity(tid, activity_id)
# ============================================================================
# Art. 30 Completeness Check
# ============================================================================
@router.get("/activities/{activity_id}/completeness")
async def get_activity_completeness(
activity_id: str,
tid: str = Depends(get_tenant_id),
db: Session = Depends(get_db),
):
"""Calculate Art. 30 completeness score for a VVT activity."""
act = db.query(VVTActivityDB).filter(
VVTActivityDB.id == activity_id,
VVTActivityDB.tenant_id == tid,
).first()
if not act:
raise HTTPException(status_code=404, detail=f"Activity {activity_id} not found")
return _calculate_completeness(act)
def _calculate_completeness(act: VVTActivityDB) -> dict:
"""Calculate Art. 30 completeness — required fields per DSGVO Art. 30 Abs. 1."""
missing = []
warnings = []
total_checks = 10
passed = 0
# 1. Name/Zweck
if act.name:
passed += 1
else:
missing.append("name")
# 2. Verarbeitungszwecke
has_purposes = bool(act.purposes) or bool(act.purpose_refs)
if has_purposes:
passed += 1
else:
missing.append("purposes")
# 3. Rechtsgrundlage
has_legal = bool(act.legal_bases) or bool(act.legal_basis_refs)
if has_legal:
passed += 1
else:
missing.append("legal_bases")
# 4. Betroffenenkategorien
has_subjects = bool(act.data_subject_categories) or bool(act.data_subject_refs)
if has_subjects:
passed += 1
else:
missing.append("data_subjects")
# 5. Datenkategorien
has_categories = bool(act.personal_data_categories) or bool(act.data_category_refs)
if has_categories:
passed += 1
else:
missing.append("data_categories")
# 6. Empfaenger
has_recipients = bool(act.recipient_categories) or bool(act.recipient_refs)
if has_recipients:
passed += 1
else:
missing.append("recipients")
# 7. Drittland-Uebermittlung (checked but not strictly required)
passed += 1 # always passes — no transfer is valid state
# 8. Loeschfristen
has_retention = bool(act.retention_period and act.retention_period.get('description')) or bool(act.retention_rule_ref)
if has_retention:
passed += 1
else:
missing.append("retention_period")
# 9. TOM-Beschreibung
has_tom = bool(act.tom_description) or bool(act.tom_refs) or bool(act.structured_toms)
if has_tom:
passed += 1
else:
missing.append("tom_description")
# 10. Verantwortlicher
if act.responsible:
passed += 1
else:
missing.append("responsible")
# Warnings
if act.dpia_required and not act.dsfa_id:
warnings.append("dpia_required_but_no_dsfa_linked")
if act.third_country_transfers and not act.transfer_mechanism_refs:
warnings.append("third_country_transfer_without_mechanism")
score = int((passed / total_checks) * 100)
return {"score": score, "missing": missing, "warnings": warnings, "passed": passed, "total": total_checks}
# ============================================================================
# Audit Log
# ============================================================================