#!/usr/bin/env python3 """ Test script for Compliance AI API Endpoints. Usage: python scripts/test_compliance_ai_endpoints.py Environment: BACKEND_URL: Base URL of the backend (default: http://localhost:8000) COMPLIANCE_LLM_PROVIDER: Set to "mock" for testing without API keys """ import asyncio import os import sys from typing import Dict, Any import httpx class ComplianceAITester: """Tester for Compliance AI endpoints.""" def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url.rstrip("/") self.api_prefix = f"{self.base_url}/api/v1/compliance" async def test_ai_status(self) -> Dict[str, Any]: """Test GET /ai/status endpoint.""" print("\n=== Testing AI Status ===") async with httpx.AsyncClient() as client: response = await client.get(f"{self.api_prefix}/ai/status") response.raise_for_status() data = response.json() print(f"Provider: {data['provider']}") print(f"Model: {data['model']}") print(f"Available: {data['is_available']}") print(f"Is Mock: {data['is_mock']}") if data.get("error"): print(f"Error: {data['error']}") return data async def test_interpret_requirement(self, requirement_id: str) -> Dict[str, Any]: """Test POST /ai/interpret endpoint.""" print(f"\n=== Testing Requirement Interpretation ===") print(f"Requirement ID: {requirement_id}") async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( f"{self.api_prefix}/ai/interpret", json={ "requirement_id": requirement_id, "force_refresh": False } ) if response.status_code == 404: print(f"ERROR: Requirement {requirement_id} not found") return {} response.raise_for_status() data = response.json() print(f"\nSummary: {data['summary'][:100]}...") print(f"Applicability: {data['applicability'][:100]}...") print(f"Risk Level: {data['risk_level']}") print(f"Affected Modules: {', '.join(data['affected_modules'])}") print(f"Technical Measures: {len(data['technical_measures'])} measures") print(f"Confidence: {data['confidence_score']:.2f}") if data.get("error"): print(f"Error: {data['error']}") return data async def test_suggest_controls(self, requirement_id: str) -> Dict[str, Any]: """Test POST /ai/suggest-controls endpoint.""" print(f"\n=== Testing Control Suggestions ===") print(f"Requirement ID: {requirement_id}") async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( f"{self.api_prefix}/ai/suggest-controls", json={"requirement_id": requirement_id} ) if response.status_code == 404: print(f"ERROR: Requirement {requirement_id} not found") return {} response.raise_for_status() data = response.json() print(f"\nFound {len(data['suggestions'])} control suggestions:") for i, ctrl in enumerate(data['suggestions'], 1): print(f"\n{i}. {ctrl['control_id']}: {ctrl['title']}") print(f" Domain: {ctrl['domain']}") print(f" Priority: {ctrl['priority']}") print(f" Automated: {ctrl['is_automated']}") if ctrl['automation_tool']: print(f" Tool: {ctrl['automation_tool']}") print(f" Confidence: {ctrl['confidence_score']:.2f}") return data async def test_assess_module_risk(self, module_id: str) -> Dict[str, Any]: """Test POST /ai/assess-risk endpoint.""" print(f"\n=== Testing Module Risk Assessment ===") print(f"Module ID: {module_id}") async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( f"{self.api_prefix}/ai/assess-risk", json={"module_id": module_id} ) if response.status_code == 404: print(f"ERROR: Module {module_id} not found") return {} response.raise_for_status() data = response.json() print(f"\nModule: {data['module_name']}") print(f"Overall Risk: {data['overall_risk']}") print(f"\nRisk Factors:") for factor in data['risk_factors']: print(f" - {factor['factor']}") print(f" Severity: {factor['severity']}, Likelihood: {factor['likelihood']}") print(f"\nRecommendations:") for rec in data['recommendations']: print(f" - {rec}") print(f"\nCompliance Gaps:") for gap in data['compliance_gaps']: print(f" - {gap}") print(f"\nConfidence: {data['confidence_score']:.2f}") return data async def test_gap_analysis(self, requirement_id: str) -> Dict[str, Any]: """Test POST /ai/gap-analysis endpoint.""" print(f"\n=== Testing Gap Analysis ===") print(f"Requirement ID: {requirement_id}") async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( f"{self.api_prefix}/ai/gap-analysis", json={"requirement_id": requirement_id} ) if response.status_code == 404: print(f"ERROR: Requirement {requirement_id} not found") return {} response.raise_for_status() data = response.json() print(f"\nRequirement: {data['requirement_title']}") print(f"Coverage Level: {data['coverage_level']}") print(f"\nExisting Controls:") for ctrl in data['existing_controls']: print(f" - {ctrl}") print(f"\nMissing Coverage:") for missing in data['missing_coverage']: print(f" - {missing}") print(f"\nSuggested Actions:") for action in data['suggested_actions']: print(f" - {action}") return data async def test_batch_interpret(self, requirement_ids: list) -> Dict[str, Any]: """Test POST /ai/batch-interpret endpoint.""" print(f"\n=== Testing Batch Interpretation ===") print(f"Requirements: {len(requirement_ids)}") async with httpx.AsyncClient(timeout=120.0) as client: response = await client.post( f"{self.api_prefix}/ai/batch-interpret", json={ "requirement_ids": requirement_ids, "rate_limit": 1.0 } ) response.raise_for_status() data = response.json() print(f"\nTotal: {data['total']}") print(f"Processed: {data['processed']}") print(f"Success Rate: {data['processed']/data['total']*100:.1f}%") if data['interpretations']: print(f"\nFirst interpretation:") first = data['interpretations'][0] print(f" ID: {first['requirement_id']}") print(f" Summary: {first['summary'][:100]}...") print(f" Risk: {first['risk_level']}") return data async def get_sample_requirement_id(self) -> str: """Get a sample requirement ID from the database.""" async with httpx.AsyncClient() as client: # Try to get requirements response = await client.get(f"{self.api_prefix}/requirements?limit=1") if response.status_code == 200: data = response.json() if data["requirements"]: return data["requirements"][0]["id"] return None async def get_sample_module_id(self) -> str: """Get a sample module ID from the database.""" async with httpx.AsyncClient() as client: # Try to get modules response = await client.get(f"{self.api_prefix}/modules") if response.status_code == 200: data = response.json() if data["modules"]: return data["modules"][0]["id"] return None async def run_all_tests(self): """Run all endpoint tests.""" print("=" * 70) print("Compliance AI Endpoints Test Suite") print("=" * 70) # Test AI status first try: status = await self.test_ai_status() if not status.get("is_available"): print("\n⚠️ WARNING: AI provider is not available!") print("Set COMPLIANCE_LLM_PROVIDER=mock for testing without API keys") return except Exception as e: print(f"\n❌ ERROR: Could not connect to backend: {e}") return # Get sample IDs print("\n--- Fetching sample data ---") requirement_id = await self.get_sample_requirement_id() module_id = await self.get_sample_module_id() if not requirement_id: print("\n⚠️ WARNING: No requirements found in database") print("Run seed command first: POST /api/v1/compliance/seed") return print(f"Sample Requirement ID: {requirement_id}") if module_id: print(f"Sample Module ID: {module_id}") # Run tests tests = [ ("Interpret Requirement", self.test_interpret_requirement(requirement_id)), ("Suggest Controls", self.test_suggest_controls(requirement_id)), ("Gap Analysis", self.test_gap_analysis(requirement_id)), ] if module_id: tests.append(("Assess Module Risk", self.test_assess_module_risk(module_id))) # Execute tests results = {"passed": 0, "failed": 0} for test_name, test_coro in tests: try: await test_coro results["passed"] += 1 print(f"\n✅ {test_name} - PASSED") except Exception as e: results["failed"] += 1 print(f"\n❌ {test_name} - FAILED: {e}") # Summary print("\n" + "=" * 70) print("Test Summary") print("=" * 70) print(f"✅ Passed: {results['passed']}") print(f"❌ Failed: {results['failed']}") print(f"Total: {results['passed'] + results['failed']}") print("=" * 70) async def main(): """Main entry point.""" backend_url = os.getenv("BACKEND_URL", "http://localhost:8000") print(f"Backend URL: {backend_url}") print(f"Provider: {os.getenv('COMPLIANCE_LLM_PROVIDER', 'default')}") tester = ComplianceAITester(base_url=backend_url) try: await tester.run_all_tests() except KeyboardInterrupt: print("\n\nTest interrupted by user") sys.exit(1) if __name__ == "__main__": asyncio.run(main())