Files
breakpilot-compliance/backend-compliance/compliance/services/regulation_requirement_service.py
Sharang Parnerkar 6658776610 refactor(backend/api): extract compliance routes services (Step 4 — file 13 of 18)
Split routes.py (991 LOC) into thin handlers + two service files:
- RegulationRequirementService: regulations CRUD, requirements CRUD
- ControlExportService: controls CRUD/review/domain, export, admin seeding

All 216 tests pass. Route module re-exports repository classes so
existing test patches (compliance.api.routes.*Repository) keep working.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 19:12:22 +02:00

411 lines
14 KiB
Python

# mypy: disable-error-code="arg-type,assignment,union-attr,no-any-return"
"""
Service for regulation and requirement business logic.
Phase 1 Step 4: extracted from ``compliance.api.routes``. All handler logic
for regulations CRUD and requirements CRUD lives here. The route module
delegates to this service and translates domain errors to HTTP responses.
"""
import logging
from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy.orm import Session
from compliance.db import RegulationRepository, RequirementRepository
from compliance.db.models import RegulationDB, RequirementDB
from compliance.domain import NotFoundError, ValidationError
from compliance.schemas.regulation import (
RegulationListResponse,
RegulationResponse,
)
from compliance.schemas.requirement import (
PaginatedRequirementResponse,
RequirementCreate,
RequirementListResponse,
RequirementResponse,
)
from compliance.schemas.common import PaginationMeta
logger = logging.getLogger(__name__)
def _regulation_to_response(
reg: Any, requirement_count: int
) -> RegulationResponse:
return RegulationResponse(
id=reg.id,
code=reg.code,
name=reg.name,
full_name=reg.full_name,
regulation_type=(
reg.regulation_type.value if reg.regulation_type else None
),
source_url=reg.source_url,
local_pdf_path=reg.local_pdf_path,
effective_date=reg.effective_date,
description=reg.description,
is_active=reg.is_active,
created_at=reg.created_at,
updated_at=reg.updated_at,
requirement_count=requirement_count,
)
def _requirement_to_response(r: Any, code: Optional[str] = None) -> RequirementResponse:
return RequirementResponse(
id=r.id,
regulation_id=r.regulation_id,
regulation_code=code,
article=r.article,
paragraph=r.paragraph,
title=r.title,
description=r.description,
requirement_text=r.requirement_text,
breakpilot_interpretation=r.breakpilot_interpretation,
is_applicable=r.is_applicable,
applicability_reason=r.applicability_reason,
priority=r.priority,
created_at=r.created_at,
updated_at=r.updated_at,
)
class RegulationRequirementService:
"""Business logic for regulation and requirement endpoints."""
def __init__(
self,
db: Session,
reg_repo_cls: Any = RegulationRepository,
req_repo_cls: Any = RequirementRepository,
) -> None:
self.db = db
self.reg_repo = reg_repo_cls(db)
self.req_repo = req_repo_cls(db)
# ------------------------------------------------------------------
# Regulations
# ------------------------------------------------------------------
def list_regulations(
self,
is_active: Optional[bool],
regulation_type: Optional[str],
) -> RegulationListResponse:
if is_active is not None:
regulations = (
self.reg_repo.get_active()
if is_active
else self.reg_repo.get_all()
)
else:
regulations = self.reg_repo.get_all()
if regulation_type:
from compliance.db.models import RegulationTypeEnum
try:
reg_type = RegulationTypeEnum(regulation_type)
regulations = [
r for r in regulations if r.regulation_type == reg_type
]
except ValueError:
pass
results = []
for reg in regulations:
reqs = self.req_repo.get_by_regulation(reg.id)
results.append(_regulation_to_response(reg, len(reqs)))
return RegulationListResponse(regulations=results, total=len(results))
def get_regulation(self, code: str) -> RegulationResponse:
regulation = self.reg_repo.get_by_code(code)
if not regulation:
raise NotFoundError(f"Regulation {code} not found")
reqs = self.req_repo.get_by_regulation(regulation.id)
return _regulation_to_response(regulation, len(reqs))
def get_regulation_requirements(
self,
code: str,
is_applicable: Optional[bool],
) -> RequirementListResponse:
regulation = self.reg_repo.get_by_code(code)
if not regulation:
raise NotFoundError(f"Regulation {code} not found")
if is_applicable is not None:
requirements = (
self.req_repo.get_applicable(regulation.id)
if is_applicable
else self.req_repo.get_by_regulation(regulation.id)
)
else:
requirements = self.req_repo.get_by_regulation(regulation.id)
results = [_requirement_to_response(r, code) for r in requirements]
return RequirementListResponse(
requirements=results, total=len(results)
)
# ------------------------------------------------------------------
# Requirements
# ------------------------------------------------------------------
def get_requirement(
self,
requirement_id: str,
include_legal_context: bool,
) -> dict[str, Any]:
requirement = (
self.db.query(RequirementDB)
.filter(RequirementDB.id == requirement_id)
.first()
)
if not requirement:
raise NotFoundError(
f"Requirement {requirement_id} not found"
)
regulation = (
self.db.query(RegulationDB)
.filter(RegulationDB.id == requirement.regulation_id)
.first()
)
result: dict[str, Any] = {
"id": requirement.id,
"regulation_id": requirement.regulation_id,
"regulation_code": regulation.code if regulation else None,
"article": requirement.article,
"paragraph": requirement.paragraph,
"title": requirement.title,
"description": requirement.description,
"requirement_text": requirement.requirement_text,
"breakpilot_interpretation": requirement.breakpilot_interpretation,
"implementation_status": (
requirement.implementation_status or "not_started"
),
"implementation_details": requirement.implementation_details,
"code_references": requirement.code_references,
"documentation_links": requirement.documentation_links,
"evidence_description": requirement.evidence_description,
"evidence_artifacts": requirement.evidence_artifacts,
"auditor_notes": requirement.auditor_notes,
"audit_status": requirement.audit_status or "pending",
"last_audit_date": requirement.last_audit_date,
"last_auditor": requirement.last_auditor,
"is_applicable": requirement.is_applicable,
"applicability_reason": requirement.applicability_reason,
"priority": requirement.priority,
"source_page": requirement.source_page,
"source_section": requirement.source_section,
}
if include_legal_context:
result["legal_context"] = self._fetch_legal_context(
requirement, regulation
)
return result
def list_requirements_paginated(
self,
page: int,
page_size: int,
regulation_code: Optional[str],
status: Optional[str],
is_applicable: Optional[bool],
search: Optional[str],
) -> PaginatedRequirementResponse:
requirements, total = self.req_repo.get_paginated(
page=page,
page_size=page_size,
regulation_code=regulation_code,
status=status,
is_applicable=is_applicable,
search=search,
)
total_pages = (total + page_size - 1) // page_size
results = [
RequirementResponse(
id=r.id,
regulation_id=r.regulation_id,
regulation_code=(
r.regulation.code if r.regulation else None
),
article=r.article,
paragraph=r.paragraph,
title=r.title,
description=r.description,
requirement_text=r.requirement_text,
breakpilot_interpretation=r.breakpilot_interpretation,
is_applicable=r.is_applicable,
applicability_reason=r.applicability_reason,
priority=r.priority,
implementation_status=(
r.implementation_status or "not_started"
),
implementation_details=r.implementation_details,
code_references=r.code_references,
documentation_links=r.documentation_links,
evidence_description=r.evidence_description,
evidence_artifacts=r.evidence_artifacts,
auditor_notes=r.auditor_notes,
audit_status=r.audit_status or "pending",
last_audit_date=r.last_audit_date,
last_auditor=r.last_auditor,
source_page=r.source_page,
source_section=r.source_section,
created_at=r.created_at,
updated_at=r.updated_at,
)
for r in requirements
]
return PaginatedRequirementResponse(
data=results,
pagination=PaginationMeta(
page=page,
page_size=page_size,
total=total,
total_pages=total_pages,
has_next=page < total_pages,
has_prev=page > 1,
),
)
def create_requirement(
self, data: RequirementCreate
) -> RequirementResponse:
regulation = self.reg_repo.get_by_id(data.regulation_id)
if not regulation:
raise NotFoundError(
f"Regulation {data.regulation_id} not found"
)
requirement = self.req_repo.create(
regulation_id=data.regulation_id,
article=data.article,
title=data.title,
paragraph=data.paragraph,
description=data.description,
requirement_text=data.requirement_text,
breakpilot_interpretation=data.breakpilot_interpretation,
is_applicable=data.is_applicable,
priority=data.priority,
)
return _requirement_to_response(requirement, regulation.code)
def delete_requirement(self, requirement_id: str) -> dict[str, Any]:
deleted = self.req_repo.delete(requirement_id)
if not deleted:
raise NotFoundError(
f"Requirement {requirement_id} not found"
)
return {"success": True, "message": "Requirement deleted"}
def update_requirement(
self, requirement_id: str, updates: dict[str, Any]
) -> dict[str, Any]:
requirement = (
self.db.query(RequirementDB)
.filter(RequirementDB.id == requirement_id)
.first()
)
if not requirement:
raise NotFoundError(
f"Requirement {requirement_id} not found"
)
allowed_fields = [
"implementation_status",
"implementation_details",
"code_references",
"documentation_links",
"evidence_description",
"evidence_artifacts",
"auditor_notes",
"audit_status",
"is_applicable",
"applicability_reason",
"breakpilot_interpretation",
]
for field in allowed_fields:
if field in updates:
setattr(requirement, field, updates[field])
if "audit_status" in updates:
requirement.last_audit_date = datetime.now(timezone.utc)
requirement.last_auditor = updates.get(
"auditor_name", "api_user"
)
requirement.updated_at = datetime.now(timezone.utc)
self.db.commit()
self.db.refresh(requirement)
return {"success": True, "message": "Requirement updated"}
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _fetch_legal_context_async(
self, requirement: Any, regulation: Any
) -> list[dict[str, Any]]:
"""Async version for RAG legal context."""
return self._fetch_legal_context(requirement, regulation)
def _fetch_legal_context(
self, requirement: Any, regulation: Any
) -> list[dict[str, Any]]:
try:
from compliance.services.rag_client import get_rag_client
from compliance.services.ai_compliance_assistant import (
AIComplianceAssistant,
)
import asyncio
rag = get_rag_client()
assistant = AIComplianceAssistant()
query = f"{requirement.title} {requirement.article or ''}"
collection = assistant._collection_for_regulation(
regulation.code if regulation else ""
)
# This is called from an async context but the method is sync
# We need to handle the async search call
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, return empty and let
# the route handler deal with it
return []
rag_results = loop.run_until_complete(
rag.search(query, collection=collection, top_k=3)
)
return [
{
"text": r.text,
"regulation_code": r.regulation_code,
"regulation_short": r.regulation_short,
"article": r.article,
"score": r.score,
"source_url": r.source_url,
}
for r in rag_results
]
except Exception as e:
logger.warning(
"Failed to fetch legal context for %s: %s",
requirement.id,
e,
)
return []