""" Control Dependency Engine — evaluates control statuses considering inter-control dependencies. Pure functions (no DB coupling) for: - Generic condition evaluation (JSONB rules -> bool) - Effect application (modifies target status) - Cycle detection (DFS-based) - Topological sort (evaluation order) - Full evaluation resolution with priority-based conflict handling DB interaction is in separate load/store functions at the bottom. """ from __future__ import annotations import json import logging import uuid from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from typing import Any, Optional from sqlalchemy import text logger = logging.getLogger(__name__) # ============================================================================ # ENUMS # ============================================================================ class DependencyType(str, Enum): PREREQUISITE = "prerequisite" CONDITIONAL_REQUIREMENT = "conditional_requirement" SUPERSEDES = "supersedes" COMPENSATING_CONTROL = "compensating_control" SCOPE_EXCLUSION = "scope_exclusion" class EvaluationStatus(str, Enum): PASS = "pass" FAIL = "fail" NOT_APPLICABLE = "not_applicable" PARTIALLY_SATISFIED = "partially_satisfied" COMPENSATED_FAIL = "compensated_fail" REVIEW_REQUIRED = "review_required" # Default priority per dependency type (lower = higher priority) DEFAULT_PRIORITIES: dict[str, int] = { "supersedes": 10, "scope_exclusion": 20, "prerequisite": 50, "conditional_requirement": 70, "compensating_control": 80, } # ============================================================================ # DATA CLASSES # ============================================================================ @dataclass class Dependency: id: str = "" source_control_id: str = "" target_control_id: str = "" dependency_type: str = "prerequisite" condition: dict = field(default_factory=dict) effect: dict = field(default_factory=dict) priority: int = 100 generation_method: str = "manual" is_active: bool = True @dataclass class ControlState: """In-memory representation of a control's evaluation state.""" control_id: str = "" raw_status: str = "fail" resolved_status: str = "" context: dict = field(default_factory=dict) @dataclass class EvaluationResult: control_id: str = "" evaluation_run_id: str = "" raw_status: str = "fail" resolved_status: str = "fail" dependency_resolution: list = field(default_factory=list) confidence: float = 1.0 reasoning: str = "" # ============================================================================ # CONDITION EVALUATOR # ============================================================================ def _resolve_field(field_path: str, context: dict) -> Any: """Resolve a dot-notation field path against a nested dict. Examples: _resolve_field("source.status", {"source": {"status": "pass"}}) -> "pass" _resolve_field("context.company_size", {"context": {"company_size": "large"}}) -> "large" """ parts = field_path.split(".") current = context for part in parts: if isinstance(current, dict): current = current.get(part) else: return None return current def _evaluate_single_clause(clause: dict, context: dict) -> bool: """Evaluate a single {field, op, value} clause.""" field_path = clause.get("field", "") op = clause.get("op", "==") expected = clause.get("value") actual = _resolve_field(field_path, context) if op == "==": return actual == expected elif op == "!=": return actual != expected elif op == "in": if isinstance(expected, list): return actual in expected return False elif op == "not_in": if isinstance(expected, list): return actual not in expected return True elif op == ">": try: return float(actual) > float(expected) except (TypeError, ValueError): return False elif op == "<": try: return float(actual) < float(expected) except (TypeError, ValueError): return False elif op == ">=": try: return float(actual) >= float(expected) except (TypeError, ValueError): return False elif op == "<=": try: return float(actual) <= float(expected) except (TypeError, ValueError): return False elif op == "contains": if isinstance(actual, (list, set, tuple)): return expected in actual if isinstance(actual, str): return str(expected) in actual return False return False def evaluate_condition(condition: dict, context: dict) -> bool: """Evaluate a generic condition against a context dict. Supports: - Empty condition -> True (always matches) - Simple clause: {"field": "source.status", "op": "==", "value": "pass"} - Compound AND: {"operator": "AND", "clauses": [...]} - Compound OR: {"operator": "OR", "clauses": [...]} - Negation: {"operator": "NOT", "clause": {...}} """ if not condition: return True operator = condition.get("operator") if operator == "AND": clauses = condition.get("clauses", []) return all(evaluate_condition(c, context) for c in clauses) if operator == "OR": clauses = condition.get("clauses", []) return any(evaluate_condition(c, context) for c in clauses) if operator == "NOT": clause = condition.get("clause", {}) return not evaluate_condition(clause, context) # Simple clause with field/op/value if "field" in condition: return _evaluate_single_clause(condition, context) return True # ============================================================================ # EFFECT APPLIER # ============================================================================ def apply_effect(effect: dict, current_status: str) -> str: """Apply a dependency effect to produce a new status. Effect schema: {"set_status": "not_applicable"} {"set_status": "compensated_fail"} """ new_status = effect.get("set_status") if new_status and new_status in {s.value for s in EvaluationStatus}: return new_status return current_status # ============================================================================ # CYCLE DETECTION # ============================================================================ WHITE, GRAY, BLACK = 0, 1, 2 def detect_cycles(dependencies: list[Dependency]) -> list[list[str]]: """Detect cycles in the dependency graph using DFS. Returns list of cycles (each cycle = list of control IDs). Empty list = no cycles. """ graph: dict[str, list[str]] = defaultdict(list) all_nodes: set[str] = set() for dep in dependencies: if dep.is_active: graph[dep.source_control_id].append(dep.target_control_id) all_nodes.add(dep.source_control_id) all_nodes.add(dep.target_control_id) color: dict[str, int] = {n: WHITE for n in all_nodes} parent: dict[str, Optional[str]] = {n: None for n in all_nodes} cycles: list[list[str]] = [] def dfs(node: str) -> None: color[node] = GRAY for neighbor in graph.get(node, []): if color.get(neighbor, WHITE) == GRAY: # Found a cycle — trace back cycle = [neighbor, node] current = parent.get(node) while current and current != neighbor: cycle.append(current) current = parent.get(current) cycles.append(cycle) elif color.get(neighbor, WHITE) == WHITE: parent[neighbor] = node dfs(neighbor) color[node] = BLACK for node in all_nodes: if color[node] == WHITE: dfs(node) return cycles def topological_sort(dependencies: list[Dependency]) -> list[str]: """Return control IDs in dependency-safe evaluation order. Sources (prerequisites) come before targets (dependents). Controls not involved in any dependency are omitted. """ graph: dict[str, list[str]] = defaultdict(list) in_degree: dict[str, int] = defaultdict(int) all_nodes: set[str] = set() for dep in dependencies: if dep.is_active: # source -> target means: source should be evaluated first graph[dep.source_control_id].append(dep.target_control_id) in_degree.setdefault(dep.source_control_id, 0) in_degree[dep.target_control_id] = in_degree.get(dep.target_control_id, 0) + 1 all_nodes.add(dep.source_control_id) all_nodes.add(dep.target_control_id) # Kahn's algorithm queue = [n for n in all_nodes if in_degree.get(n, 0) == 0] result: list[str] = [] while queue: queue.sort() # deterministic order node = queue.pop(0) result.append(node) for neighbor in graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) return result # ============================================================================ # MAIN EVALUATION ENGINE # ============================================================================ def evaluate_controls( control_states: dict[str, ControlState], dependencies: list[Dependency], context: dict, ) -> dict[str, EvaluationResult]: """Evaluate all controls considering dependencies. Args: control_states: control_id -> ControlState (with raw_status) dependencies: all active dependencies context: company profile (industry, company_size, scope_signals, etc.) Returns: control_id -> EvaluationResult (with resolved_status + trace) Algorithm: 1. Build adjacency (target -> dependencies) 2. Detect cycles -> involved controls = review_required 3. Topological sort for evaluation order 4. For each control: evaluate conditions, apply highest-priority effect 5. Record full dependency trace for MCP output """ evaluation_run_id = str(uuid.uuid4()) # 1. Build adjacency: target_control_id -> list of dependencies target_deps: dict[str, list[Dependency]] = defaultdict(list) for dep in dependencies: if dep.is_active: target_deps[dep.target_control_id].append(dep) # 2. Cycle detection cycles = detect_cycles(dependencies) cycle_controls: set[str] = set() for cycle in cycles: cycle_controls.update(cycle) # 3. Topological sort (excluding cycle controls) safe_deps = [ d for d in dependencies if d.is_active and d.source_control_id not in cycle_controls and d.target_control_id not in cycle_controls ] eval_order = topological_sort(safe_deps) # Add remaining controls (those not in any dependency + cycle controls) all_ids = set(control_states.keys()) remaining = all_ids - set(eval_order) eval_order.extend(sorted(remaining)) # 4. Iterate and evaluate results: dict[str, EvaluationResult] = {} for control_id in eval_order: state = control_states.get(control_id) if not state: continue # Cycle controls -> review_required if control_id in cycle_controls: results[control_id] = EvaluationResult( control_id=control_id, evaluation_run_id=evaluation_run_id, raw_status=state.raw_status, resolved_status=EvaluationStatus.REVIEW_REQUIRED.value, dependency_resolution=[{"cycle_detected": True}], confidence=0.5, reasoning="Zyklische Abhaengigkeit erkannt — manuelle Pruefung erforderlich.", ) continue # Collect dependencies targeting this control deps_for_target = target_deps.get(control_id, []) if not deps_for_target: results[control_id] = EvaluationResult( control_id=control_id, evaluation_run_id=evaluation_run_id, raw_status=state.raw_status, resolved_status=state.raw_status, confidence=1.0, ) continue # Evaluate each dependency's condition matching_effects: list[tuple[int, dict, Dependency]] = [] trace: list[dict] = [] for dep in sorted(deps_for_target, key=lambda d: d.priority): source_state = control_states.get(dep.source_control_id) source_result = results.get(dep.source_control_id) source_status = "unknown" if source_result: source_status = source_result.resolved_status elif source_state: source_status = source_state.raw_status eval_ctx = { "source": {"status": source_status}, "target": {"status": state.raw_status}, "context": context, } condition_met = evaluate_condition(dep.condition, eval_ctx) trace_entry = { "dependency_id": dep.id, "dependency_type": dep.dependency_type, "source_control_id": dep.source_control_id, "source_status": source_status, "condition_met": condition_met, "effect_applied": dep.effect if condition_met else None, "priority": dep.priority, } trace.append(trace_entry) if condition_met: matching_effects.append((dep.priority, dep.effect, dep)) # Apply highest-priority (lowest number) effect resolved = state.raw_status if matching_effects: matching_effects.sort(key=lambda x: x[0]) _, best_effect, _ = matching_effects[0] resolved = apply_effect(best_effect, state.raw_status) results[control_id] = EvaluationResult( control_id=control_id, evaluation_run_id=evaluation_run_id, raw_status=state.raw_status, resolved_status=resolved, dependency_resolution=trace, confidence=_compute_confidence(trace), ) return results def _compute_confidence(trace: list[dict]) -> float: """Compute confidence based on dependency resolution trace.""" if not trace: return 1.0 met_count = sum(1 for t in trace if t.get("condition_met")) total = len(trace) if met_count == 0: return 1.0 # No dependencies fired -> raw status stands if met_count == 1: return 0.95 # Single dependency resolved # Multiple dependencies -> slightly lower confidence return max(0.7, 1.0 - (met_count - 1) * 0.1) # ============================================================================ # DB INTERACTION (separate from pure logic) # ============================================================================ def load_dependencies_for_controls( db, control_ids: list[str], ) -> list[Dependency]: """Load all active dependencies involving the given control IDs.""" if not control_ids: return [] rows = db.execute( text(""" SELECT id, source_control_id, target_control_id, dependency_type, condition, effect, priority, generation_method, is_active FROM control_dependencies WHERE is_active = TRUE AND (source_control_id = ANY(CAST(:ids AS uuid[])) OR target_control_id = ANY(CAST(:ids AS uuid[]))) """), {"ids": control_ids}, ).fetchall() return [ Dependency( id=str(r[0]), source_control_id=str(r[1]), target_control_id=str(r[2]), dependency_type=r[3], condition=r[4] if isinstance(r[4], dict) else {}, effect=r[5] if isinstance(r[5], dict) else {}, priority=r[6], generation_method=r[7], is_active=r[8], ) for r in rows ] def load_all_active_dependencies(db) -> list[Dependency]: """Load all active dependencies.""" rows = db.execute( text(""" SELECT id, source_control_id, target_control_id, dependency_type, condition, effect, priority, generation_method, is_active FROM control_dependencies WHERE is_active = TRUE ORDER BY priority """), ).fetchall() return [ Dependency( id=str(r[0]), source_control_id=str(r[1]), target_control_id=str(r[2]), dependency_type=r[3], condition=r[4] if isinstance(r[4], dict) else {}, effect=r[5] if isinstance(r[5], dict) else {}, priority=r[6], generation_method=r[7], is_active=r[8], ) for r in rows ] def store_dependency(db, dep: Dependency) -> str: """Insert a dependency, return its UUID.""" row = db.execute( text(""" INSERT INTO control_dependencies (source_control_id, target_control_id, dependency_type, condition, effect, priority, generation_method, generation_metadata) VALUES (CAST(:src AS uuid), CAST(:tgt AS uuid), :dtype, CAST(:cond AS jsonb), CAST(:eff AS jsonb), :prio, :gmethod, CAST(:gmeta AS jsonb)) ON CONFLICT (source_control_id, target_control_id, dependency_type) DO UPDATE SET condition = EXCLUDED.condition, effect = EXCLUDED.effect, priority = EXCLUDED.priority, updated_at = NOW() RETURNING id::text """), { "src": dep.source_control_id, "tgt": dep.target_control_id, "dtype": dep.dependency_type, "cond": json.dumps(dep.condition), "eff": json.dumps(dep.effect), "prio": dep.priority, "gmethod": dep.generation_method, "gmeta": json.dumps({}), }, ).fetchone() return row[0] if row else "" def store_evaluation_results( db, results: dict[str, EvaluationResult], company_profile: dict, ) -> int: """Batch insert evaluation results. Returns row count.""" count = 0 for result in results.values(): db.execute( text(""" INSERT INTO control_evaluation_results (control_id, evaluation_run_id, company_profile, raw_status, resolved_status, dependency_resolution, confidence, reasoning) VALUES (CAST(:cid AS uuid), CAST(:rid AS uuid), CAST(:prof AS jsonb), :raw, :resolved, CAST(:trace AS jsonb), :conf, :reason) ON CONFLICT (control_id, evaluation_run_id) DO UPDATE SET resolved_status = EXCLUDED.resolved_status, dependency_resolution = EXCLUDED.dependency_resolution, confidence = EXCLUDED.confidence """), { "cid": result.control_id, "rid": result.evaluation_run_id, "prof": json.dumps(company_profile), "raw": result.raw_status, "resolved": result.resolved_status, "trace": json.dumps(result.dependency_resolution), "conf": result.confidence, "reason": result.reasoning, }, ) count += 1 return count