# mypy: disable-error-code="arg-type,assignment,no-any-return,union-attr" """ Company Profile service — Stammdaten CRUD with raw-SQL persistence and audit log. Phase 1 Step 4: extracted from ``compliance.api.company_profile_routes``. Unusual for this repo: persistence uses raw SQL via ``sqlalchemy.text()`` rather than ORM models, because the table has ~45 columns with complex jsonb coercion and there is no SQLAlchemy model for it. """ import json import logging from typing import Any, Optional from sqlalchemy import text from sqlalchemy.orm import Session from compliance.domain import NotFoundError, ValidationError from compliance.schemas.company_profile import ( AuditEntryResponse, AuditListResponse, CompanyProfileRequest, CompanyProfileResponse, ) logger = logging.getLogger(__name__) # ============================================================================ # SQL column list — keep in sync with SELECT/INSERT # ============================================================================ _BASE_COLUMNS_LIST = [ "id", "tenant_id", "company_name", "legal_form", "industry", "founded_year", "business_model", "offerings", "company_size", "employee_count", "annual_revenue", "headquarters_country", "headquarters_city", "has_international_locations", "international_countries", "target_markets", "primary_jurisdiction", "is_data_controller", "is_data_processor", "uses_ai", "ai_use_cases", "dpo_name", "dpo_email", "legal_contact_name", "legal_contact_email", "machine_builder", "is_complete", "completed_at", "created_at", "updated_at", "repos", "document_sources", "processing_systems", "ai_systems", "technical_contacts", "subject_to_nis2", "subject_to_ai_act", "subject_to_iso27001", "supervisory_authority", "review_cycle_months", "project_id", "offering_urls", "headquarters_country_other", "headquarters_street", "headquarters_zip", "headquarters_state", ] _BASE_COLUMNS = ", ".join(_BASE_COLUMNS_LIST) # Per-field defaults and type coercions for row_to_response. _FIELD_DEFAULTS: dict[str, tuple[Any, Any]] = { "id": (None, "STR"), "tenant_id": (None, None), "company_name": ("", None), "legal_form": ("GmbH", None), "industry": ("", None), "founded_year": (None, None), "business_model": ("B2B", None), "offerings": ([], list), "offering_urls": ({}, dict), "company_size": ("small", None), "employee_count": ("1-9", None), "annual_revenue": ("< 2 Mio", None), "headquarters_country": ("DE", None), "headquarters_country_other": ("", None), "headquarters_street": ("", None), "headquarters_zip": ("", None), "headquarters_city": ("", None), "headquarters_state": ("", None), "has_international_locations": (False, None), "international_countries": ([], list), "target_markets": (["DE"], list), "primary_jurisdiction": ("DE", None), "is_data_controller": (True, None), "is_data_processor": (False, None), "uses_ai": (False, None), "ai_use_cases": ([], list), "dpo_name": (None, None), "dpo_email": (None, None), "legal_contact_name": (None, None), "legal_contact_email": (None, None), "machine_builder": (None, dict), "is_complete": (False, None), "completed_at": (None, "STR_OR_NONE"), "created_at": (None, "STR"), "updated_at": (None, "STR"), "repos": ([], list), "document_sources": ([], list), "processing_systems": ([], list), "ai_systems": ([], list), "technical_contacts": ([], list), "subject_to_nis2": (False, None), "subject_to_ai_act": (False, None), "subject_to_iso27001": (False, None), "supervisory_authority": (None, None), "review_cycle_months": (12, None), "project_id": (None, "STR_OR_NONE"), } _JSONB_FIELDS = { "offerings", "offering_urls", "international_countries", "target_markets", "ai_use_cases", "machine_builder", "repos", "document_sources", "processing_systems", "ai_systems", "technical_contacts", } def _where_clause() -> str: """WHERE clause matching tenant_id + project_id (handles NULL).""" return "tenant_id = :tid AND project_id IS NOT DISTINCT FROM :pid" def row_to_response(row: Any) -> CompanyProfileResponse: """Convert a DB row to response model using zip-based column mapping.""" raw = dict(zip(_BASE_COLUMNS_LIST, row)) coerced: dict[str, Any] = {} for col in _BASE_COLUMNS_LIST: default, expected_type = _FIELD_DEFAULTS[col] value = raw[col] if expected_type == "STR": coerced[col] = str(value) elif expected_type == "STR_OR_NONE": coerced[col] = str(value) if value else None elif expected_type is not None: coerced[col] = value if isinstance(value, expected_type) else default else: if col == "is_data_controller": coerced[col] = value if value is not None else default else: coerced[col] = value or default if default is not None else value return CompanyProfileResponse(**coerced) def log_audit( db: Session, tenant_id: str, action: str, changed_fields: Optional[dict[str, Any]], changed_by: Optional[str], project_id: Optional[str] = None, ) -> None: """Write an audit log entry. Warnings only on failure — never fatal.""" try: db.execute( text( "INSERT INTO compliance_company_profile_audit " "(tenant_id, project_id, action, changed_fields, changed_by) " "VALUES (:tenant_id, :project_id, :action, :fields::jsonb, :changed_by)" ), { "tenant_id": tenant_id, "project_id": project_id, "action": action, "fields": json.dumps(changed_fields) if changed_fields else None, "changed_by": changed_by, }, ) except Exception as exc: logger.warning(f"Failed to write audit log: {exc}") # ============================================================================ # Service # ============================================================================ class CompanyProfileService: """Business logic for company profile persistence + audit.""" def __init__(self, db: Session) -> None: self.db = db # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _fetch_row(self, tid: str, pid: Optional[str]) -> Any: return self.db.execute( text(f"SELECT {_BASE_COLUMNS} FROM compliance_company_profiles WHERE {_where_clause()}"), {"tid": tid, "pid": pid}, ).fetchone() def _exists(self, tid: str, pid: Optional[str]) -> bool: return self.db.execute( text(f"SELECT id FROM compliance_company_profiles WHERE {_where_clause()}"), {"tid": tid, "pid": pid}, ).fetchone() is not None def _require_row(self, tid: str, pid: Optional[str]) -> Any: row = self._fetch_row(tid, pid) if not row: raise NotFoundError("Company profile not found") return row # ------------------------------------------------------------------ # Queries # ------------------------------------------------------------------ def get(self, tid: str, pid: Optional[str]) -> CompanyProfileResponse: return row_to_response(self._require_row(tid, pid)) def template_context(self, tid: str, pid: Optional[str]) -> dict[str, Any]: row = self._fetch_row(tid, pid) if not row: raise NotFoundError("Company profile not found — fill Stammdaten first") resp = row_to_response(row) return { "company_name": resp.company_name, "legal_form": resp.legal_form, "industry": resp.industry, "business_model": resp.business_model, "company_size": resp.company_size, "employee_count": resp.employee_count, "headquarters_country": resp.headquarters_country, "headquarters_city": resp.headquarters_city, "primary_jurisdiction": resp.primary_jurisdiction, "is_data_controller": resp.is_data_controller, "is_data_processor": resp.is_data_processor, "uses_ai": resp.uses_ai, "dpo_name": resp.dpo_name or "", "dpo_email": resp.dpo_email or "", "legal_contact_name": resp.legal_contact_name or "", "legal_contact_email": resp.legal_contact_email or "", "supervisory_authority": resp.supervisory_authority or "", "review_cycle_months": resp.review_cycle_months, "subject_to_nis2": resp.subject_to_nis2, "subject_to_ai_act": resp.subject_to_ai_act, "subject_to_iso27001": resp.subject_to_iso27001, "offerings": resp.offerings, "target_markets": resp.target_markets, "international_countries": resp.international_countries, "ai_use_cases": resp.ai_use_cases, "repos": resp.repos, "document_sources": resp.document_sources, "processing_systems": resp.processing_systems, "ai_systems": resp.ai_systems, "technical_contacts": resp.technical_contacts, "has_ai_systems": len(resp.ai_systems) > 0, "processing_system_count": len(resp.processing_systems), "ai_system_count": len(resp.ai_systems), "is_complete": resp.is_complete, } def audit_log(self, tid: str, pid: Optional[str]) -> AuditListResponse: result = self.db.execute( text( "SELECT id, action, changed_fields, changed_by, created_at " "FROM compliance_company_profile_audit " "WHERE tenant_id = :tid AND project_id IS NOT DISTINCT FROM :pid " "ORDER BY created_at DESC LIMIT 100" ), {"tid": tid, "pid": pid}, ) entries = [ AuditEntryResponse( id=str(r[0]), action=r[1], changed_fields=r[2] if isinstance(r[2], dict) else None, changed_by=r[3], created_at=str(r[4]), ) for r in result.fetchall() ] return AuditListResponse(entries=entries, total=len(entries)) # ------------------------------------------------------------------ # Commands # ------------------------------------------------------------------ def upsert( self, tid: str, pid: Optional[str], profile: CompanyProfileRequest ) -> CompanyProfileResponse: from compliance.services._company_profile_sql import ( build_upsert_params, execute_insert, execute_update, ) existing = self._exists(tid, pid) action = "update" if existing else "create" params = build_upsert_params(tid, pid, profile) completed_at_sql = "NOW()" if profile.is_complete else "NULL" if existing: execute_update(self.db, params, completed_at_sql, _where_clause()) else: execute_insert(self.db, params, completed_at_sql) log_audit(self.db, tid, action, profile.model_dump(), None, pid) self.db.commit() return row_to_response(self._require_row(tid, pid)) def delete(self, tid: str, pid: Optional[str]) -> dict[str, Any]: if not self._exists(tid, pid): raise NotFoundError("Company profile not found") self.db.execute( text(f"DELETE FROM compliance_company_profiles WHERE {_where_clause()}"), {"tid": tid, "pid": pid}, ) log_audit(self.db, tid, "delete", None, None, pid) self.db.commit() return {"success": True, "message": "Company profile deleted"} def patch( self, tid: str, pid: Optional[str], updates: dict[str, Any] ) -> CompanyProfileResponse: if not self._exists(tid, pid): raise NotFoundError("Company profile not found") allowed = set(_BASE_COLUMNS_LIST) - { "id", "tenant_id", "project_id", "created_at", "updated_at", "completed_at", } set_parts: list[str] = [] params: dict[str, Any] = {"tid": tid, "pid": pid} for key, value in updates.items(): if key not in allowed: continue param_name = f"p_{key}" if key in _JSONB_FIELDS: set_parts.append(f"{key} = :{param_name}::jsonb") params[param_name] = json.dumps(value) if value is not None else None else: set_parts.append(f"{key} = :{param_name}") params[param_name] = value if not set_parts: raise ValidationError("No valid fields to update") set_parts.append("updated_at = NOW()") self.db.execute( text( f"UPDATE compliance_company_profiles SET {', '.join(set_parts)} " f"WHERE {_where_clause()}" ), params, ) log_audit(self.db, tid, "patch", updates, None, pid) self.db.commit() return row_to_response(self._require_row(tid, pid))