"""Tests for VVT tenant isolation (Phase 1: Multi-Tenancy Fix). Verifies that: - tenant_utils correctly validates and resolves tenant IDs - VVT routes filter data by tenant_id - One tenant cannot see another tenant's data - "default" tenant_id is rejected """ import pytest import uuid from unittest.mock import MagicMock, AsyncMock, patch from datetime import datetime from fastapi import HTTPException from fastapi.testclient import TestClient from compliance.api.tenant_utils import get_tenant_id, _validate_tenant_id # ============================================================================= # tenant_utils unit tests # ============================================================================= TENANT_A = "9282a473-5c95-4b3a-bf78-0ecc0ec71d3e" TENANT_B = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" class TestValidateTenantId: def test_valid_uuid(self): assert _validate_tenant_id(TENANT_A) == TENANT_A def test_valid_uuid_uppercase(self): upper = TENANT_A.upper() assert _validate_tenant_id(upper) == upper def test_reject_default_string(self): with pytest.raises(HTTPException) as exc_info: _validate_tenant_id("default") assert exc_info.value.status_code == 400 assert "default" in str(exc_info.value.detail) def test_reject_empty_string(self): with pytest.raises(HTTPException) as exc_info: _validate_tenant_id("") assert exc_info.value.status_code == 400 def test_reject_random_string(self): with pytest.raises(HTTPException) as exc_info: _validate_tenant_id("my-tenant") assert exc_info.value.status_code == 400 def test_reject_partial_uuid(self): with pytest.raises(HTTPException) as exc_info: _validate_tenant_id("9282a473-5c95-4b3a") assert exc_info.value.status_code == 400 class TestGetTenantId: @pytest.mark.asyncio async def test_header_takes_precedence(self): result = await get_tenant_id(x_tenant_id=TENANT_A, tenant_id=TENANT_B) assert result == TENANT_A @pytest.mark.asyncio async def test_query_param_fallback(self): result = await get_tenant_id(x_tenant_id=None, tenant_id=TENANT_B) assert result == TENANT_B @pytest.mark.asyncio async def test_env_fallback(self): result = await get_tenant_id(x_tenant_id=None, tenant_id=None) # Falls back to ENV default which is the well-known dev UUID assert result == TENANT_A @pytest.mark.asyncio async def test_reject_default_via_header(self): with pytest.raises(HTTPException): await get_tenant_id(x_tenant_id="default", tenant_id=None) # ============================================================================= # VVT Model tests — tenant_id column present # ============================================================================= class TestVVTModelsHaveTenantId: def test_organization_has_tenant_id(self): from compliance.db.vvt_models import VVTOrganizationDB assert hasattr(VVTOrganizationDB, 'tenant_id') col = VVTOrganizationDB.__table__.columns['tenant_id'] assert col.nullable is False def test_activity_has_tenant_id(self): from compliance.db.vvt_models import VVTActivityDB assert hasattr(VVTActivityDB, 'tenant_id') col = VVTActivityDB.__table__.columns['tenant_id'] assert col.nullable is False def test_audit_log_has_tenant_id(self): from compliance.db.vvt_models import VVTAuditLogDB assert hasattr(VVTAuditLogDB, 'tenant_id') col = VVTAuditLogDB.__table__.columns['tenant_id'] assert col.nullable is False def test_activity_no_global_unique_vvt_id(self): """vvt_id should NOT have a global unique constraint anymore.""" from compliance.db.vvt_models import VVTActivityDB col = VVTActivityDB.__table__.columns['vvt_id'] assert col.unique is not True # unique moved to composite constraint # ============================================================================= # VVT Route integration tests — tenant isolation via mocked DB # ============================================================================= def _make_activity(tenant_id, vvt_id="VVT-001", name="Test", **kwargs): """Create a mock VVTActivityDB.""" act = MagicMock() act.id = uuid.uuid4() act.tenant_id = tenant_id act.vvt_id = vvt_id act.name = name act.description = kwargs.get("description", "") act.purposes = kwargs.get("purposes", []) act.legal_bases = kwargs.get("legal_bases", []) act.data_subject_categories = [] act.personal_data_categories = [] act.recipient_categories = [] act.third_country_transfers = [] act.retention_period = {} act.tom_description = None act.business_function = kwargs.get("business_function", "IT") act.systems = [] act.deployment_model = None act.data_sources = [] act.data_flows = [] act.protection_level = "MEDIUM" act.dpia_required = False act.structured_toms = {} act.status = kwargs.get("status", "DRAFT") act.responsible = None act.owner = None act.last_reviewed_at = None act.next_review_at = None act.created_by = "system" act.dsfa_id = None act.created_at = datetime.utcnow() act.updated_at = datetime.utcnow() return act class TestVVTRouteTenantIsolation: """Verify that _activity_to_response and _log_audit accept tenant_id.""" def test_activity_to_response(self): from compliance.api.vvt_routes import _activity_to_response act = _make_activity(TENANT_A, "VVT-100", "Test Activity") resp = _activity_to_response(act) assert resp.vvt_id == "VVT-100" assert resp.name == "Test Activity" def test_log_audit_with_tenant(self): from compliance.api.vvt_routes import _log_audit db = MagicMock() _log_audit(db, tenant_id=TENANT_A, action="CREATE", entity_type="activity") db.add.assert_called_once() entry = db.add.call_args[0][0] assert entry.tenant_id == TENANT_A assert entry.action == "CREATE" def test_log_audit_different_tenants(self): from compliance.api.vvt_routes import _log_audit db = MagicMock() _log_audit(db, tenant_id=TENANT_A, action="CREATE", entity_type="activity") _log_audit(db, tenant_id=TENANT_B, action="UPDATE", entity_type="activity") assert db.add.call_count == 2 entries = [call[0][0] for call in db.add.call_args_list] assert entries[0].tenant_id == TENANT_A assert entries[1].tenant_id == TENANT_B # ============================================================================= # DSFA / Vendor — DEFAULT_TENANT_ID no longer "default" # ============================================================================= class TestDSFADefaultTenantFixed: def test_dsfa_default_is_uuid(self): from compliance.api.dsfa_routes import DEFAULT_TENANT_ID assert DEFAULT_TENANT_ID != "default" assert len(DEFAULT_TENANT_ID) == 36 assert "-" in DEFAULT_TENANT_ID def test_dsfa_get_tenant_id_fallback(self): from compliance.api.dsfa_routes import _get_tenant_id result = _get_tenant_id(None) assert result != "default" assert len(result) == 36 class TestVendorDefaultTenantFixed: def test_vendor_default_is_uuid(self): from compliance.api.vendor_compliance_routes import DEFAULT_TENANT_ID assert DEFAULT_TENANT_ID != "default" assert len(DEFAULT_TENANT_ID) == 36 assert "-" in DEFAULT_TENANT_ID