diff --git a/backend-compliance/compliance/api/vendor_compliance_routes.py b/backend-compliance/compliance/api/vendor_compliance_routes.py index 7ed5e3f..380d228 100644 --- a/backend-compliance/compliance/api/vendor_compliance_routes.py +++ b/backend-compliance/compliance/api/vendor_compliance_routes.py @@ -42,319 +42,86 @@ Endpoints: GET /vendor-compliance/export/{id} — 501 GET /vendor-compliance/export/{id}/download — 501 -DB tables (Go Migration 011, schema: vendor_vendors, vendor_contracts, -vendor_findings, vendor_control_instances). +Phase 1 Step 4 refactor: handlers delegate to VendorService, +ContractService, FindingService, ControlInstanceService, and +ControlsLibraryService. Module-level helpers re-exported for legacy +test imports. """ -import json import logging -import uuid -from datetime import datetime, timezone from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import text +from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session from classroom_engine.database import get_db +from compliance.api._http_errors import translate_domain_errors +from compliance.services.vendor_compliance_service import ( + DEFAULT_TENANT_ID, # noqa: F401 — re-export + VendorService, + _get, # noqa: F401 — re-export + _now_iso, + _ok, # noqa: F401 — re-export + _parse_json, # noqa: F401 — re-export + _to_camel, # noqa: F401 — re-export + _to_snake, # noqa: F401 — re-export + _ts, # noqa: F401 — re-export + _vendor_to_response, # noqa: F401 — re-export + _VENDOR_CAMEL_TO_SNAKE, # noqa: F401 — re-export + _VENDOR_SNAKE_TO_CAMEL, # noqa: F401 — re-export +) +from compliance.services.vendor_compliance_sub_service import ( + ContractService, + _contract_to_response, # noqa: F401 — re-export + _control_instance_to_response, # noqa: F401 — re-export + _finding_to_response, # noqa: F401 — re-export +) +from compliance.services.vendor_compliance_extra_service import ( + ControlInstanceService, + ControlsLibraryService, + FindingService, +) logger = logging.getLogger(__name__) router = APIRouter(prefix="/vendor-compliance", tags=["vendor-compliance"]) -# Default tenant UUID — "default" string no longer accepted -DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e" -# ============================================================================= -# Helpers -# ============================================================================= +# --------------------------------------------------------------------------- +# Service factories +# --------------------------------------------------------------------------- -def _now_iso() -> str: - return datetime.now(timezone.utc).isoformat() + "Z" +def _vendor_svc(db: Session = Depends(get_db)) -> VendorService: + return VendorService(db) -def _ok(data, status_code: int = 200): - """Wrap response in {success, data, timestamp} envelope.""" - return {"success": True, "data": data, "timestamp": _now_iso()} +def _contract_svc(db: Session = Depends(get_db)) -> ContractService: + return ContractService(db) -def _parse_json(val, default=None): - """Parse a JSONB/TEXT field → Python object.""" - if val is None: - return default if default is not None else None - if isinstance(val, (dict, list)): - return val - if isinstance(val, str): - try: - return json.loads(val) - except Exception: - return default if default is not None else val - return val +def _finding_svc(db: Session = Depends(get_db)) -> FindingService: + return FindingService(db) -def _ts(val): - """Timestamp → ISO string or None.""" - if not val: - return None - if isinstance(val, str): - return val - return val.isoformat() +def _ci_svc(db: Session = Depends(get_db)) -> ControlInstanceService: + return ControlInstanceService(db) -def _get(row, key, default=None): - """Safe row access.""" - try: - v = row[key] - return default if v is None and default is not None else v - except (KeyError, IndexError): - return default +def _ctrl_svc(db: Session = Depends(get_db)) -> ControlsLibraryService: + return ControlsLibraryService(db) -# camelCase ↔ snake_case conversion maps -_VENDOR_CAMEL_TO_SNAKE = { - # Vendor fields - "legalForm": "legal_form", - "serviceDescription": "service_description", - "serviceCategory": "service_category", - "dataAccessLevel": "data_access_level", - "processingLocations": "processing_locations", - "transferMechanisms": "transfer_mechanisms", - "primaryContact": "primary_contact", - "dpoContact": "dpo_contact", - "securityContact": "security_contact", - "contractTypes": "contract_types", - "inherentRiskScore": "inherent_risk_score", - "residualRiskScore": "residual_risk_score", - "manualRiskAdjustment": "manual_risk_adjustment", - "riskJustification": "risk_justification", - "reviewFrequency": "review_frequency", - "lastReviewDate": "last_review_date", - "nextReviewDate": "next_review_date", - "processingActivityIds": "processing_activity_ids", - "contactName": "contact_name", - "contactEmail": "contact_email", - "contactPhone": "contact_phone", - "contactDepartment": "contact_department", - # Common / cross-entity fields - "tenantId": "tenant_id", - "createdAt": "created_at", - "updatedAt": "updated_at", - "createdBy": "created_by", - "vendorId": "vendor_id", - "contractId": "contract_id", - "controlId": "control_id", - "controlDomain": "control_domain", - "evidenceIds": "evidence_ids", - "lastAssessedAt": "last_assessed_at", - "lastAssessedBy": "last_assessed_by", - "nextAssessmentDate": "next_assessment_date", - # Contract fields - "fileName": "file_name", - "originalName": "original_name", - "mimeType": "mime_type", - "fileSize": "file_size", - "storagePath": "storage_path", - "documentType": "document_type", - "previousVersionId": "previous_version_id", - "effectiveDate": "effective_date", - "expirationDate": "expiration_date", - "autoRenewal": "auto_renewal", - "renewalNoticePeriod": "renewal_notice_period", - "terminationNoticePeriod": "termination_notice_period", - "reviewStatus": "review_status", - "reviewCompletedAt": "review_completed_at", - "complianceScore": "compliance_score", - "extractedText": "extracted_text", - "pageCount": "page_count", - # Finding fields - "findingType": "finding_type", - "dueDate": "due_date", - "resolvedAt": "resolved_at", - "resolvedBy": "resolved_by", -} - -_VENDOR_SNAKE_TO_CAMEL = {v: k for k, v in _VENDOR_CAMEL_TO_SNAKE.items()} - - -def _to_snake(data: dict) -> dict: - """Convert camelCase keys in data to snake_case for DB storage.""" - result = {} - for k, v in data.items(): - snake = _VENDOR_CAMEL_TO_SNAKE.get(k, k) - result[snake] = v - return result - - -def _to_camel(data: dict) -> dict: - """Convert snake_case keys to camelCase for frontend.""" - result = {} - for k, v in data.items(): - camel = _VENDOR_SNAKE_TO_CAMEL.get(k, k) - result[camel] = v - return result - - -# ============================================================================= -# Row → Response converters -# ============================================================================= - -def _vendor_to_response(row) -> dict: - return _to_camel({ - "id": str(row["id"]), - "tenant_id": row["tenant_id"], - "name": row["name"], - "legal_form": _get(row, "legal_form", ""), - "country": _get(row, "country", ""), - "address": _get(row, "address", ""), - "website": _get(row, "website", ""), - "role": _get(row, "role", "PROCESSOR"), - "service_description": _get(row, "service_description", ""), - "service_category": _get(row, "service_category", "OTHER"), - "data_access_level": _get(row, "data_access_level", "NONE"), - "processing_locations": _parse_json(_get(row, "processing_locations"), []), - "transfer_mechanisms": _parse_json(_get(row, "transfer_mechanisms"), []), - "certifications": _parse_json(_get(row, "certifications"), []), - "primary_contact": _parse_json(_get(row, "primary_contact"), {}), - "dpo_contact": _parse_json(_get(row, "dpo_contact"), {}), - "security_contact": _parse_json(_get(row, "security_contact"), {}), - "contract_types": _parse_json(_get(row, "contract_types"), []), - "inherent_risk_score": _get(row, "inherent_risk_score", 50), - "residual_risk_score": _get(row, "residual_risk_score", 50), - "manual_risk_adjustment": _get(row, "manual_risk_adjustment"), - "risk_justification": _get(row, "risk_justification", ""), - "review_frequency": _get(row, "review_frequency", "ANNUAL"), - "last_review_date": _ts(_get(row, "last_review_date")), - "next_review_date": _ts(_get(row, "next_review_date")), - "status": _get(row, "status", "ACTIVE"), - "processing_activity_ids": _parse_json(_get(row, "processing_activity_ids"), []), - "notes": _get(row, "notes", ""), - "contact_name": _get(row, "contact_name", ""), - "contact_email": _get(row, "contact_email", ""), - "contact_phone": _get(row, "contact_phone", ""), - "contact_department": _get(row, "contact_department", ""), - "created_at": _ts(row["created_at"]), - "updated_at": _ts(row["updated_at"]), - "created_by": _get(row, "created_by", "system"), - }) - - -def _contract_to_response(row) -> dict: - return _to_camel({ - "id": str(row["id"]), - "tenant_id": row["tenant_id"], - "vendor_id": str(row["vendor_id"]), - "file_name": _get(row, "file_name", ""), - "original_name": _get(row, "original_name", ""), - "mime_type": _get(row, "mime_type", ""), - "file_size": _get(row, "file_size", 0), - "storage_path": _get(row, "storage_path", ""), - "document_type": _get(row, "document_type", "AVV"), - "version": _get(row, "version", 1), - "previous_version_id": str(_get(row, "previous_version_id")) if _get(row, "previous_version_id") else None, - "parties": _parse_json(_get(row, "parties"), []), - "effective_date": _ts(_get(row, "effective_date")), - "expiration_date": _ts(_get(row, "expiration_date")), - "auto_renewal": _get(row, "auto_renewal", False), - "renewal_notice_period": _get(row, "renewal_notice_period", ""), - "termination_notice_period": _get(row, "termination_notice_period", ""), - "review_status": _get(row, "review_status", "PENDING"), - "review_completed_at": _ts(_get(row, "review_completed_at")), - "compliance_score": _get(row, "compliance_score"), - "status": _get(row, "status", "DRAFT"), - "extracted_text": _get(row, "extracted_text", ""), - "page_count": _get(row, "page_count", 0), - "created_at": _ts(row["created_at"]), - "updated_at": _ts(row["updated_at"]), - "created_by": _get(row, "created_by", "system"), - }) - - -def _finding_to_response(row) -> dict: - return _to_camel({ - "id": str(row["id"]), - "tenant_id": row["tenant_id"], - "vendor_id": str(row["vendor_id"]), - "contract_id": str(_get(row, "contract_id")) if _get(row, "contract_id") else None, - "finding_type": _get(row, "finding_type", "UNKNOWN"), - "category": _get(row, "category", ""), - "severity": _get(row, "severity", "MEDIUM"), - "title": _get(row, "title", ""), - "description": _get(row, "description", ""), - "recommendation": _get(row, "recommendation", ""), - "citations": _parse_json(_get(row, "citations"), []), - "status": _get(row, "status", "OPEN"), - "assignee": _get(row, "assignee", ""), - "due_date": _ts(_get(row, "due_date")), - "resolution": _get(row, "resolution", ""), - "resolved_at": _ts(_get(row, "resolved_at")), - "resolved_by": _get(row, "resolved_by", ""), - "created_at": _ts(row["created_at"]), - "updated_at": _ts(row["updated_at"]), - "created_by": _get(row, "created_by", "system"), - }) - - -def _control_instance_to_response(row) -> dict: - return _to_camel({ - "id": str(row["id"]), - "tenant_id": row["tenant_id"], - "vendor_id": str(row["vendor_id"]), - "control_id": _get(row, "control_id", ""), - "control_domain": _get(row, "control_domain", ""), - "status": _get(row, "status", "PLANNED"), - "evidence_ids": _parse_json(_get(row, "evidence_ids"), []), - "notes": _get(row, "notes", ""), - "last_assessed_at": _ts(_get(row, "last_assessed_at")), - "last_assessed_by": _get(row, "last_assessed_by", ""), - "next_assessment_date": _ts(_get(row, "next_assessment_date")), - "created_at": _ts(row["created_at"]), - "updated_at": _ts(row["updated_at"]), - "created_by": _get(row, "created_by", "system"), - }) - - -# ============================================================================= +# ============================================================================ # Vendors -# ============================================================================= +# ============================================================================ + @router.get("/vendors/stats") def get_vendor_stats( tenant_id: Optional[str] = Query(None), - db: Session = Depends(get_db), + svc: VendorService = Depends(_vendor_svc), ): - tid = tenant_id or DEFAULT_TENANT_ID - result = db.execute(text(""" - SELECT - COUNT(*) AS total, - COUNT(*) FILTER (WHERE status = 'ACTIVE') AS active, - COUNT(*) FILTER (WHERE status = 'INACTIVE') AS inactive, - COUNT(*) FILTER (WHERE status = 'PENDING_REVIEW') AS pending_review, - COUNT(*) FILTER (WHERE status = 'TERMINATED') AS terminated, - COALESCE(AVG(inherent_risk_score), 0) AS avg_inherent_risk, - COALESCE(AVG(residual_risk_score), 0) AS avg_residual_risk, - COUNT(*) FILTER (WHERE inherent_risk_score >= 75) AS high_risk_count - FROM vendor_vendors - WHERE tenant_id = :tid - """), {"tid": tid}) - row = result.fetchone() - if row is None: - stats = { - "total": 0, "active": 0, "inactive": 0, - "pending_review": 0, "terminated": 0, - "avg_inherent_risk": 0, "avg_residual_risk": 0, - "high_risk_count": 0, - } - else: - stats = { - "total": row["total"] or 0, - "active": row["active"] or 0, - "inactive": row["inactive"] or 0, - "pendingReview": row["pending_review"] or 0, - "terminated": row["terminated"] or 0, - "avgInherentRisk": round(float(row["avg_inherent_risk"] or 0), 1), - "avgResidualRisk": round(float(row["avg_residual_risk"] or 0), 1), - "highRiskCount": row["high_risk_count"] or 0, - } - return _ok(stats) + with translate_domain_errors(): + return svc.get_stats(tenant_id) @router.get("/vendors") @@ -365,212 +132,63 @@ def list_vendors( search: Optional[str] = Query(None), skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=500), - db: Session = Depends(get_db), + svc: VendorService = Depends(_vendor_svc), ): - tid = tenant_id or DEFAULT_TENANT_ID - where = ["tenant_id = :tid"] - params: dict = {"tid": tid} - - if status: - where.append("status = :status") - params["status"] = status - if risk_level: - if risk_level == "HIGH": - where.append("inherent_risk_score >= 75") - elif risk_level == "MEDIUM": - where.append("inherent_risk_score >= 40 AND inherent_risk_score < 75") - elif risk_level == "LOW": - where.append("inherent_risk_score < 40") - if search: - where.append("(name ILIKE :search OR service_description ILIKE :search)") - params["search"] = f"%{search}%" - - where_clause = " AND ".join(where) - params["lim"] = limit - params["off"] = skip - - rows = db.execute(text(f""" - SELECT * FROM vendor_vendors - WHERE {where_clause} - ORDER BY created_at DESC - LIMIT :lim OFFSET :off - """), params).fetchall() - - count_row = db.execute(text(f""" - SELECT COUNT(*) AS cnt FROM vendor_vendors WHERE {where_clause} - """), {k: v for k, v in params.items() if k not in ("lim", "off")}).fetchone() - total = count_row["cnt"] if count_row else 0 - - return _ok({"items": [_vendor_to_response(r) for r in rows], "total": total}) + with translate_domain_errors(): + return svc.list_vendors(tenant_id, status, risk_level, search, skip, limit) @router.get("/vendors/{vendor_id}") -def get_vendor(vendor_id: str, db: Session = Depends(get_db)): - row = db.execute(text("SELECT * FROM vendor_vendors WHERE id = :id"), - {"id": vendor_id}).fetchone() - if not row: - raise HTTPException(404, "Vendor not found") - return _ok(_vendor_to_response(row)) +def get_vendor( + vendor_id: str, + svc: VendorService = Depends(_vendor_svc), +): + with translate_domain_errors(): + return svc.get_vendor(vendor_id) @router.post("/vendors", status_code=201) -def create_vendor(body: dict = {}, db: Session = Depends(get_db)): - data = _to_snake(body) - vid = str(uuid.uuid4()) - tid = data.get("tenant_id", DEFAULT_TENANT_ID) - now = datetime.now(timezone.utc).isoformat() - - db.execute(text(""" - INSERT INTO vendor_vendors ( - id, tenant_id, name, legal_form, country, address, website, - role, service_description, service_category, data_access_level, - processing_locations, transfer_mechanisms, certifications, - primary_contact, dpo_contact, security_contact, - contract_types, inherent_risk_score, residual_risk_score, - manual_risk_adjustment, risk_justification, - review_frequency, last_review_date, next_review_date, - status, processing_activity_ids, notes, - contact_name, contact_email, contact_phone, contact_department, - created_at, updated_at, created_by - ) VALUES ( - :id, :tenant_id, :name, :legal_form, :country, :address, :website, - :role, :service_description, :service_category, :data_access_level, - CAST(:processing_locations AS jsonb), CAST(:transfer_mechanisms AS jsonb), - CAST(:certifications AS jsonb), - CAST(:primary_contact AS jsonb), CAST(:dpo_contact AS jsonb), - CAST(:security_contact AS jsonb), - CAST(:contract_types AS jsonb), :inherent_risk_score, :residual_risk_score, - :manual_risk_adjustment, :risk_justification, - :review_frequency, :last_review_date, :next_review_date, - :status, CAST(:processing_activity_ids AS jsonb), :notes, - :contact_name, :contact_email, :contact_phone, :contact_department, - :created_at, :updated_at, :created_by - ) - """), { - "id": vid, - "tenant_id": tid, - "name": data.get("name", ""), - "legal_form": data.get("legal_form", ""), - "country": data.get("country", ""), - "address": data.get("address", ""), - "website": data.get("website", ""), - "role": data.get("role", "PROCESSOR"), - "service_description": data.get("service_description", ""), - "service_category": data.get("service_category", "OTHER"), - "data_access_level": data.get("data_access_level", "NONE"), - "processing_locations": json.dumps(data.get("processing_locations", [])), - "transfer_mechanisms": json.dumps(data.get("transfer_mechanisms", [])), - "certifications": json.dumps(data.get("certifications", [])), - "primary_contact": json.dumps(data.get("primary_contact", {})), - "dpo_contact": json.dumps(data.get("dpo_contact", {})), - "security_contact": json.dumps(data.get("security_contact", {})), - "contract_types": json.dumps(data.get("contract_types", [])), - "inherent_risk_score": data.get("inherent_risk_score", 50), - "residual_risk_score": data.get("residual_risk_score", 50), - "manual_risk_adjustment": data.get("manual_risk_adjustment"), - "risk_justification": data.get("risk_justification", ""), - "review_frequency": data.get("review_frequency", "ANNUAL"), - "last_review_date": data.get("last_review_date"), - "next_review_date": data.get("next_review_date"), - "status": data.get("status", "ACTIVE"), - "processing_activity_ids": json.dumps(data.get("processing_activity_ids", [])), - "notes": data.get("notes", ""), - "contact_name": data.get("contact_name", ""), - "contact_email": data.get("contact_email", ""), - "contact_phone": data.get("contact_phone", ""), - "contact_department": data.get("contact_department", ""), - "created_at": now, - "updated_at": now, - "created_by": data.get("created_by", "system"), - }) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_vendors WHERE id = :id"), - {"id": vid}).fetchone() - return _ok(_vendor_to_response(row)) +def create_vendor( + body: dict = {}, + svc: VendorService = Depends(_vendor_svc), +): + with translate_domain_errors(): + return svc.create_vendor(body) @router.put("/vendors/{vendor_id}") -def update_vendor(vendor_id: str, body: dict = {}, db: Session = Depends(get_db)): - existing = db.execute(text("SELECT id FROM vendor_vendors WHERE id = :id"), - {"id": vendor_id}).fetchone() - if not existing: - raise HTTPException(404, "Vendor not found") - - data = _to_snake(body) - now = datetime.now(timezone.utc).isoformat() - - # Build dynamic SET clause - allowed = [ - "name", "legal_form", "country", "address", "website", - "role", "service_description", "service_category", "data_access_level", - "inherent_risk_score", "residual_risk_score", - "manual_risk_adjustment", "risk_justification", - "review_frequency", "last_review_date", "next_review_date", - "status", "notes", - "contact_name", "contact_email", "contact_phone", "contact_department", - ] - jsonb_fields = [ - "processing_locations", "transfer_mechanisms", "certifications", - "primary_contact", "dpo_contact", "security_contact", - "contract_types", "processing_activity_ids", - ] - - sets = ["updated_at = :updated_at"] - params: dict = {"id": vendor_id, "updated_at": now} - - for col in allowed: - if col in data: - sets.append(f"{col} = :{col}") - params[col] = data[col] - - for col in jsonb_fields: - if col in data: - sets.append(f"{col} = CAST(:{col} AS jsonb)") - params[col] = json.dumps(data[col]) - - db.execute(text(f"UPDATE vendor_vendors SET {', '.join(sets)} WHERE id = :id"), params) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_vendors WHERE id = :id"), - {"id": vendor_id}).fetchone() - return _ok(_vendor_to_response(row)) +def update_vendor( + vendor_id: str, + body: dict = {}, + svc: VendorService = Depends(_vendor_svc), +): + with translate_domain_errors(): + return svc.update_vendor(vendor_id, body) @router.delete("/vendors/{vendor_id}") -def delete_vendor(vendor_id: str, db: Session = Depends(get_db)): - result = db.execute(text("DELETE FROM vendor_vendors WHERE id = :id"), - {"id": vendor_id}) - db.commit() - if result.rowcount == 0: - raise HTTPException(404, "Vendor not found") - return _ok({"deleted": True}) +def delete_vendor( + vendor_id: str, + svc: VendorService = Depends(_vendor_svc), +): + with translate_domain_errors(): + return svc.delete_vendor(vendor_id) @router.patch("/vendors/{vendor_id}/status") -def patch_vendor_status(vendor_id: str, body: dict = {}, db: Session = Depends(get_db)): - new_status = body.get("status") - if not new_status: - raise HTTPException(400, "status is required") - valid = {"ACTIVE", "INACTIVE", "PENDING_REVIEW", "TERMINATED"} - if new_status not in valid: - raise HTTPException(400, f"Invalid status. Must be one of: {', '.join(sorted(valid))}") - - result = db.execute(text(""" - UPDATE vendor_vendors SET status = :status, updated_at = :now WHERE id = :id - """), {"id": vendor_id, "status": new_status, "now": datetime.now(timezone.utc).isoformat()}) - db.commit() - if result.rowcount == 0: - raise HTTPException(404, "Vendor not found") - - row = db.execute(text("SELECT * FROM vendor_vendors WHERE id = :id"), - {"id": vendor_id}).fetchone() - return _ok(_vendor_to_response(row)) +def patch_vendor_status( + vendor_id: str, + body: dict = {}, + svc: VendorService = Depends(_vendor_svc), +): + with translate_domain_errors(): + return svc.patch_status(vendor_id, body) -# ============================================================================= +# ============================================================================ # Contracts -# ============================================================================= +# ============================================================================ + @router.get("/contracts") def list_contracts( @@ -579,155 +197,53 @@ def list_contracts( status: Optional[str] = Query(None), skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=500), - db: Session = Depends(get_db), + svc: ContractService = Depends(_contract_svc), ): - tid = tenant_id or DEFAULT_TENANT_ID - where = ["tenant_id = :tid"] - params: dict = {"tid": tid} - - if vendor_id: - where.append("vendor_id = :vendor_id") - params["vendor_id"] = vendor_id - if status: - where.append("status = :status") - params["status"] = status - - where_clause = " AND ".join(where) - params["lim"] = limit - params["off"] = skip - - rows = db.execute(text(f""" - SELECT * FROM vendor_contracts - WHERE {where_clause} - ORDER BY created_at DESC - LIMIT :lim OFFSET :off - """), params).fetchall() - - return _ok([_contract_to_response(r) for r in rows]) + with translate_domain_errors(): + return svc.list_contracts(tenant_id, vendor_id, status, skip, limit) @router.get("/contracts/{contract_id}") -def get_contract(contract_id: str, db: Session = Depends(get_db)): - row = db.execute(text("SELECT * FROM vendor_contracts WHERE id = :id"), - {"id": contract_id}).fetchone() - if not row: - raise HTTPException(404, "Contract not found") - return _ok(_contract_to_response(row)) +def get_contract( + contract_id: str, + svc: ContractService = Depends(_contract_svc), +): + with translate_domain_errors(): + return svc.get_contract(contract_id) @router.post("/contracts", status_code=201) -def create_contract(body: dict = {}, db: Session = Depends(get_db)): - data = _to_snake(body) - cid = str(uuid.uuid4()) - tid = data.get("tenant_id", DEFAULT_TENANT_ID) - now = datetime.now(timezone.utc).isoformat() - - db.execute(text(""" - INSERT INTO vendor_contracts ( - id, tenant_id, vendor_id, file_name, original_name, mime_type, - file_size, storage_path, document_type, version, previous_version_id, - parties, effective_date, expiration_date, - auto_renewal, renewal_notice_period, termination_notice_period, - review_status, status, compliance_score, - extracted_text, page_count, - created_at, updated_at, created_by - ) VALUES ( - :id, :tenant_id, :vendor_id, :file_name, :original_name, :mime_type, - :file_size, :storage_path, :document_type, :version, :previous_version_id, - CAST(:parties AS jsonb), :effective_date, :expiration_date, - :auto_renewal, :renewal_notice_period, :termination_notice_period, - :review_status, :status, :compliance_score, - :extracted_text, :page_count, - :created_at, :updated_at, :created_by - ) - """), { - "id": cid, - "tenant_id": tid, - "vendor_id": data.get("vendor_id", ""), - "file_name": data.get("file_name", ""), - "original_name": data.get("original_name", ""), - "mime_type": data.get("mime_type", ""), - "file_size": data.get("file_size", 0), - "storage_path": data.get("storage_path", ""), - "document_type": data.get("document_type", "AVV"), - "version": data.get("version", 1), - "previous_version_id": data.get("previous_version_id"), - "parties": json.dumps(data.get("parties", [])), - "effective_date": data.get("effective_date"), - "expiration_date": data.get("expiration_date"), - "auto_renewal": data.get("auto_renewal", False), - "renewal_notice_period": data.get("renewal_notice_period", ""), - "termination_notice_period": data.get("termination_notice_period", ""), - "review_status": data.get("review_status", "PENDING"), - "status": data.get("status", "DRAFT"), - "compliance_score": data.get("compliance_score"), - "extracted_text": data.get("extracted_text", ""), - "page_count": data.get("page_count", 0), - "created_at": now, - "updated_at": now, - "created_by": data.get("created_by", "system"), - }) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_contracts WHERE id = :id"), - {"id": cid}).fetchone() - return _ok(_contract_to_response(row)) +def create_contract( + body: dict = {}, + svc: ContractService = Depends(_contract_svc), +): + with translate_domain_errors(): + return svc.create_contract(body) @router.put("/contracts/{contract_id}") -def update_contract(contract_id: str, body: dict = {}, db: Session = Depends(get_db)): - existing = db.execute(text("SELECT id FROM vendor_contracts WHERE id = :id"), - {"id": contract_id}).fetchone() - if not existing: - raise HTTPException(404, "Contract not found") - - data = _to_snake(body) - now = datetime.now(timezone.utc).isoformat() - - allowed = [ - "vendor_id", "file_name", "original_name", "mime_type", "file_size", - "storage_path", "document_type", "version", "previous_version_id", - "effective_date", "expiration_date", "auto_renewal", - "renewal_notice_period", "termination_notice_period", - "review_status", "review_completed_at", "compliance_score", - "status", "extracted_text", "page_count", - ] - jsonb_fields = ["parties"] - - sets = ["updated_at = :updated_at"] - params: dict = {"id": contract_id, "updated_at": now} - - for col in allowed: - if col in data: - sets.append(f"{col} = :{col}") - params[col] = data[col] - - for col in jsonb_fields: - if col in data: - sets.append(f"{col} = CAST(:{col} AS jsonb)") - params[col] = json.dumps(data[col]) - - db.execute(text(f"UPDATE vendor_contracts SET {', '.join(sets)} WHERE id = :id"), params) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_contracts WHERE id = :id"), - {"id": contract_id}).fetchone() - return _ok(_contract_to_response(row)) +def update_contract( + contract_id: str, + body: dict = {}, + svc: ContractService = Depends(_contract_svc), +): + with translate_domain_errors(): + return svc.update_contract(contract_id, body) @router.delete("/contracts/{contract_id}") -def delete_contract(contract_id: str, db: Session = Depends(get_db)): - result = db.execute(text("DELETE FROM vendor_contracts WHERE id = :id"), - {"id": contract_id}) - db.commit() - if result.rowcount == 0: - raise HTTPException(404, "Contract not found") - return _ok({"deleted": True}) +def delete_contract( + contract_id: str, + svc: ContractService = Depends(_contract_svc), +): + with translate_domain_errors(): + return svc.delete_contract(contract_id) -# ============================================================================= +# ============================================================================ # Findings -# ============================================================================= +# ============================================================================ + @router.get("/findings") def list_findings( @@ -737,144 +253,53 @@ def list_findings( status: Optional[str] = Query(None), skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=500), - db: Session = Depends(get_db), + svc: FindingService = Depends(_finding_svc), ): - tid = tenant_id or DEFAULT_TENANT_ID - where = ["tenant_id = :tid"] - params: dict = {"tid": tid} - - if vendor_id: - where.append("vendor_id = :vendor_id") - params["vendor_id"] = vendor_id - if severity: - where.append("severity = :severity") - params["severity"] = severity - if status: - where.append("status = :status") - params["status"] = status - - where_clause = " AND ".join(where) - params["lim"] = limit - params["off"] = skip - - rows = db.execute(text(f""" - SELECT * FROM vendor_findings - WHERE {where_clause} - ORDER BY created_at DESC - LIMIT :lim OFFSET :off - """), params).fetchall() - - return _ok([_finding_to_response(r) for r in rows]) + with translate_domain_errors(): + return svc.list_findings(tenant_id, vendor_id, severity, status, skip, limit) @router.get("/findings/{finding_id}") -def get_finding(finding_id: str, db: Session = Depends(get_db)): - row = db.execute(text("SELECT * FROM vendor_findings WHERE id = :id"), - {"id": finding_id}).fetchone() - if not row: - raise HTTPException(404, "Finding not found") - return _ok(_finding_to_response(row)) +def get_finding( + finding_id: str, + svc: FindingService = Depends(_finding_svc), +): + with translate_domain_errors(): + return svc.get_finding(finding_id) @router.post("/findings", status_code=201) -def create_finding(body: dict = {}, db: Session = Depends(get_db)): - data = _to_snake(body) - fid = str(uuid.uuid4()) - tid = data.get("tenant_id", DEFAULT_TENANT_ID) - now = datetime.now(timezone.utc).isoformat() - - db.execute(text(""" - INSERT INTO vendor_findings ( - id, tenant_id, vendor_id, contract_id, - finding_type, category, severity, - title, description, recommendation, - citations, status, assignee, due_date, - created_at, updated_at, created_by - ) VALUES ( - :id, :tenant_id, :vendor_id, :contract_id, - :finding_type, :category, :severity, - :title, :description, :recommendation, - CAST(:citations AS jsonb), :status, :assignee, :due_date, - :created_at, :updated_at, :created_by - ) - """), { - "id": fid, - "tenant_id": tid, - "vendor_id": data.get("vendor_id", ""), - "contract_id": data.get("contract_id"), - "finding_type": data.get("finding_type", "UNKNOWN"), - "category": data.get("category", ""), - "severity": data.get("severity", "MEDIUM"), - "title": data.get("title", ""), - "description": data.get("description", ""), - "recommendation": data.get("recommendation", ""), - "citations": json.dumps(data.get("citations", [])), - "status": data.get("status", "OPEN"), - "assignee": data.get("assignee", ""), - "due_date": data.get("due_date"), - "created_at": now, - "updated_at": now, - "created_by": data.get("created_by", "system"), - }) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_findings WHERE id = :id"), - {"id": fid}).fetchone() - return _ok(_finding_to_response(row)) +def create_finding( + body: dict = {}, + svc: FindingService = Depends(_finding_svc), +): + with translate_domain_errors(): + return svc.create_finding(body) @router.put("/findings/{finding_id}") -def update_finding(finding_id: str, body: dict = {}, db: Session = Depends(get_db)): - existing = db.execute(text("SELECT id FROM vendor_findings WHERE id = :id"), - {"id": finding_id}).fetchone() - if not existing: - raise HTTPException(404, "Finding not found") - - data = _to_snake(body) - now = datetime.now(timezone.utc).isoformat() - - allowed = [ - "vendor_id", "contract_id", "finding_type", "category", "severity", - "title", "description", "recommendation", - "status", "assignee", "due_date", - "resolution", "resolved_at", "resolved_by", - ] - jsonb_fields = ["citations"] - - sets = ["updated_at = :updated_at"] - params: dict = {"id": finding_id, "updated_at": now} - - for col in allowed: - if col in data: - sets.append(f"{col} = :{col}") - params[col] = data[col] - - for col in jsonb_fields: - if col in data: - sets.append(f"{col} = CAST(:{col} AS jsonb)") - params[col] = json.dumps(data[col]) - - db.execute(text(f"UPDATE vendor_findings SET {', '.join(sets)} WHERE id = :id"), params) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_findings WHERE id = :id"), - {"id": finding_id}).fetchone() - return _ok(_finding_to_response(row)) +def update_finding( + finding_id: str, + body: dict = {}, + svc: FindingService = Depends(_finding_svc), +): + with translate_domain_errors(): + return svc.update_finding(finding_id, body) @router.delete("/findings/{finding_id}") -def delete_finding(finding_id: str, db: Session = Depends(get_db)): - result = db.execute(text("DELETE FROM vendor_findings WHERE id = :id"), - {"id": finding_id}) - db.commit() - if result.rowcount == 0: - raise HTTPException(404, "Finding not found") - return _ok({"deleted": True}) +def delete_finding( + finding_id: str, + svc: FindingService = Depends(_finding_svc), +): + with translate_domain_errors(): + return svc.delete_finding(finding_id) -# ============================================================================= +# ============================================================================ # Control Instances -# ============================================================================= +# ============================================================================ + @router.get("/control-instances") def list_control_instances( @@ -882,215 +307,86 @@ def list_control_instances( vendor_id: Optional[str] = Query(None), skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=500), - db: Session = Depends(get_db), + svc: ControlInstanceService = Depends(_ci_svc), ): - tid = tenant_id or DEFAULT_TENANT_ID - where = ["tenant_id = :tid"] - params: dict = {"tid": tid} - - if vendor_id: - where.append("vendor_id = :vendor_id") - params["vendor_id"] = vendor_id - - where_clause = " AND ".join(where) - params["lim"] = limit - params["off"] = skip - - rows = db.execute(text(f""" - SELECT * FROM vendor_control_instances - WHERE {where_clause} - ORDER BY created_at DESC - LIMIT :lim OFFSET :off - """), params).fetchall() - - return _ok([_control_instance_to_response(r) for r in rows]) + with translate_domain_errors(): + return svc.list_instances(tenant_id, vendor_id, skip, limit) @router.get("/control-instances/{instance_id}") -def get_control_instance(instance_id: str, db: Session = Depends(get_db)): - row = db.execute(text("SELECT * FROM vendor_control_instances WHERE id = :id"), - {"id": instance_id}).fetchone() - if not row: - raise HTTPException(404, "Control instance not found") - return _ok(_control_instance_to_response(row)) +def get_control_instance( + instance_id: str, + svc: ControlInstanceService = Depends(_ci_svc), +): + with translate_domain_errors(): + return svc.get_instance(instance_id) @router.post("/control-instances", status_code=201) -def create_control_instance(body: dict = {}, db: Session = Depends(get_db)): - data = _to_snake(body) - ciid = str(uuid.uuid4()) - tid = data.get("tenant_id", DEFAULT_TENANT_ID) - now = datetime.now(timezone.utc).isoformat() - - db.execute(text(""" - INSERT INTO vendor_control_instances ( - id, tenant_id, vendor_id, control_id, control_domain, - status, evidence_ids, notes, - last_assessed_at, last_assessed_by, next_assessment_date, - created_at, updated_at, created_by - ) VALUES ( - :id, :tenant_id, :vendor_id, :control_id, :control_domain, - :status, CAST(:evidence_ids AS jsonb), :notes, - :last_assessed_at, :last_assessed_by, :next_assessment_date, - :created_at, :updated_at, :created_by - ) - """), { - "id": ciid, - "tenant_id": tid, - "vendor_id": data.get("vendor_id", ""), - "control_id": data.get("control_id", ""), - "control_domain": data.get("control_domain", ""), - "status": data.get("status", "PLANNED"), - "evidence_ids": json.dumps(data.get("evidence_ids", [])), - "notes": data.get("notes", ""), - "last_assessed_at": data.get("last_assessed_at"), - "last_assessed_by": data.get("last_assessed_by", ""), - "next_assessment_date": data.get("next_assessment_date"), - "created_at": now, - "updated_at": now, - "created_by": data.get("created_by", "system"), - }) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_control_instances WHERE id = :id"), - {"id": ciid}).fetchone() - return _ok(_control_instance_to_response(row)) +def create_control_instance( + body: dict = {}, + svc: ControlInstanceService = Depends(_ci_svc), +): + with translate_domain_errors(): + return svc.create_instance(body) @router.put("/control-instances/{instance_id}") -def update_control_instance(instance_id: str, body: dict = {}, db: Session = Depends(get_db)): - existing = db.execute(text("SELECT id FROM vendor_control_instances WHERE id = :id"), - {"id": instance_id}).fetchone() - if not existing: - raise HTTPException(404, "Control instance not found") - - data = _to_snake(body) - now = datetime.now(timezone.utc).isoformat() - - allowed = [ - "vendor_id", "control_id", "control_domain", - "status", "notes", - "last_assessed_at", "last_assessed_by", "next_assessment_date", - ] - jsonb_fields = ["evidence_ids"] - - sets = ["updated_at = :updated_at"] - params: dict = {"id": instance_id, "updated_at": now} - - for col in allowed: - if col in data: - sets.append(f"{col} = :{col}") - params[col] = data[col] - - for col in jsonb_fields: - if col in data: - sets.append(f"{col} = CAST(:{col} AS jsonb)") - params[col] = json.dumps(data[col]) - - db.execute(text(f"UPDATE vendor_control_instances SET {', '.join(sets)} WHERE id = :id"), params) - db.commit() - - row = db.execute(text("SELECT * FROM vendor_control_instances WHERE id = :id"), - {"id": instance_id}).fetchone() - return _ok(_control_instance_to_response(row)) +def update_control_instance( + instance_id: str, + body: dict = {}, + svc: ControlInstanceService = Depends(_ci_svc), +): + with translate_domain_errors(): + return svc.update_instance(instance_id, body) @router.delete("/control-instances/{instance_id}") -def delete_control_instance(instance_id: str, db: Session = Depends(get_db)): - result = db.execute(text("DELETE FROM vendor_control_instances WHERE id = :id"), - {"id": instance_id}) - db.commit() - if result.rowcount == 0: - raise HTTPException(404, "Control instance not found") - return _ok({"deleted": True}) +def delete_control_instance( + instance_id: str, + svc: ControlInstanceService = Depends(_ci_svc), +): + with translate_domain_errors(): + return svc.delete_instance(instance_id) -# ============================================================================= -# Controls Library (vendor_compliance_controls — lightweight catalog) -# ============================================================================= +# ============================================================================ +# Controls Library +# ============================================================================ + @router.get("/controls") def list_controls( tenant_id: Optional[str] = Query(None), domain: Optional[str] = Query(None), - db: Session = Depends(get_db), + svc: ControlsLibraryService = Depends(_ctrl_svc), ): - tid = tenant_id or DEFAULT_TENANT_ID - where = ["tenant_id = :tid"] - params: dict = {"tid": tid} - - if domain: - where.append("domain = :domain") - params["domain"] = domain - - where_clause = " AND ".join(where) - - rows = db.execute(text(f""" - SELECT * FROM vendor_compliance_controls - WHERE {where_clause} - ORDER BY domain, control_code - """), params).fetchall() - - items = [] - for r in rows: - items.append({ - "id": str(r["id"]), - "tenantId": r["tenant_id"], - "domain": _get(r, "domain", ""), - "controlCode": _get(r, "control_code", ""), - "title": _get(r, "title", ""), - "description": _get(r, "description", ""), - "createdAt": _ts(r["created_at"]), - }) - - return _ok(items) + with translate_domain_errors(): + return svc.list_controls(tenant_id, domain) @router.post("/controls", status_code=201) -def create_control(body: dict = {}, db: Session = Depends(get_db)): - cid = str(uuid.uuid4()) - tid = body.get("tenantId", body.get("tenant_id", DEFAULT_TENANT_ID)) - now = datetime.now(timezone.utc).isoformat() - - db.execute(text(""" - INSERT INTO vendor_compliance_controls ( - id, tenant_id, domain, control_code, title, description, created_at - ) VALUES (:id, :tenant_id, :domain, :control_code, :title, :description, :created_at) - """), { - "id": cid, - "tenant_id": tid, - "domain": body.get("domain", ""), - "control_code": body.get("controlCode", body.get("control_code", "")), - "title": body.get("title", ""), - "description": body.get("description", ""), - "created_at": now, - }) - db.commit() - - return _ok({ - "id": cid, - "tenantId": tid, - "domain": body.get("domain", ""), - "controlCode": body.get("controlCode", body.get("control_code", "")), - "title": body.get("title", ""), - "description": body.get("description", ""), - "createdAt": now, - }) +def create_control( + body: dict = {}, + svc: ControlsLibraryService = Depends(_ctrl_svc), +): + with translate_domain_errors(): + return svc.create_control(body) @router.delete("/controls/{control_id}") -def delete_control(control_id: str, db: Session = Depends(get_db)): - result = db.execute(text("DELETE FROM vendor_compliance_controls WHERE id = :id"), - {"id": control_id}) - db.commit() - if result.rowcount == 0: - raise HTTPException(404, "Control not found") - return _ok({"deleted": True}) +def delete_control( + control_id: str, + svc: ControlsLibraryService = Depends(_ctrl_svc), +): + with translate_domain_errors(): + return svc.delete_control(control_id) -# ============================================================================= +# ============================================================================ # Export Stubs (501 Not Implemented) -# ============================================================================= +# ============================================================================ + @router.post("/export", status_code=501) def export_report(): diff --git a/backend-compliance/compliance/services/vendor_compliance_extra_service.py b/backend-compliance/compliance/services/vendor_compliance_extra_service.py new file mode 100644 index 0000000..660e47e --- /dev/null +++ b/backend-compliance/compliance/services/vendor_compliance_extra_service.py @@ -0,0 +1,408 @@ +# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" +""" +Vendor compliance extra entities — Findings, Control Instances, and +Controls Library CRUD. + +Phase 1 Step 4: extracted from ``compliance.api.vendor_compliance_routes``. +Shares helpers with ``compliance.services.vendor_compliance_service`` and +row converters from ``compliance.services.vendor_compliance_sub_service``. +""" + +import json +import uuid +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import text +from sqlalchemy.orm import Session + +from compliance.domain import NotFoundError +from compliance.services.vendor_compliance_service import ( + DEFAULT_TENANT_ID, + _get, + _ok, + _to_snake, + _ts, +) +from compliance.services.vendor_compliance_sub_service import ( + _control_instance_to_response, + _finding_to_response, +) + + +# ============================================================================ +# FindingService +# ============================================================================ + + +class FindingService: + """Vendor findings CRUD.""" + + def __init__(self, db: Session) -> None: + self._db = db + + def list_findings( + self, + tenant_id: Optional[str] = None, + vendor_id: Optional[str] = None, + severity: Optional[str] = None, + status: Optional[str] = None, + skip: int = 0, + limit: int = 100, + ) -> dict: + tid = tenant_id or DEFAULT_TENANT_ID + where = ["tenant_id = :tid"] + params: dict = {"tid": tid} + if vendor_id: + where.append("vendor_id = :vendor_id") + params["vendor_id"] = vendor_id + if severity: + where.append("severity = :severity") + params["severity"] = severity + if status: + where.append("status = :status") + params["status"] = status + where_clause = " AND ".join(where) + params["lim"] = limit + params["off"] = skip + + rows = self._db.execute(text(f""" + SELECT * FROM vendor_findings + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT :lim OFFSET :off + """), params).fetchall() + return _ok([_finding_to_response(r) for r in rows]) + + def get_finding(self, finding_id: str) -> dict: + row = self._db.execute( + text("SELECT * FROM vendor_findings WHERE id = :id"), + {"id": finding_id}, + ).fetchone() + if not row: + raise NotFoundError("Finding not found") + return _ok(_finding_to_response(row)) + + def create_finding(self, body: dict) -> dict: + data = _to_snake(body) + fid = str(uuid.uuid4()) + tid = data.get("tenant_id", DEFAULT_TENANT_ID) + now = datetime.now(timezone.utc).isoformat() + + self._db.execute(text(""" + INSERT INTO vendor_findings ( + id, tenant_id, vendor_id, contract_id, + finding_type, category, severity, + title, description, recommendation, + citations, status, assignee, due_date, + created_at, updated_at, created_by + ) VALUES ( + :id, :tenant_id, :vendor_id, :contract_id, + :finding_type, :category, :severity, + :title, :description, :recommendation, + CAST(:citations AS jsonb), :status, :assignee, :due_date, + :created_at, :updated_at, :created_by + ) + """), { + "id": fid, "tenant_id": tid, + "vendor_id": data.get("vendor_id", ""), + "contract_id": data.get("contract_id"), + "finding_type": data.get("finding_type", "UNKNOWN"), + "category": data.get("category", ""), + "severity": data.get("severity", "MEDIUM"), + "title": data.get("title", ""), + "description": data.get("description", ""), + "recommendation": data.get("recommendation", ""), + "citations": json.dumps(data.get("citations", [])), + "status": data.get("status", "OPEN"), + "assignee": data.get("assignee", ""), + "due_date": data.get("due_date"), + "created_at": now, "updated_at": now, + "created_by": data.get("created_by", "system"), + }) + self._db.commit() + row = self._db.execute( + text("SELECT * FROM vendor_findings WHERE id = :id"), + {"id": fid}, + ).fetchone() + return _ok(_finding_to_response(row)) + + def update_finding(self, finding_id: str, body: dict) -> dict: + existing = self._db.execute( + text("SELECT id FROM vendor_findings WHERE id = :id"), + {"id": finding_id}, + ).fetchone() + if not existing: + raise NotFoundError("Finding not found") + + data = _to_snake(body) + now = datetime.now(timezone.utc).isoformat() + allowed = [ + "vendor_id", "contract_id", "finding_type", "category", + "severity", "title", "description", "recommendation", + "status", "assignee", "due_date", + "resolution", "resolved_at", "resolved_by", + ] + jsonb_fields = ["citations"] + + sets = ["updated_at = :updated_at"] + params: dict = {"id": finding_id, "updated_at": now} + for col in allowed: + if col in data: + sets.append(f"{col} = :{col}") + params[col] = data[col] + for col in jsonb_fields: + if col in data: + sets.append(f"{col} = CAST(:{col} AS jsonb)") + params[col] = json.dumps(data[col]) + + self._db.execute( + text( + f"UPDATE vendor_findings SET {', '.join(sets)} WHERE id = :id", + ), + params, + ) + self._db.commit() + row = self._db.execute( + text("SELECT * FROM vendor_findings WHERE id = :id"), + {"id": finding_id}, + ).fetchone() + return _ok(_finding_to_response(row)) + + def delete_finding(self, finding_id: str) -> dict: + result = self._db.execute( + text("DELETE FROM vendor_findings WHERE id = :id"), + {"id": finding_id}, + ) + self._db.commit() + if result.rowcount == 0: + raise NotFoundError("Finding not found") + return _ok({"deleted": True}) + + +# ============================================================================ +# ControlInstanceService +# ============================================================================ + + +class ControlInstanceService: + """Vendor control instances CRUD.""" + + def __init__(self, db: Session) -> None: + self._db = db + + def list_instances( + self, + tenant_id: Optional[str] = None, + vendor_id: Optional[str] = None, + skip: int = 0, + limit: int = 100, + ) -> dict: + tid = tenant_id or DEFAULT_TENANT_ID + where = ["tenant_id = :tid"] + params: dict = {"tid": tid} + if vendor_id: + where.append("vendor_id = :vendor_id") + params["vendor_id"] = vendor_id + where_clause = " AND ".join(where) + params["lim"] = limit + params["off"] = skip + + rows = self._db.execute(text(f""" + SELECT * FROM vendor_control_instances + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT :lim OFFSET :off + """), params).fetchall() + return _ok([_control_instance_to_response(r) for r in rows]) + + def get_instance(self, instance_id: str) -> dict: + row = self._db.execute( + text("SELECT * FROM vendor_control_instances WHERE id = :id"), + {"id": instance_id}, + ).fetchone() + if not row: + raise NotFoundError("Control instance not found") + return _ok(_control_instance_to_response(row)) + + def create_instance(self, body: dict) -> dict: + data = _to_snake(body) + ciid = str(uuid.uuid4()) + tid = data.get("tenant_id", DEFAULT_TENANT_ID) + now = datetime.now(timezone.utc).isoformat() + + self._db.execute(text(""" + INSERT INTO vendor_control_instances ( + id, tenant_id, vendor_id, control_id, control_domain, + status, evidence_ids, notes, + last_assessed_at, last_assessed_by, next_assessment_date, + created_at, updated_at, created_by + ) VALUES ( + :id, :tenant_id, :vendor_id, :control_id, :control_domain, + :status, CAST(:evidence_ids AS jsonb), :notes, + :last_assessed_at, :last_assessed_by, + :next_assessment_date, + :created_at, :updated_at, :created_by + ) + """), { + "id": ciid, "tenant_id": tid, + "vendor_id": data.get("vendor_id", ""), + "control_id": data.get("control_id", ""), + "control_domain": data.get("control_domain", ""), + "status": data.get("status", "PLANNED"), + "evidence_ids": json.dumps(data.get("evidence_ids", [])), + "notes": data.get("notes", ""), + "last_assessed_at": data.get("last_assessed_at"), + "last_assessed_by": data.get("last_assessed_by", ""), + "next_assessment_date": data.get("next_assessment_date"), + "created_at": now, "updated_at": now, + "created_by": data.get("created_by", "system"), + }) + self._db.commit() + row = self._db.execute( + text("SELECT * FROM vendor_control_instances WHERE id = :id"), + {"id": ciid}, + ).fetchone() + return _ok(_control_instance_to_response(row)) + + def update_instance(self, instance_id: str, body: dict) -> dict: + existing = self._db.execute( + text("SELECT id FROM vendor_control_instances WHERE id = :id"), + {"id": instance_id}, + ).fetchone() + if not existing: + raise NotFoundError("Control instance not found") + + data = _to_snake(body) + now = datetime.now(timezone.utc).isoformat() + allowed = [ + "vendor_id", "control_id", "control_domain", + "status", "notes", + "last_assessed_at", "last_assessed_by", + "next_assessment_date", + ] + jsonb_fields = ["evidence_ids"] + + sets = ["updated_at = :updated_at"] + params: dict = {"id": instance_id, "updated_at": now} + for col in allowed: + if col in data: + sets.append(f"{col} = :{col}") + params[col] = data[col] + for col in jsonb_fields: + if col in data: + sets.append(f"{col} = CAST(:{col} AS jsonb)") + params[col] = json.dumps(data[col]) + + self._db.execute(text( + f"UPDATE vendor_control_instances SET {', '.join(sets)} " + f"WHERE id = :id", + ), params) + self._db.commit() + row = self._db.execute( + text("SELECT * FROM vendor_control_instances WHERE id = :id"), + {"id": instance_id}, + ).fetchone() + return _ok(_control_instance_to_response(row)) + + def delete_instance(self, instance_id: str) -> dict: + result = self._db.execute( + text("DELETE FROM vendor_control_instances WHERE id = :id"), + {"id": instance_id}, + ) + self._db.commit() + if result.rowcount == 0: + raise NotFoundError("Control instance not found") + return _ok({"deleted": True}) + + +# ============================================================================ +# ControlsLibraryService +# ============================================================================ + + +class ControlsLibraryService: + """Controls library (vendor_compliance_controls catalog).""" + + def __init__(self, db: Session) -> None: + self._db = db + + def list_controls( + self, + tenant_id: Optional[str] = None, + domain: Optional[str] = None, + ) -> dict: + tid = tenant_id or DEFAULT_TENANT_ID + where = ["tenant_id = :tid"] + params: dict = {"tid": tid} + if domain: + where.append("domain = :domain") + params["domain"] = domain + where_clause = " AND ".join(where) + + rows = self._db.execute(text(f""" + SELECT * FROM vendor_compliance_controls + WHERE {where_clause} + ORDER BY domain, control_code + """), params).fetchall() + + items = [] + for r in rows: + items.append({ + "id": str(r["id"]), + "tenantId": r["tenant_id"], + "domain": _get(r, "domain", ""), + "controlCode": _get(r, "control_code", ""), + "title": _get(r, "title", ""), + "description": _get(r, "description", ""), + "createdAt": _ts(r["created_at"]), + }) + return _ok(items) + + def create_control(self, body: dict) -> dict: + cid = str(uuid.uuid4()) + tid = body.get( + "tenantId", body.get("tenant_id", DEFAULT_TENANT_ID), + ) + now = datetime.now(timezone.utc).isoformat() + + self._db.execute(text(""" + INSERT INTO vendor_compliance_controls ( + id, tenant_id, domain, control_code, title, description, + created_at + ) VALUES ( + :id, :tenant_id, :domain, :control_code, :title, + :description, :created_at + ) + """), { + "id": cid, "tenant_id": tid, + "domain": body.get("domain", ""), + "control_code": body.get( + "controlCode", body.get("control_code", ""), + ), + "title": body.get("title", ""), + "description": body.get("description", ""), + "created_at": now, + }) + self._db.commit() + + return _ok({ + "id": cid, "tenantId": tid, + "domain": body.get("domain", ""), + "controlCode": body.get( + "controlCode", body.get("control_code", ""), + ), + "title": body.get("title", ""), + "description": body.get("description", ""), + "createdAt": now, + }) + + def delete_control(self, control_id: str) -> dict: + result = self._db.execute( + text("DELETE FROM vendor_compliance_controls WHERE id = :id"), + {"id": control_id}, + ) + self._db.commit() + if result.rowcount == 0: + raise NotFoundError("Control not found") + return _ok({"deleted": True}) diff --git a/backend-compliance/compliance/services/vendor_compliance_service.py b/backend-compliance/compliance/services/vendor_compliance_service.py new file mode 100644 index 0000000..d23580a --- /dev/null +++ b/backend-compliance/compliance/services/vendor_compliance_service.py @@ -0,0 +1,489 @@ +# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" +""" +Vendor compliance service — Vendors CRUD + stats + status patch. + +Phase 1 Step 4: extracted from ``compliance.api.vendor_compliance_routes``. +Helpers (_now_iso, _ok, _parse_json, _ts, _get, _to_snake, _to_camel, +_vendor_to_response, camelCase maps) are shared by both vendor service +modules and re-exported from the routes module for legacy test imports. +""" + +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, Optional + +from sqlalchemy import text +from sqlalchemy.orm import Session + +from compliance.domain import NotFoundError, ValidationError + +logger = logging.getLogger(__name__) + +# Default tenant UUID +DEFAULT_TENANT_ID = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e" + +# ============================================================================ +# Helpers (shared across vendor service modules) +# ============================================================================ + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + "Z" + + +def _ok(data: Any, status_code: int = 200) -> dict: + """Wrap response in {success, data, timestamp} envelope.""" + return {"success": True, "data": data, "timestamp": _now_iso()} + + +def _parse_json(val: Any, default: Any = None) -> Any: + """Parse a JSONB/TEXT field -> Python object.""" + if val is None: + return default if default is not None else None + if isinstance(val, (dict, list)): + return val + if isinstance(val, str): + try: + return json.loads(val) + except Exception: + return default if default is not None else val + return val + + +def _ts(val: Any) -> Optional[str]: + """Timestamp -> ISO string or None.""" + if not val: + return None + if isinstance(val, str): + return val + return val.isoformat() + + +def _get(row: Any, key: str, default: Any = None) -> Any: + """Safe row access.""" + try: + v = row[key] + return default if v is None and default is not None else v + except (KeyError, IndexError): + return default + + +# camelCase <-> snake_case conversion maps +_VENDOR_CAMEL_TO_SNAKE = { + "legalForm": "legal_form", + "serviceDescription": "service_description", + "serviceCategory": "service_category", + "dataAccessLevel": "data_access_level", + "processingLocations": "processing_locations", + "transferMechanisms": "transfer_mechanisms", + "primaryContact": "primary_contact", + "dpoContact": "dpo_contact", + "securityContact": "security_contact", + "contractTypes": "contract_types", + "inherentRiskScore": "inherent_risk_score", + "residualRiskScore": "residual_risk_score", + "manualRiskAdjustment": "manual_risk_adjustment", + "riskJustification": "risk_justification", + "reviewFrequency": "review_frequency", + "lastReviewDate": "last_review_date", + "nextReviewDate": "next_review_date", + "processingActivityIds": "processing_activity_ids", + "contactName": "contact_name", + "contactEmail": "contact_email", + "contactPhone": "contact_phone", + "contactDepartment": "contact_department", + "tenantId": "tenant_id", + "createdAt": "created_at", + "updatedAt": "updated_at", + "createdBy": "created_by", + "vendorId": "vendor_id", + "contractId": "contract_id", + "controlId": "control_id", + "controlDomain": "control_domain", + "evidenceIds": "evidence_ids", + "lastAssessedAt": "last_assessed_at", + "lastAssessedBy": "last_assessed_by", + "nextAssessmentDate": "next_assessment_date", + "fileName": "file_name", + "originalName": "original_name", + "mimeType": "mime_type", + "fileSize": "file_size", + "storagePath": "storage_path", + "documentType": "document_type", + "previousVersionId": "previous_version_id", + "effectiveDate": "effective_date", + "expirationDate": "expiration_date", + "autoRenewal": "auto_renewal", + "renewalNoticePeriod": "renewal_notice_period", + "terminationNoticePeriod": "termination_notice_period", + "reviewStatus": "review_status", + "reviewCompletedAt": "review_completed_at", + "complianceScore": "compliance_score", + "extractedText": "extracted_text", + "pageCount": "page_count", + "findingType": "finding_type", + "dueDate": "due_date", + "resolvedAt": "resolved_at", + "resolvedBy": "resolved_by", +} + +_VENDOR_SNAKE_TO_CAMEL = {v: k for k, v in _VENDOR_CAMEL_TO_SNAKE.items()} + + +def _to_snake(data: dict) -> dict: + """Convert camelCase keys in data to snake_case for DB storage.""" + result = {} + for k, v in data.items(): + snake = _VENDOR_CAMEL_TO_SNAKE.get(k, k) + result[snake] = v + return result + + +def _to_camel(data: dict) -> dict: + """Convert snake_case keys to camelCase for frontend.""" + result = {} + for k, v in data.items(): + camel = _VENDOR_SNAKE_TO_CAMEL.get(k, k) + result[camel] = v + return result + + +# ============================================================================ +# Row -> Response converters +# ============================================================================ + + +def _vendor_to_response(row: Any) -> dict: + return _to_camel({ + "id": str(row["id"]), + "tenant_id": row["tenant_id"], + "name": row["name"], + "legal_form": _get(row, "legal_form", ""), + "country": _get(row, "country", ""), + "address": _get(row, "address", ""), + "website": _get(row, "website", ""), + "role": _get(row, "role", "PROCESSOR"), + "service_description": _get(row, "service_description", ""), + "service_category": _get(row, "service_category", "OTHER"), + "data_access_level": _get(row, "data_access_level", "NONE"), + "processing_locations": _parse_json(_get(row, "processing_locations"), []), + "transfer_mechanisms": _parse_json(_get(row, "transfer_mechanisms"), []), + "certifications": _parse_json(_get(row, "certifications"), []), + "primary_contact": _parse_json(_get(row, "primary_contact"), {}), + "dpo_contact": _parse_json(_get(row, "dpo_contact"), {}), + "security_contact": _parse_json(_get(row, "security_contact"), {}), + "contract_types": _parse_json(_get(row, "contract_types"), []), + "inherent_risk_score": _get(row, "inherent_risk_score", 50), + "residual_risk_score": _get(row, "residual_risk_score", 50), + "manual_risk_adjustment": _get(row, "manual_risk_adjustment"), + "risk_justification": _get(row, "risk_justification", ""), + "review_frequency": _get(row, "review_frequency", "ANNUAL"), + "last_review_date": _ts(_get(row, "last_review_date")), + "next_review_date": _ts(_get(row, "next_review_date")), + "status": _get(row, "status", "ACTIVE"), + "processing_activity_ids": _parse_json( + _get(row, "processing_activity_ids"), [], + ), + "notes": _get(row, "notes", ""), + "contact_name": _get(row, "contact_name", ""), + "contact_email": _get(row, "contact_email", ""), + "contact_phone": _get(row, "contact_phone", ""), + "contact_department": _get(row, "contact_department", ""), + "created_at": _ts(row["created_at"]), + "updated_at": _ts(row["updated_at"]), + "created_by": _get(row, "created_by", "system"), + }) + + +# ============================================================================ +# VendorService +# ============================================================================ + + +class VendorService: + """Vendor CRUD + stats + status patch.""" + + def __init__(self, db: Session) -> None: + self._db = db + + def get_stats(self, tenant_id: Optional[str] = None) -> dict: + tid = tenant_id or DEFAULT_TENANT_ID + result = self._db.execute(text(""" + SELECT + COUNT(*) AS total, + COUNT(*) FILTER (WHERE status = 'ACTIVE') AS active, + COUNT(*) FILTER (WHERE status = 'INACTIVE') AS inactive, + COUNT(*) FILTER (WHERE status = 'PENDING_REVIEW') AS pending_review, + COUNT(*) FILTER (WHERE status = 'TERMINATED') AS terminated, + COALESCE(AVG(inherent_risk_score), 0) AS avg_inherent_risk, + COALESCE(AVG(residual_risk_score), 0) AS avg_residual_risk, + COUNT(*) FILTER (WHERE inherent_risk_score >= 75) AS high_risk_count + FROM vendor_vendors + WHERE tenant_id = :tid + """), {"tid": tid}) + row = result.fetchone() + if row is None: + stats = { + "total": 0, "active": 0, "inactive": 0, + "pending_review": 0, "terminated": 0, + "avg_inherent_risk": 0, "avg_residual_risk": 0, + "high_risk_count": 0, + } + else: + stats = { + "total": row["total"] or 0, + "active": row["active"] or 0, + "inactive": row["inactive"] or 0, + "pendingReview": row["pending_review"] or 0, + "terminated": row["terminated"] or 0, + "avgInherentRisk": round( + float(row["avg_inherent_risk"] or 0), 1, + ), + "avgResidualRisk": round( + float(row["avg_residual_risk"] or 0), 1, + ), + "highRiskCount": row["high_risk_count"] or 0, + } + return _ok(stats) + + def list_vendors( + self, + tenant_id: Optional[str] = None, + status: Optional[str] = None, + risk_level: Optional[str] = None, + search: Optional[str] = None, + skip: int = 0, + limit: int = 100, + ) -> dict: + tid = tenant_id or DEFAULT_TENANT_ID + where = ["tenant_id = :tid"] + params: dict = {"tid": tid} + + if status: + where.append("status = :status") + params["status"] = status + if risk_level: + if risk_level == "HIGH": + where.append("inherent_risk_score >= 75") + elif risk_level == "MEDIUM": + where.append( + "inherent_risk_score >= 40 AND inherent_risk_score < 75", + ) + elif risk_level == "LOW": + where.append("inherent_risk_score < 40") + if search: + where.append( + "(name ILIKE :search OR service_description ILIKE :search)", + ) + params["search"] = f"%{search}%" + + where_clause = " AND ".join(where) + params["lim"] = limit + params["off"] = skip + + rows = self._db.execute(text(f""" + SELECT * FROM vendor_vendors + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT :lim OFFSET :off + """), params).fetchall() + + count_row = self._db.execute(text(f""" + SELECT COUNT(*) AS cnt FROM vendor_vendors WHERE {where_clause} + """), {k: v for k, v in params.items() if k not in ("lim", "off")}).fetchone() + total = count_row["cnt"] if count_row else 0 + + return _ok({ + "items": [_vendor_to_response(r) for r in rows], + "total": total, + }) + + def get_vendor(self, vendor_id: str) -> dict: + row = self._db.execute( + text("SELECT * FROM vendor_vendors WHERE id = :id"), + {"id": vendor_id}, + ).fetchone() + if not row: + raise NotFoundError("Vendor not found") + return _ok(_vendor_to_response(row)) + + def create_vendor(self, body: dict) -> dict: + data = _to_snake(body) + vid = str(uuid.uuid4()) + tid = data.get("tenant_id", DEFAULT_TENANT_ID) + now = datetime.now(timezone.utc).isoformat() + + self._db.execute(text(""" + INSERT INTO vendor_vendors ( + id, tenant_id, name, legal_form, country, address, website, + role, service_description, service_category, data_access_level, + processing_locations, transfer_mechanisms, certifications, + primary_contact, dpo_contact, security_contact, + contract_types, inherent_risk_score, residual_risk_score, + manual_risk_adjustment, risk_justification, + review_frequency, last_review_date, next_review_date, + status, processing_activity_ids, notes, + contact_name, contact_email, contact_phone, + contact_department, + created_at, updated_at, created_by + ) VALUES ( + :id, :tenant_id, :name, :legal_form, :country, :address, + :website, :role, :service_description, :service_category, + :data_access_level, + CAST(:processing_locations AS jsonb), + CAST(:transfer_mechanisms AS jsonb), + CAST(:certifications AS jsonb), + CAST(:primary_contact AS jsonb), + CAST(:dpo_contact AS jsonb), + CAST(:security_contact AS jsonb), + CAST(:contract_types AS jsonb), + :inherent_risk_score, :residual_risk_score, + :manual_risk_adjustment, :risk_justification, + :review_frequency, :last_review_date, :next_review_date, + :status, CAST(:processing_activity_ids AS jsonb), :notes, + :contact_name, :contact_email, :contact_phone, + :contact_department, + :created_at, :updated_at, :created_by + ) + """), { + "id": vid, "tenant_id": tid, + "name": data.get("name", ""), + "legal_form": data.get("legal_form", ""), + "country": data.get("country", ""), + "address": data.get("address", ""), + "website": data.get("website", ""), + "role": data.get("role", "PROCESSOR"), + "service_description": data.get("service_description", ""), + "service_category": data.get("service_category", "OTHER"), + "data_access_level": data.get("data_access_level", "NONE"), + "processing_locations": json.dumps( + data.get("processing_locations", []), + ), + "transfer_mechanisms": json.dumps( + data.get("transfer_mechanisms", []), + ), + "certifications": json.dumps(data.get("certifications", [])), + "primary_contact": json.dumps(data.get("primary_contact", {})), + "dpo_contact": json.dumps(data.get("dpo_contact", {})), + "security_contact": json.dumps(data.get("security_contact", {})), + "contract_types": json.dumps(data.get("contract_types", [])), + "inherent_risk_score": data.get("inherent_risk_score", 50), + "residual_risk_score": data.get("residual_risk_score", 50), + "manual_risk_adjustment": data.get("manual_risk_adjustment"), + "risk_justification": data.get("risk_justification", ""), + "review_frequency": data.get("review_frequency", "ANNUAL"), + "last_review_date": data.get("last_review_date"), + "next_review_date": data.get("next_review_date"), + "status": data.get("status", "ACTIVE"), + "processing_activity_ids": json.dumps( + data.get("processing_activity_ids", []), + ), + "notes": data.get("notes", ""), + "contact_name": data.get("contact_name", ""), + "contact_email": data.get("contact_email", ""), + "contact_phone": data.get("contact_phone", ""), + "contact_department": data.get("contact_department", ""), + "created_at": now, "updated_at": now, + "created_by": data.get("created_by", "system"), + }) + self._db.commit() + + row = self._db.execute( + text("SELECT * FROM vendor_vendors WHERE id = :id"), + {"id": vid}, + ).fetchone() + return _ok(_vendor_to_response(row)) + + def update_vendor(self, vendor_id: str, body: dict) -> dict: + existing = self._db.execute( + text("SELECT id FROM vendor_vendors WHERE id = :id"), + {"id": vendor_id}, + ).fetchone() + if not existing: + raise NotFoundError("Vendor not found") + + data = _to_snake(body) + now = datetime.now(timezone.utc).isoformat() + + allowed = [ + "name", "legal_form", "country", "address", "website", + "role", "service_description", "service_category", + "data_access_level", + "inherent_risk_score", "residual_risk_score", + "manual_risk_adjustment", "risk_justification", + "review_frequency", "last_review_date", "next_review_date", + "status", "notes", + "contact_name", "contact_email", "contact_phone", + "contact_department", + ] + jsonb_fields = [ + "processing_locations", "transfer_mechanisms", "certifications", + "primary_contact", "dpo_contact", "security_contact", + "contract_types", "processing_activity_ids", + ] + + sets = ["updated_at = :updated_at"] + params: dict = {"id": vendor_id, "updated_at": now} + + for col in allowed: + if col in data: + sets.append(f"{col} = :{col}") + params[col] = data[col] + + for col in jsonb_fields: + if col in data: + sets.append(f"{col} = CAST(:{col} AS jsonb)") + params[col] = json.dumps(data[col]) + + self._db.execute( + text(f"UPDATE vendor_vendors SET {', '.join(sets)} WHERE id = :id"), + params, + ) + self._db.commit() + + row = self._db.execute( + text("SELECT * FROM vendor_vendors WHERE id = :id"), + {"id": vendor_id}, + ).fetchone() + return _ok(_vendor_to_response(row)) + + def delete_vendor(self, vendor_id: str) -> dict: + result = self._db.execute( + text("DELETE FROM vendor_vendors WHERE id = :id"), + {"id": vendor_id}, + ) + self._db.commit() + if result.rowcount == 0: + raise NotFoundError("Vendor not found") + return _ok({"deleted": True}) + + def patch_status(self, vendor_id: str, body: dict) -> dict: + new_status = body.get("status") + if not new_status: + raise ValidationError("status is required") + valid = {"ACTIVE", "INACTIVE", "PENDING_REVIEW", "TERMINATED"} + if new_status not in valid: + raise ValidationError( + f"Invalid status. Must be one of: {', '.join(sorted(valid))}", + ) + + result = self._db.execute(text(""" + UPDATE vendor_vendors + SET status = :status, updated_at = :now + WHERE id = :id + """), { + "id": vendor_id, + "status": new_status, + "now": datetime.now(timezone.utc).isoformat(), + }) + self._db.commit() + if result.rowcount == 0: + raise NotFoundError("Vendor not found") + + row = self._db.execute( + text("SELECT * FROM vendor_vendors WHERE id = :id"), + {"id": vendor_id}, + ).fetchone() + return _ok(_vendor_to_response(row)) diff --git a/backend-compliance/compliance/services/vendor_compliance_sub_service.py b/backend-compliance/compliance/services/vendor_compliance_sub_service.py new file mode 100644 index 0000000..e84d697 --- /dev/null +++ b/backend-compliance/compliance/services/vendor_compliance_sub_service.py @@ -0,0 +1,282 @@ +# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return" +""" +Vendor compliance sub-entities — Contracts CRUD + row converters for +contracts, findings, and control instances. + +Phase 1 Step 4: extracted from ``compliance.api.vendor_compliance_routes``. +Shares helpers with ``compliance.services.vendor_compliance_service``. +""" + +import json +import uuid +from datetime import datetime, timezone +from typing import Any, Optional + +from sqlalchemy import text +from sqlalchemy.orm import Session + +from compliance.domain import NotFoundError +from compliance.services.vendor_compliance_service import ( + DEFAULT_TENANT_ID, + _get, + _ok, + _parse_json, + _to_camel, + _to_snake, + _ts, +) + + +# ============================================================================ +# Row -> Response converters (shared with extra service) +# ============================================================================ + + +def _contract_to_response(row: Any) -> dict: + return _to_camel({ + "id": str(row["id"]), + "tenant_id": row["tenant_id"], + "vendor_id": str(row["vendor_id"]), + "file_name": _get(row, "file_name", ""), + "original_name": _get(row, "original_name", ""), + "mime_type": _get(row, "mime_type", ""), + "file_size": _get(row, "file_size", 0), + "storage_path": _get(row, "storage_path", ""), + "document_type": _get(row, "document_type", "AVV"), + "version": _get(row, "version", 1), + "previous_version_id": ( + str(_get(row, "previous_version_id")) + if _get(row, "previous_version_id") else None + ), + "parties": _parse_json(_get(row, "parties"), []), + "effective_date": _ts(_get(row, "effective_date")), + "expiration_date": _ts(_get(row, "expiration_date")), + "auto_renewal": _get(row, "auto_renewal", False), + "renewal_notice_period": _get(row, "renewal_notice_period", ""), + "termination_notice_period": _get( + row, "termination_notice_period", "", + ), + "review_status": _get(row, "review_status", "PENDING"), + "review_completed_at": _ts(_get(row, "review_completed_at")), + "compliance_score": _get(row, "compliance_score"), + "status": _get(row, "status", "DRAFT"), + "extracted_text": _get(row, "extracted_text", ""), + "page_count": _get(row, "page_count", 0), + "created_at": _ts(row["created_at"]), + "updated_at": _ts(row["updated_at"]), + "created_by": _get(row, "created_by", "system"), + }) + + +def _finding_to_response(row: Any) -> dict: + return _to_camel({ + "id": str(row["id"]), + "tenant_id": row["tenant_id"], + "vendor_id": str(row["vendor_id"]), + "contract_id": ( + str(_get(row, "contract_id")) + if _get(row, "contract_id") else None + ), + "finding_type": _get(row, "finding_type", "UNKNOWN"), + "category": _get(row, "category", ""), + "severity": _get(row, "severity", "MEDIUM"), + "title": _get(row, "title", ""), + "description": _get(row, "description", ""), + "recommendation": _get(row, "recommendation", ""), + "citations": _parse_json(_get(row, "citations"), []), + "status": _get(row, "status", "OPEN"), + "assignee": _get(row, "assignee", ""), + "due_date": _ts(_get(row, "due_date")), + "resolution": _get(row, "resolution", ""), + "resolved_at": _ts(_get(row, "resolved_at")), + "resolved_by": _get(row, "resolved_by", ""), + "created_at": _ts(row["created_at"]), + "updated_at": _ts(row["updated_at"]), + "created_by": _get(row, "created_by", "system"), + }) + + +def _control_instance_to_response(row: Any) -> dict: + return _to_camel({ + "id": str(row["id"]), + "tenant_id": row["tenant_id"], + "vendor_id": str(row["vendor_id"]), + "control_id": _get(row, "control_id", ""), + "control_domain": _get(row, "control_domain", ""), + "status": _get(row, "status", "PLANNED"), + "evidence_ids": _parse_json(_get(row, "evidence_ids"), []), + "notes": _get(row, "notes", ""), + "last_assessed_at": _ts(_get(row, "last_assessed_at")), + "last_assessed_by": _get(row, "last_assessed_by", ""), + "next_assessment_date": _ts(_get(row, "next_assessment_date")), + "created_at": _ts(row["created_at"]), + "updated_at": _ts(row["updated_at"]), + "created_by": _get(row, "created_by", "system"), + }) + + +# ============================================================================ +# ContractService +# ============================================================================ + + +class ContractService: + """Vendor contracts CRUD.""" + + def __init__(self, db: Session) -> None: + self._db = db + + def list_contracts( + self, + tenant_id: Optional[str] = None, + vendor_id: Optional[str] = None, + status: Optional[str] = None, + skip: int = 0, + limit: int = 100, + ) -> dict: + tid = tenant_id or DEFAULT_TENANT_ID + where = ["tenant_id = :tid"] + params: dict = {"tid": tid} + if vendor_id: + where.append("vendor_id = :vendor_id") + params["vendor_id"] = vendor_id + if status: + where.append("status = :status") + params["status"] = status + where_clause = " AND ".join(where) + params["lim"] = limit + params["off"] = skip + + rows = self._db.execute(text(f""" + SELECT * FROM vendor_contracts + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT :lim OFFSET :off + """), params).fetchall() + return _ok([_contract_to_response(r) for r in rows]) + + def get_contract(self, contract_id: str) -> dict: + row = self._db.execute( + text("SELECT * FROM vendor_contracts WHERE id = :id"), + {"id": contract_id}, + ).fetchone() + if not row: + raise NotFoundError("Contract not found") + return _ok(_contract_to_response(row)) + + def create_contract(self, body: dict) -> dict: + data = _to_snake(body) + cid = str(uuid.uuid4()) + tid = data.get("tenant_id", DEFAULT_TENANT_ID) + now = datetime.now(timezone.utc).isoformat() + + self._db.execute(text(""" + INSERT INTO vendor_contracts ( + id, tenant_id, vendor_id, file_name, original_name, + mime_type, file_size, storage_path, document_type, + version, previous_version_id, + parties, effective_date, expiration_date, + auto_renewal, renewal_notice_period, + termination_notice_period, + review_status, status, compliance_score, + extracted_text, page_count, + created_at, updated_at, created_by + ) VALUES ( + :id, :tenant_id, :vendor_id, :file_name, :original_name, + :mime_type, :file_size, :storage_path, :document_type, + :version, :previous_version_id, + CAST(:parties AS jsonb), :effective_date, :expiration_date, + :auto_renewal, :renewal_notice_period, + :termination_notice_period, + :review_status, :status, :compliance_score, + :extracted_text, :page_count, + :created_at, :updated_at, :created_by + ) + """), { + "id": cid, "tenant_id": tid, + "vendor_id": data.get("vendor_id", ""), + "file_name": data.get("file_name", ""), + "original_name": data.get("original_name", ""), + "mime_type": data.get("mime_type", ""), + "file_size": data.get("file_size", 0), + "storage_path": data.get("storage_path", ""), + "document_type": data.get("document_type", "AVV"), + "version": data.get("version", 1), + "previous_version_id": data.get("previous_version_id"), + "parties": json.dumps(data.get("parties", [])), + "effective_date": data.get("effective_date"), + "expiration_date": data.get("expiration_date"), + "auto_renewal": data.get("auto_renewal", False), + "renewal_notice_period": data.get("renewal_notice_period", ""), + "termination_notice_period": data.get( + "termination_notice_period", "", + ), + "review_status": data.get("review_status", "PENDING"), + "status": data.get("status", "DRAFT"), + "compliance_score": data.get("compliance_score"), + "extracted_text": data.get("extracted_text", ""), + "page_count": data.get("page_count", 0), + "created_at": now, "updated_at": now, + "created_by": data.get("created_by", "system"), + }) + self._db.commit() + row = self._db.execute( + text("SELECT * FROM vendor_contracts WHERE id = :id"), + {"id": cid}, + ).fetchone() + return _ok(_contract_to_response(row)) + + def update_contract(self, contract_id: str, body: dict) -> dict: + existing = self._db.execute( + text("SELECT id FROM vendor_contracts WHERE id = :id"), + {"id": contract_id}, + ).fetchone() + if not existing: + raise NotFoundError("Contract not found") + + data = _to_snake(body) + now = datetime.now(timezone.utc).isoformat() + allowed = [ + "vendor_id", "file_name", "original_name", "mime_type", + "file_size", "storage_path", "document_type", "version", + "previous_version_id", "effective_date", "expiration_date", + "auto_renewal", "renewal_notice_period", + "termination_notice_period", + "review_status", "review_completed_at", "compliance_score", + "status", "extracted_text", "page_count", + ] + jsonb_fields = ["parties"] + + sets = ["updated_at = :updated_at"] + params: dict = {"id": contract_id, "updated_at": now} + for col in allowed: + if col in data: + sets.append(f"{col} = :{col}") + params[col] = data[col] + for col in jsonb_fields: + if col in data: + sets.append(f"{col} = CAST(:{col} AS jsonb)") + params[col] = json.dumps(data[col]) + + self._db.execute( + text( + f"UPDATE vendor_contracts SET {', '.join(sets)} WHERE id = :id", + ), + params, + ) + self._db.commit() + row = self._db.execute( + text("SELECT * FROM vendor_contracts WHERE id = :id"), + {"id": contract_id}, + ).fetchone() + return _ok(_contract_to_response(row)) + + def delete_contract(self, contract_id: str) -> dict: + result = self._db.execute( + text("DELETE FROM vendor_contracts WHERE id = :id"), + {"id": contract_id}, + ) + self._db.commit() + if result.rowcount == 0: + raise NotFoundError("Contract not found") + return _ok({"deleted": True})