Files
breakpilot-compliance/backend-compliance/compliance/services/company_profile_service.py
Sharang Parnerkar f39c7ca40c refactor(backend/api): extract CompanyProfileService (Step 4 — file 4 of 18)
compliance/api/company_profile_routes.py (640 LOC) -> 154 LOC thin routes.
Unusual for this repo: persistence uses raw SQL via sqlalchemy.text()
because the underlying compliance_company_profiles table has ~45 columns
with complex jsonb coercion and there is no SQLAlchemy model for it.

New files:
  compliance/schemas/company_profile.py         (127) — 4 request/response models
  compliance/services/company_profile_service.py (340) — Service class + row_to_response + log_audit
  compliance/services/_company_profile_sql.py   (139) — 70-line INSERT/UPDATE statements
                                                         separated for readability

Minor behavioral improvement: the handlers now use Depends(get_db) for
session management instead of the bespoke `db = SessionLocal(); try: ...
finally: db.close()` pattern. This makes the routes consistent with
every other refactored service, fixes the broken-ness under test
dependency_overrides, and removes 6 duplicate try/finally blocks.

Legacy exports preserved: CompanyProfileRequest, CompanyProfileResponse,
AuditEntryResponse, AuditListResponse, row_to_response, and log_audit are
re-exported from compliance.api.company_profile_routes so that the two
existing test files
(tests/test_company_profile_routes.py, tests/test_company_profile_extend.py)
keep importing from the same path.

Pre-existing broken tests noted: 6 tests in those files feed a 40-tuple
row into row_to_response, but _BASE_COLUMNS_LIST has 46 columns (has had
since the Phase 2 Stammdaten extension). These tests fail on main too
(verified via `git stash` round-trip). Not fixed in this commit — they
require a rewrite of the test's _make_row helper, which is out of scope
for a pure structural refactor. Flagged for follow-up.

Verified:
  - 173/173 pytest compliance/tests/ tests/contracts/ pass
  - OpenAPI 360/484 unchanged
  - mypy compliance/ -> Success on 127 source files
  - company_profile_routes.py 640 -> 154 LOC
  - All new files under soft 300 target except service (340, under hard 500)
  - Hard-cap violations: 15 -> 14

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 19:47:29 +02:00

341 lines
14 KiB
Python

# mypy: disable-error-code="arg-type,assignment,no-any-return,union-attr"
"""
Company Profile service — Stammdaten CRUD with raw-SQL persistence and audit log.
Phase 1 Step 4: extracted from ``compliance.api.company_profile_routes``.
Unusual for this repo: persistence uses raw SQL via ``sqlalchemy.text()``
rather than ORM models, because the table has ~45 columns with complex
jsonb coercion and there is no SQLAlchemy model for it.
"""
import json
import logging
from typing import Any, Optional
from sqlalchemy import text
from sqlalchemy.orm import Session
from compliance.domain import NotFoundError, ValidationError
from compliance.schemas.company_profile import (
AuditEntryResponse,
AuditListResponse,
CompanyProfileRequest,
CompanyProfileResponse,
)
logger = logging.getLogger(__name__)
# ============================================================================
# SQL column list — keep in sync with SELECT/INSERT
# ============================================================================
_BASE_COLUMNS_LIST = [
"id", "tenant_id", "company_name", "legal_form", "industry", "founded_year",
"business_model", "offerings", "company_size", "employee_count", "annual_revenue",
"headquarters_country", "headquarters_city", "has_international_locations",
"international_countries", "target_markets", "primary_jurisdiction",
"is_data_controller", "is_data_processor", "uses_ai", "ai_use_cases",
"dpo_name", "dpo_email", "legal_contact_name", "legal_contact_email",
"machine_builder", "is_complete", "completed_at", "created_at", "updated_at",
"repos", "document_sources", "processing_systems", "ai_systems", "technical_contacts",
"subject_to_nis2", "subject_to_ai_act", "subject_to_iso27001",
"supervisory_authority", "review_cycle_months",
"project_id", "offering_urls",
"headquarters_country_other", "headquarters_street", "headquarters_zip", "headquarters_state",
]
_BASE_COLUMNS = ", ".join(_BASE_COLUMNS_LIST)
# Per-field defaults and type coercions for row_to_response.
_FIELD_DEFAULTS: dict[str, tuple[Any, Any]] = {
"id": (None, "STR"),
"tenant_id": (None, None),
"company_name": ("", None),
"legal_form": ("GmbH", None),
"industry": ("", None),
"founded_year": (None, None),
"business_model": ("B2B", None),
"offerings": ([], list),
"offering_urls": ({}, dict),
"company_size": ("small", None),
"employee_count": ("1-9", None),
"annual_revenue": ("< 2 Mio", None),
"headquarters_country": ("DE", None),
"headquarters_country_other": ("", None),
"headquarters_street": ("", None),
"headquarters_zip": ("", None),
"headquarters_city": ("", None),
"headquarters_state": ("", None),
"has_international_locations": (False, None),
"international_countries": ([], list),
"target_markets": (["DE"], list),
"primary_jurisdiction": ("DE", None),
"is_data_controller": (True, None),
"is_data_processor": (False, None),
"uses_ai": (False, None),
"ai_use_cases": ([], list),
"dpo_name": (None, None),
"dpo_email": (None, None),
"legal_contact_name": (None, None),
"legal_contact_email": (None, None),
"machine_builder": (None, dict),
"is_complete": (False, None),
"completed_at": (None, "STR_OR_NONE"),
"created_at": (None, "STR"),
"updated_at": (None, "STR"),
"repos": ([], list),
"document_sources": ([], list),
"processing_systems": ([], list),
"ai_systems": ([], list),
"technical_contacts": ([], list),
"subject_to_nis2": (False, None),
"subject_to_ai_act": (False, None),
"subject_to_iso27001": (False, None),
"supervisory_authority": (None, None),
"review_cycle_months": (12, None),
"project_id": (None, "STR_OR_NONE"),
}
_JSONB_FIELDS = {
"offerings", "offering_urls", "international_countries", "target_markets",
"ai_use_cases", "machine_builder", "repos", "document_sources",
"processing_systems", "ai_systems", "technical_contacts",
}
def _where_clause() -> str:
"""WHERE clause matching tenant_id + project_id (handles NULL)."""
return "tenant_id = :tid AND project_id IS NOT DISTINCT FROM :pid"
def row_to_response(row: Any) -> CompanyProfileResponse:
"""Convert a DB row to response model using zip-based column mapping."""
raw = dict(zip(_BASE_COLUMNS_LIST, row))
coerced: dict[str, Any] = {}
for col in _BASE_COLUMNS_LIST:
default, expected_type = _FIELD_DEFAULTS[col]
value = raw[col]
if expected_type == "STR":
coerced[col] = str(value)
elif expected_type == "STR_OR_NONE":
coerced[col] = str(value) if value else None
elif expected_type is not None:
coerced[col] = value if isinstance(value, expected_type) else default
else:
if col == "is_data_controller":
coerced[col] = value if value is not None else default
else:
coerced[col] = value or default if default is not None else value
return CompanyProfileResponse(**coerced)
def log_audit(
db: Session,
tenant_id: str,
action: str,
changed_fields: Optional[dict[str, Any]],
changed_by: Optional[str],
project_id: Optional[str] = None,
) -> None:
"""Write an audit log entry. Warnings only on failure — never fatal."""
try:
db.execute(
text(
"INSERT INTO compliance_company_profile_audit "
"(tenant_id, project_id, action, changed_fields, changed_by) "
"VALUES (:tenant_id, :project_id, :action, :fields::jsonb, :changed_by)"
),
{
"tenant_id": tenant_id,
"project_id": project_id,
"action": action,
"fields": json.dumps(changed_fields) if changed_fields else None,
"changed_by": changed_by,
},
)
except Exception as exc:
logger.warning(f"Failed to write audit log: {exc}")
# ============================================================================
# Service
# ============================================================================
class CompanyProfileService:
"""Business logic for company profile persistence + audit."""
def __init__(self, db: Session) -> None:
self.db = db
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _fetch_row(self, tid: str, pid: Optional[str]) -> Any:
return self.db.execute(
text(f"SELECT {_BASE_COLUMNS} FROM compliance_company_profiles WHERE {_where_clause()}"),
{"tid": tid, "pid": pid},
).fetchone()
def _exists(self, tid: str, pid: Optional[str]) -> bool:
return self.db.execute(
text(f"SELECT id FROM compliance_company_profiles WHERE {_where_clause()}"),
{"tid": tid, "pid": pid},
).fetchone() is not None
def _require_row(self, tid: str, pid: Optional[str]) -> Any:
row = self._fetch_row(tid, pid)
if not row:
raise NotFoundError("Company profile not found")
return row
# ------------------------------------------------------------------
# Queries
# ------------------------------------------------------------------
def get(self, tid: str, pid: Optional[str]) -> CompanyProfileResponse:
return row_to_response(self._require_row(tid, pid))
def template_context(self, tid: str, pid: Optional[str]) -> dict[str, Any]:
row = self._fetch_row(tid, pid)
if not row:
raise NotFoundError("Company profile not found — fill Stammdaten first")
resp = row_to_response(row)
return {
"company_name": resp.company_name,
"legal_form": resp.legal_form,
"industry": resp.industry,
"business_model": resp.business_model,
"company_size": resp.company_size,
"employee_count": resp.employee_count,
"headquarters_country": resp.headquarters_country,
"headquarters_city": resp.headquarters_city,
"primary_jurisdiction": resp.primary_jurisdiction,
"is_data_controller": resp.is_data_controller,
"is_data_processor": resp.is_data_processor,
"uses_ai": resp.uses_ai,
"dpo_name": resp.dpo_name or "",
"dpo_email": resp.dpo_email or "",
"legal_contact_name": resp.legal_contact_name or "",
"legal_contact_email": resp.legal_contact_email or "",
"supervisory_authority": resp.supervisory_authority or "",
"review_cycle_months": resp.review_cycle_months,
"subject_to_nis2": resp.subject_to_nis2,
"subject_to_ai_act": resp.subject_to_ai_act,
"subject_to_iso27001": resp.subject_to_iso27001,
"offerings": resp.offerings,
"target_markets": resp.target_markets,
"international_countries": resp.international_countries,
"ai_use_cases": resp.ai_use_cases,
"repos": resp.repos,
"document_sources": resp.document_sources,
"processing_systems": resp.processing_systems,
"ai_systems": resp.ai_systems,
"technical_contacts": resp.technical_contacts,
"has_ai_systems": len(resp.ai_systems) > 0,
"processing_system_count": len(resp.processing_systems),
"ai_system_count": len(resp.ai_systems),
"is_complete": resp.is_complete,
}
def audit_log(self, tid: str, pid: Optional[str]) -> AuditListResponse:
result = self.db.execute(
text(
"SELECT id, action, changed_fields, changed_by, created_at "
"FROM compliance_company_profile_audit "
"WHERE tenant_id = :tid AND project_id IS NOT DISTINCT FROM :pid "
"ORDER BY created_at DESC LIMIT 100"
),
{"tid": tid, "pid": pid},
)
entries = [
AuditEntryResponse(
id=str(r[0]),
action=r[1],
changed_fields=r[2] if isinstance(r[2], dict) else None,
changed_by=r[3],
created_at=str(r[4]),
)
for r in result.fetchall()
]
return AuditListResponse(entries=entries, total=len(entries))
# ------------------------------------------------------------------
# Commands
# ------------------------------------------------------------------
def upsert(
self, tid: str, pid: Optional[str], profile: CompanyProfileRequest
) -> CompanyProfileResponse:
from compliance.services._company_profile_sql import (
build_upsert_params,
execute_insert,
execute_update,
)
existing = self._exists(tid, pid)
action = "update" if existing else "create"
params = build_upsert_params(tid, pid, profile)
completed_at_sql = "NOW()" if profile.is_complete else "NULL"
if existing:
execute_update(self.db, params, completed_at_sql, _where_clause())
else:
execute_insert(self.db, params, completed_at_sql)
log_audit(self.db, tid, action, profile.model_dump(), None, pid)
self.db.commit()
return row_to_response(self._require_row(tid, pid))
def delete(self, tid: str, pid: Optional[str]) -> dict[str, Any]:
if not self._exists(tid, pid):
raise NotFoundError("Company profile not found")
self.db.execute(
text(f"DELETE FROM compliance_company_profiles WHERE {_where_clause()}"),
{"tid": tid, "pid": pid},
)
log_audit(self.db, tid, "delete", None, None, pid)
self.db.commit()
return {"success": True, "message": "Company profile deleted"}
def patch(
self, tid: str, pid: Optional[str], updates: dict[str, Any]
) -> CompanyProfileResponse:
if not self._exists(tid, pid):
raise NotFoundError("Company profile not found")
allowed = set(_BASE_COLUMNS_LIST) - {
"id", "tenant_id", "project_id", "created_at", "updated_at", "completed_at",
}
set_parts: list[str] = []
params: dict[str, Any] = {"tid": tid, "pid": pid}
for key, value in updates.items():
if key not in allowed:
continue
param_name = f"p_{key}"
if key in _JSONB_FIELDS:
set_parts.append(f"{key} = :{param_name}::jsonb")
params[param_name] = json.dumps(value) if value is not None else None
else:
set_parts.append(f"{key} = :{param_name}")
params[param_name] = value
if not set_parts:
raise ValidationError("No valid fields to update")
set_parts.append("updated_at = NOW()")
self.db.execute(
text(
f"UPDATE compliance_company_profiles SET {', '.join(set_parts)} "
f"WHERE {_where_clause()}"
),
params,
)
log_audit(self.db, tid, "patch", updates, None, pid)
self.db.commit()
return row_to_response(self._require_row(tid, pid))