""" Compliance Pipeline Execution. Pipeline phases (ingestion, extraction, control generation, measures) and orchestration logic. """ import asyncio import json import logging import os import sys import time from datetime import datetime from typing import Dict, List, Any from dataclasses import asdict from compliance_models import Checkpoint, Control, Measure from compliance_extraction import ( extract_checkpoints_from_chunk, generate_control_for_checkpoints, generate_measure_for_control, ) logger = logging.getLogger(__name__) # Import checkpoint manager try: from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus except ImportError: logger.warning("Checkpoint manager not available, running without checkpoints") CheckpointManager = None EXPECTED_VALUES = {} ValidationStatus = None # Set environment variables for Docker network if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"): os.environ["QDRANT_HOST"] = "qdrant" os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087") # Try to import from klausur-service try: from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION from qdrant_client import QdrantClient from qdrant_client.models import Filter, FieldCondition, MatchValue except ImportError: logger.error("Could not import required modules. Make sure you're in the klausur-service container.") sys.exit(1) class CompliancePipeline: """Handles the full compliance pipeline.""" def __init__(self): # Support both QDRANT_URL and QDRANT_HOST/PORT qdrant_url = os.getenv("QDRANT_URL", "") if qdrant_url: from urllib.parse import urlparse parsed = urlparse(qdrant_url) qdrant_host = parsed.hostname or "qdrant" qdrant_port = parsed.port or 6333 else: qdrant_host = os.getenv("QDRANT_HOST", "qdrant") qdrant_port = 6333 self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port) self.checkpoints: List[Checkpoint] = [] self.controls: List[Control] = [] self.measures: List[Measure] = [] self.stats = { "chunks_processed": 0, "checkpoints_extracted": 0, "controls_created": 0, "measures_defined": 0, "by_regulation": {}, "by_domain": {}, } # Initialize checkpoint manager self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None async def run_ingestion_phase(self, force_reindex: bool = False) -> int: """Phase 1: Ingest documents (incremental - only missing ones).""" logger.info("\n" + "=" * 60) logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)") logger.info("=" * 60) if self.checkpoint_mgr: self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion") ingestion = LegalCorpusIngestion() try: # Check existing chunks per regulation existing_chunks = {} try: for regulation in REGULATIONS: count_result = self.qdrant.count( collection_name=LEGAL_CORPUS_COLLECTION, count_filter=Filter( must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))] ) ) existing_chunks[regulation.code] = count_result.count logger.info(f" {regulation.code}: {count_result.count} existing chunks") except Exception as e: logger.warning(f"Could not check existing chunks: {e}") # Determine which regulations need ingestion regulations_to_ingest = [] for regulation in REGULATIONS: existing = existing_chunks.get(regulation.code, 0) if force_reindex or existing == 0: regulations_to_ingest.append(regulation) logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})") else: logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)") self.stats["by_regulation"][regulation.code] = existing if not regulations_to_ingest: logger.info("All regulations already indexed. Skipping ingestion phase.") total_chunks = sum(existing_chunks.values()) self.stats["chunks_processed"] = total_chunks if self.checkpoint_mgr: self.checkpoint_mgr.add_metric("total_chunks", total_chunks) self.checkpoint_mgr.add_metric("skipped", True) self.checkpoint_mgr.complete_checkpoint(success=True) return total_chunks # Ingest only missing regulations total_chunks = sum(existing_chunks.values()) for i, regulation in enumerate(regulations_to_ingest, 1): logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...") try: count = await ingestion.ingest_regulation(regulation) total_chunks += count self.stats["by_regulation"][regulation.code] = count logger.info(f" -> {count} chunks") if self.checkpoint_mgr: self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count) except Exception as e: logger.error(f" -> FAILED: {e}") self.stats["by_regulation"][regulation.code] = 0 self.stats["chunks_processed"] = total_chunks logger.info(f"\nTotal chunks in collection: {total_chunks}") # Validate ingestion results if self.checkpoint_mgr: self.checkpoint_mgr.add_metric("total_chunks", total_chunks) self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS)) expected = EXPECTED_VALUES.get("ingestion", {}) self.checkpoint_mgr.validate( "total_chunks", expected=expected.get("total_chunks", 8000), actual=total_chunks, min_value=expected.get("min_chunks", 7000) ) reg_expected = expected.get("regulations", {}) for reg_code, reg_exp in reg_expected.items(): actual = self.stats["by_regulation"].get(reg_code, 0) self.checkpoint_mgr.validate( f"chunks_{reg_code}", expected=reg_exp.get("expected", 0), actual=actual, min_value=reg_exp.get("min", 0) ) self.checkpoint_mgr.complete_checkpoint(success=True) return total_chunks except Exception as e: if self.checkpoint_mgr: self.checkpoint_mgr.fail_checkpoint(str(e)) raise finally: await ingestion.close() async def run_extraction_phase(self) -> int: """Phase 2: Extract checkpoints from chunks.""" logger.info("\n" + "=" * 60) logger.info("PHASE 2: CHECKPOINT EXTRACTION") logger.info("=" * 60) if self.checkpoint_mgr: self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction") try: offset = None total_checkpoints = 0 while True: result = self.qdrant.scroll( collection_name=LEGAL_CORPUS_COLLECTION, limit=100, offset=offset, with_payload=True, with_vectors=False ) points, next_offset = result if not points: break for point in points: payload = point.payload text = payload.get("text", "") cps = extract_checkpoints_from_chunk(text, payload) self.checkpoints.extend(cps) total_checkpoints += len(cps) logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...") if next_offset is None: break offset = next_offset self.stats["checkpoints_extracted"] = len(self.checkpoints) logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}") by_reg = {} for cp in self.checkpoints: by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1 for reg, count in sorted(by_reg.items()): logger.info(f" {reg}: {count} checkpoints") if self.checkpoint_mgr: self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints)) self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg) expected = EXPECTED_VALUES.get("extraction", {}) self.checkpoint_mgr.validate( "total_checkpoints", expected=expected.get("total_checkpoints", 3500), actual=len(self.checkpoints), min_value=expected.get("min_checkpoints", 3000) ) self.checkpoint_mgr.complete_checkpoint(success=True) return len(self.checkpoints) except Exception as e: if self.checkpoint_mgr: self.checkpoint_mgr.fail_checkpoint(str(e)) raise async def run_control_generation_phase(self) -> int: """Phase 3: Generate controls from checkpoints.""" logger.info("\n" + "=" * 60) logger.info("PHASE 3: CONTROL GENERATION") logger.info("=" * 60) if self.checkpoint_mgr: self.checkpoint_mgr.start_checkpoint("controls", "Control Generation") try: # Group checkpoints by regulation by_regulation: Dict[str, List[Checkpoint]] = {} for cp in self.checkpoints: reg = cp.regulation_code if reg not in by_regulation: by_regulation[reg] = [] by_regulation[reg].append(cp) # Generate controls per regulation (group every 3-5 checkpoints) for regulation, checkpoints in by_regulation.items(): logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...") batch_size = 4 for i in range(0, len(checkpoints), batch_size): batch = checkpoints[i:i + batch_size] control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {})) if control: self.controls.append(control) self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1 self.stats["controls_created"] = len(self.controls) logger.info(f"\nTotal controls created: {len(self.controls)}") for domain, count in sorted(self.stats["by_domain"].items()): logger.info(f" {domain}: {count} controls") if self.checkpoint_mgr: self.checkpoint_mgr.add_metric("total_controls", len(self.controls)) self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"])) expected = EXPECTED_VALUES.get("controls", {}) self.checkpoint_mgr.validate( "total_controls", expected=expected.get("total_controls", 900), actual=len(self.controls), min_value=expected.get("min_controls", 800) ) self.checkpoint_mgr.complete_checkpoint(success=True) return len(self.controls) except Exception as e: if self.checkpoint_mgr: self.checkpoint_mgr.fail_checkpoint(str(e)) raise async def run_measure_generation_phase(self) -> int: """Phase 4: Generate measures for controls.""" logger.info("\n" + "=" * 60) logger.info("PHASE 4: MEASURE GENERATION") logger.info("=" * 60) if self.checkpoint_mgr: self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation") try: for control in self.controls: measure = generate_measure_for_control(control) self.measures.append(measure) self.stats["measures_defined"] = len(self.measures) logger.info(f"\nTotal measures defined: {len(self.measures)}") if self.checkpoint_mgr: self.checkpoint_mgr.add_metric("total_measures", len(self.measures)) expected = EXPECTED_VALUES.get("measures", {}) self.checkpoint_mgr.validate( "total_measures", expected=expected.get("total_measures", 900), actual=len(self.measures), min_value=expected.get("min_measures", 800) ) self.checkpoint_mgr.complete_checkpoint(success=True) return len(self.measures) except Exception as e: if self.checkpoint_mgr: self.checkpoint_mgr.fail_checkpoint(str(e)) raise def save_results(self, output_dir: str = "/tmp/compliance_output"): """Save results to JSON files.""" logger.info("\n" + "=" * 60) logger.info("SAVING RESULTS") logger.info("=" * 60) os.makedirs(output_dir, exist_ok=True) checkpoints_file = os.path.join(output_dir, "checkpoints.json") with open(checkpoints_file, "w") as f: json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False) logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}") controls_file = os.path.join(output_dir, "controls.json") with open(controls_file, "w") as f: json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False) logger.info(f"Saved {len(self.controls)} controls to {controls_file}") measures_file = os.path.join(output_dir, "measures.json") with open(measures_file, "w") as f: json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False) logger.info(f"Saved {len(self.measures)} measures to {measures_file}") stats_file = os.path.join(output_dir, "statistics.json") self.stats["generated_at"] = datetime.now().isoformat() with open(stats_file, "w") as f: json.dump(self.stats, f, indent=2, ensure_ascii=False) logger.info(f"Saved statistics to {stats_file}") async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False): """Run the complete pipeline. Args: force_reindex: If True, re-ingest all documents even if they exist skip_ingestion: If True, skip ingestion phase entirely (use existing chunks) """ start_time = time.time() logger.info("=" * 60) logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)") logger.info(f"Started at: {datetime.now().isoformat()}") logger.info(f"Force reindex: {force_reindex}") logger.info(f"Skip ingestion: {skip_ingestion}") if self.checkpoint_mgr: logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}") logger.info("=" * 60) try: if skip_ingestion: logger.info("Skipping ingestion phase as requested...") try: collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION) self.stats["chunks_processed"] = collection_info.points_count except Exception: self.stats["chunks_processed"] = 0 else: await self.run_ingestion_phase(force_reindex=force_reindex) await self.run_extraction_phase() await self.run_control_generation_phase() await self.run_measure_generation_phase() self.save_results() elapsed = time.time() - start_time logger.info("\n" + "=" * 60) logger.info("PIPELINE COMPLETE") logger.info("=" * 60) logger.info(f"Duration: {elapsed:.1f} seconds") logger.info(f"Chunks processed: {self.stats['chunks_processed']}") logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}") logger.info(f"Controls created: {self.stats['controls_created']}") logger.info(f"Measures defined: {self.stats['measures_defined']}") logger.info(f"\nResults saved to: /tmp/compliance_output/") logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json") logger.info("=" * 60) if self.checkpoint_mgr: self.checkpoint_mgr.complete_pipeline({ "duration_seconds": elapsed, "chunks_processed": self.stats['chunks_processed'], "checkpoints_extracted": self.stats['checkpoints_extracted'], "controls_created": self.stats['controls_created'], "measures_defined": self.stats['measures_defined'], "by_regulation": self.stats['by_regulation'], "by_domain": self.stats['by_domain'], }) except Exception as e: logger.error(f"Pipeline failed: {e}") if self.checkpoint_mgr: self.checkpoint_mgr.state.status = "failed" self.checkpoint_mgr._save() raise