"""Tests for License Gate service (license_gate.py).""" import pytest from unittest.mock import MagicMock, patch from collections import namedtuple from compliance.services.license_gate import ( check_source_allowed, get_license_matrix, get_source_permissions, USAGE_COLUMN_MAP, ) class TestUsageColumnMap: """Test the usage type to column mapping.""" def test_all_usage_types_mapped(self): expected = {"analysis", "store_excerpt", "ship_embeddings", "ship_in_product"} assert set(USAGE_COLUMN_MAP.keys()) == expected def test_column_names(self): assert USAGE_COLUMN_MAP["analysis"] == "allowed_analysis" assert USAGE_COLUMN_MAP["store_excerpt"] == "allowed_store_excerpt" assert USAGE_COLUMN_MAP["ship_embeddings"] == "allowed_ship_embeddings" assert USAGE_COLUMN_MAP["ship_in_product"] == "allowed_ship_in_product" class TestCheckSourceAllowed: """Tests for check_source_allowed().""" def _mock_db(self, return_value): db = MagicMock() mock_result = MagicMock() if return_value is None: mock_result.fetchone.return_value = None else: mock_result.fetchone.return_value = (return_value,) db.execute.return_value = mock_result return db def test_allowed_analysis(self): db = self._mock_db(True) assert check_source_allowed(db, "OWASP_ASVS", "analysis") is True def test_denied_ship_in_product(self): db = self._mock_db(False) assert check_source_allowed(db, "BSI_TR03161_1", "ship_in_product") is False def test_unknown_source(self): db = self._mock_db(None) assert check_source_allowed(db, "NONEXISTENT", "analysis") is False def test_unknown_usage_type(self): db = MagicMock() assert check_source_allowed(db, "OWASP_ASVS", "invalid_type") is False # DB should not be called for invalid usage type db.execute.assert_not_called() def test_allowed_store_excerpt(self): db = self._mock_db(True) assert check_source_allowed(db, "OWASP_ASVS", "store_excerpt") is True def test_denied_store_excerpt(self): db = self._mock_db(False) assert check_source_allowed(db, "BSI_TR03161_1", "store_excerpt") is False class TestGetLicenseMatrix: """Tests for get_license_matrix().""" def test_returns_list(self): LicRow = namedtuple("LicRow", [ "license_id", "name", "terms_url", "commercial_use", "ai_training_restriction", "tdm_allowed_under_44b", "deletion_required", "notes", ]) rows = [ LicRow("OWASP_CC_BY_SA", "CC BY-SA 4.0", "https://example.com", "allowed", None, "yes", False, "Open source"), LicRow("BSI_TOS_2025", "BSI ToS", "https://bsi.bund.de", "restricted", "unclear", "yes", True, "Commercial restricted"), ] db = MagicMock() db.execute.return_value.fetchall.return_value = rows result = get_license_matrix(db) assert len(result) == 2 assert result[0]["license_id"] == "OWASP_CC_BY_SA" assert result[0]["commercial_use"] == "allowed" assert result[0]["deletion_required"] is False assert result[1]["license_id"] == "BSI_TOS_2025" assert result[1]["commercial_use"] == "restricted" assert result[1]["deletion_required"] is True def test_empty_result(self): db = MagicMock() db.execute.return_value.fetchall.return_value = [] result = get_license_matrix(db) assert result == [] class TestGetSourcePermissions: """Tests for get_source_permissions().""" def test_returns_list_with_join(self): SrcRow = namedtuple("SrcRow", [ "source_id", "title", "publisher", "url", "version_label", "language", "license_id", "allowed_analysis", "allowed_store_excerpt", "allowed_ship_embeddings", "allowed_ship_in_product", "vault_retention_days", "vault_access_tier", "license_name", "commercial_use", ]) rows = [ SrcRow( "OWASP_ASVS", "OWASP ASVS", "OWASP Foundation", "https://owasp.org", "4.0.3", "en", "OWASP_CC_BY_SA", True, True, True, True, 30, "public", "CC BY-SA 4.0", "allowed", ), ] db = MagicMock() db.execute.return_value.fetchall.return_value = rows result = get_source_permissions(db) assert len(result) == 1 src = result[0] assert src["source_id"] == "OWASP_ASVS" assert src["allowed_analysis"] is True assert src["allowed_ship_in_product"] is True assert src["license_name"] == "CC BY-SA 4.0" assert src["commercial_use"] == "allowed" def test_restricted_source(self): SrcRow = namedtuple("SrcRow", [ "source_id", "title", "publisher", "url", "version_label", "language", "license_id", "allowed_analysis", "allowed_store_excerpt", "allowed_ship_embeddings", "allowed_ship_in_product", "vault_retention_days", "vault_access_tier", "license_name", "commercial_use", ]) rows = [ SrcRow( "BSI_TR03161_1", "BSI TR-03161 Teil 1", "BSI", "https://bsi.bund.de", "1.0", "de", "BSI_TOS_2025", True, False, False, False, 30, "restricted", "BSI Nutzungsbedingungen", "restricted", ), ] db = MagicMock() db.execute.return_value.fetchall.return_value = rows result = get_source_permissions(db) src = result[0] assert src["allowed_analysis"] is True assert src["allowed_store_excerpt"] is False assert src["allowed_ship_embeddings"] is False assert src["allowed_ship_in_product"] is False