""" LLM Provider Abstraction for Compliance AI Features. Supports: - Anthropic Claude API (default) - Self-Hosted LLMs (Ollama, vLLM, LocalAI, etc.) - HashiCorp Vault integration for secure API key storage Configuration via environment variables: - COMPLIANCE_LLM_PROVIDER: "anthropic" or "self_hosted" - ANTHROPIC_API_KEY: API key for Claude (or loaded from Vault) - ANTHROPIC_MODEL: Model name (default: claude-sonnet-4-20250514) - SELF_HOSTED_LLM_URL: Base URL for self-hosted LLM - SELF_HOSTED_LLM_MODEL: Model name for self-hosted - SELF_HOSTED_LLM_KEY: Optional API key for self-hosted Vault Configuration: - VAULT_ADDR: Vault server address (e.g., http://vault:8200) - VAULT_TOKEN: Vault authentication token - USE_VAULT_SECRETS: Set to "true" to enable Vault integration - VAULT_SECRET_PATH: Path to secrets (default: secret/breakpilot/api_keys) """ import os import asyncio import logging from abc import ABC, abstractmethod from typing import List, Optional, Dict, Any from dataclasses import dataclass from enum import Enum import httpx logger = logging.getLogger(__name__) # ============================================================================= # Vault Integration # ============================================================================= class VaultClient: """ HashiCorp Vault client for retrieving secrets. Supports KV v2 secrets engine. """ def __init__( self, addr: Optional[str] = None, token: Optional[str] = None ): self.addr = addr or os.getenv("VAULT_ADDR", "http://localhost:8200") self.token = token or os.getenv("VAULT_TOKEN") self._cache: Dict[str, Any] = {} self._cache_ttl = 300 # 5 minutes cache def _get_headers(self) -> Dict[str, str]: """Get request headers with Vault token.""" headers = {"Content-Type": "application/json"} if self.token: headers["X-Vault-Token"] = self.token return headers def get_secret(self, path: str, key: str = "value") -> Optional[str]: """ Get a secret from Vault KV v2. Args: path: Secret path (e.g., "breakpilot/api_keys/anthropic") key: Key within the secret data (default: "value") Returns: Secret value or None if not found """ cache_key = f"{path}:{key}" # Check cache first if cache_key in self._cache: return self._cache[cache_key] try: # KV v2 uses /data/ in the path full_path = f"{self.addr}/v1/secret/data/{path}" response = httpx.get( full_path, headers=self._get_headers(), timeout=10.0 ) if response.status_code == 200: data = response.json() secret_data = data.get("data", {}).get("data", {}) secret_value = secret_data.get(key) if secret_value: self._cache[cache_key] = secret_value logger.info(f"Successfully loaded secret from Vault: {path}") return secret_value elif response.status_code == 404: logger.warning(f"Secret not found in Vault: {path}") else: logger.error(f"Vault error {response.status_code}: {response.text}") except httpx.RequestError as e: logger.error(f"Failed to connect to Vault at {self.addr}: {e}") except Exception as e: logger.error(f"Error retrieving secret from Vault: {e}") return None def get_anthropic_key(self) -> Optional[str]: """Get Anthropic API key from Vault.""" path = os.getenv("VAULT_ANTHROPIC_PATH", "breakpilot/api_keys/anthropic") return self.get_secret(path, "value") def is_available(self) -> bool: """Check if Vault is available and authenticated.""" try: response = httpx.get( f"{self.addr}/v1/sys/health", headers=self._get_headers(), timeout=5.0 ) return response.status_code in (200, 429, 472, 473, 501, 503) except Exception: return False # Singleton Vault client _vault_client: Optional[VaultClient] = None def get_vault_client() -> VaultClient: """Get shared Vault client instance.""" global _vault_client if _vault_client is None: _vault_client = VaultClient() return _vault_client def get_secret_from_vault_or_env( vault_path: str, env_var: str, vault_key: str = "value" ) -> Optional[str]: """ Get a secret, trying Vault first, then falling back to environment variable. Args: vault_path: Path in Vault (e.g., "breakpilot/api_keys/anthropic") env_var: Environment variable name as fallback vault_key: Key within Vault secret data Returns: Secret value or None """ use_vault = os.getenv("USE_VAULT_SECRETS", "").lower() in ("true", "1", "yes") if use_vault: vault = get_vault_client() secret = vault.get_secret(vault_path, vault_key) if secret: return secret logger.info(f"Vault secret not found, falling back to env: {env_var}") return os.getenv(env_var) class LLMProviderType(str, Enum): """Supported LLM provider types.""" ANTHROPIC = "anthropic" SELF_HOSTED = "self_hosted" OLLAMA = "ollama" # Alias for self_hosted (Ollama-specific) MOCK = "mock" # For testing @dataclass class LLMResponse: """Standard response from LLM.""" content: str model: str provider: str usage: Optional[Dict[str, int]] = None raw_response: Optional[Dict[str, Any]] = None @dataclass class LLMConfig: """Configuration for LLM provider.""" provider_type: LLMProviderType api_key: Optional[str] = None model: str = "claude-sonnet-4-20250514" base_url: Optional[str] = None max_tokens: int = 4096 temperature: float = 0.3 timeout: float = 60.0 class LLMProvider(ABC): """Abstract base class for LLM providers.""" def __init__(self, config: LLMConfig): self.config = config @abstractmethod async def complete( self, prompt: str, system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None ) -> LLMResponse: """Generate a completion for the given prompt.""" pass @abstractmethod async def batch_complete( self, prompts: List[str], system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, rate_limit: float = 1.0 ) -> List[LLMResponse]: """Generate completions for multiple prompts with rate limiting.""" pass @property @abstractmethod def provider_name(self) -> str: """Return the provider name.""" pass class AnthropicProvider(LLMProvider): """Claude API Provider using Anthropic's official API.""" ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" def __init__(self, config: LLMConfig): super().__init__(config) if not config.api_key: raise ValueError("Anthropic API key is required") self.api_key = config.api_key self.model = config.model or "claude-sonnet-4-20250514" @property def provider_name(self) -> str: return "anthropic" async def complete( self, prompt: str, system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None ) -> LLMResponse: """Generate completion using Claude API.""" headers = { "x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json" } messages = [{"role": "user", "content": prompt}] payload = { "model": self.model, "max_tokens": max_tokens or self.config.max_tokens, "messages": messages } if system_prompt: payload["system"] = system_prompt if temperature is not None: payload["temperature"] = temperature elif self.config.temperature is not None: payload["temperature"] = self.config.temperature async with httpx.AsyncClient(timeout=self.config.timeout) as client: try: response = await client.post( self.ANTHROPIC_API_URL, headers=headers, json=payload ) response.raise_for_status() data = response.json() content = "" if data.get("content"): content = data["content"][0].get("text", "") return LLMResponse( content=content, model=self.model, provider=self.provider_name, usage=data.get("usage"), raw_response=data ) except httpx.HTTPStatusError as e: logger.error(f"Anthropic API error: {e.response.status_code} - {e.response.text}") raise except Exception as e: logger.error(f"Anthropic API request failed: {e}") raise async def batch_complete( self, prompts: List[str], system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, rate_limit: float = 1.0 ) -> List[LLMResponse]: """Process multiple prompts with rate limiting.""" results = [] for i, prompt in enumerate(prompts): if i > 0: await asyncio.sleep(rate_limit) try: result = await self.complete( prompt=prompt, system_prompt=system_prompt, max_tokens=max_tokens ) results.append(result) except Exception as e: logger.error(f"Failed to process prompt {i}: {e}") # Append error response results.append(LLMResponse( content=f"Error: {str(e)}", model=self.model, provider=self.provider_name )) return results class SelfHostedProvider(LLMProvider): """Self-Hosted LLM Provider supporting Ollama, vLLM, LocalAI, etc.""" def __init__(self, config: LLMConfig): super().__init__(config) if not config.base_url: raise ValueError("Base URL is required for self-hosted provider") self.base_url = config.base_url.rstrip("/") self.model = config.model self.api_key = config.api_key @property def provider_name(self) -> str: return "self_hosted" def _detect_api_format(self) -> str: """Detect the API format based on URL patterns.""" if "11434" in self.base_url or "ollama" in self.base_url.lower(): return "ollama" elif "openai" in self.base_url.lower() or "v1" in self.base_url: return "openai" else: return "ollama" # Default to Ollama format async def complete( self, prompt: str, system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None ) -> LLMResponse: """Generate completion using self-hosted LLM.""" api_format = self._detect_api_format() headers = {"content-type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" if api_format == "ollama": # Ollama API format endpoint = f"{self.base_url}/api/generate" full_prompt = prompt if system_prompt: full_prompt = f"{system_prompt}\n\n{prompt}" payload = { "model": self.model, "prompt": full_prompt, "stream": False, "think": False, # Disable thinking mode (qwen3.5 etc.) "options": {} } if max_tokens: payload["options"]["num_predict"] = max_tokens if temperature is not None: payload["options"]["temperature"] = temperature else: # OpenAI-compatible format (vLLM, LocalAI, etc.) endpoint = f"{self.base_url}/v1/chat/completions" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) payload = { "model": self.model, "messages": messages, "max_tokens": max_tokens or self.config.max_tokens, "temperature": temperature if temperature is not None else self.config.temperature } async with httpx.AsyncClient(timeout=self.config.timeout) as client: try: response = await client.post(endpoint, headers=headers, json=payload) response.raise_for_status() data = response.json() # Parse response based on format if api_format == "ollama": content = data.get("response", "") else: # OpenAI format content = data.get("choices", [{}])[0].get("message", {}).get("content", "") return LLMResponse( content=content, model=self.model, provider=self.provider_name, usage=data.get("usage"), raw_response=data ) except httpx.HTTPStatusError as e: logger.error(f"Self-hosted LLM error: {e.response.status_code} - {e.response.text}") raise except Exception as e: logger.error(f"Self-hosted LLM request failed: {e}") raise async def batch_complete( self, prompts: List[str], system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, rate_limit: float = 0.5 # Self-hosted can be faster ) -> List[LLMResponse]: """Process multiple prompts with rate limiting.""" results = [] for i, prompt in enumerate(prompts): if i > 0: await asyncio.sleep(rate_limit) try: result = await self.complete( prompt=prompt, system_prompt=system_prompt, max_tokens=max_tokens ) results.append(result) except Exception as e: logger.error(f"Failed to process prompt {i}: {e}") results.append(LLMResponse( content=f"Error: {str(e)}", model=self.model, provider=self.provider_name )) return results class MockProvider(LLMProvider): """Mock provider for testing without actual API calls.""" def __init__(self, config: LLMConfig): super().__init__(config) self.responses: List[str] = [] self.call_count = 0 @property def provider_name(self) -> str: return "mock" def set_responses(self, responses: List[str]): """Set predetermined responses for testing.""" self.responses = responses self.call_count = 0 async def complete( self, prompt: str, system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None ) -> LLMResponse: """Return mock response.""" if self.responses: content = self.responses[self.call_count % len(self.responses)] else: content = f"Mock response for: {prompt[:50]}..." self.call_count += 1 return LLMResponse( content=content, model="mock-model", provider=self.provider_name, usage={"input_tokens": len(prompt), "output_tokens": len(content)} ) async def batch_complete( self, prompts: List[str], system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, rate_limit: float = 0.0 ) -> List[LLMResponse]: """Return mock responses for batch.""" return [await self.complete(p, system_prompt, max_tokens) for p in prompts] def get_llm_config() -> LLMConfig: """ Create LLM config from environment variables or Vault. Priority for API key: 1. Vault (if USE_VAULT_SECRETS=true and Vault is available) 2. Environment variable (ANTHROPIC_API_KEY) """ provider_type_str = os.getenv("COMPLIANCE_LLM_PROVIDER", "anthropic") try: provider_type = LLMProviderType(provider_type_str) except ValueError: logger.warning(f"Unknown LLM provider: {provider_type_str}, falling back to mock") provider_type = LLMProviderType.MOCK # Get API key from Vault or environment api_key = None if provider_type == LLMProviderType.ANTHROPIC: api_key = get_secret_from_vault_or_env( vault_path="breakpilot/api_keys/anthropic", env_var="ANTHROPIC_API_KEY" ) elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA): api_key = get_secret_from_vault_or_env( vault_path="breakpilot/api_keys/self_hosted_llm", env_var="SELF_HOSTED_LLM_KEY" ) # Select model based on provider type if provider_type == LLMProviderType.ANTHROPIC: model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514") elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA): model = os.getenv("SELF_HOSTED_LLM_MODEL", "qwen2.5:14b") else: model = "mock-model" return LLMConfig( provider_type=provider_type, api_key=api_key, model=model, base_url=os.getenv("SELF_HOSTED_LLM_URL"), max_tokens=int(os.getenv("COMPLIANCE_LLM_MAX_TOKENS", "4096")), temperature=float(os.getenv("COMPLIANCE_LLM_TEMPERATURE", "0.3")), timeout=float(os.getenv("COMPLIANCE_LLM_TIMEOUT", "60.0")) ) def get_llm_provider(config: Optional[LLMConfig] = None) -> LLMProvider: """ Factory function to get the appropriate LLM provider based on configuration. Usage: provider = get_llm_provider() response = await provider.complete("Analyze this requirement...") """ if config is None: config = get_llm_config() if config.provider_type == LLMProviderType.ANTHROPIC: if not config.api_key: logger.warning("No Anthropic API key found, using mock provider") return MockProvider(config) return AnthropicProvider(config) elif config.provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA): if not config.base_url: logger.warning("No self-hosted LLM URL found, using mock provider") return MockProvider(config) return SelfHostedProvider(config) elif config.provider_type == LLMProviderType.MOCK: return MockProvider(config) else: raise ValueError(f"Unsupported LLM provider type: {config.provider_type}") # Singleton instance for reuse _provider_instance: Optional[LLMProvider] = None def get_shared_provider() -> LLMProvider: """Get a shared LLM provider instance.""" global _provider_instance if _provider_instance is None: _provider_instance = get_llm_provider() return _provider_instance def reset_shared_provider(): """Reset the shared provider instance (useful for testing).""" global _provider_instance _provider_instance = None