""" Keycloak Authentication Module Implements token validation against Keycloak JWKS endpoint. This module handles authentication (who is the user?), while rbac.py handles authorization (what can the user do?). Architecture: - Keycloak validates JWT tokens and provides basic identity - Our custom rbac.py handles fine-grained permissions """ import os import httpx import jwt from jwt import PyJWKClient from datetime import datetime, timezone from typing import Optional, Dict, Any, List from dataclasses import dataclass from functools import lru_cache import logging logger = logging.getLogger(__name__) @dataclass class KeycloakConfig: """Keycloak connection configuration.""" server_url: str realm: str client_id: str client_secret: Optional[str] = None verify_ssl: bool = True @property def issuer_url(self) -> str: return f"{self.server_url}/realms/{self.realm}" @property def jwks_url(self) -> str: return f"{self.issuer_url}/protocol/openid-connect/certs" @property def token_url(self) -> str: return f"{self.issuer_url}/protocol/openid-connect/token" @property def userinfo_url(self) -> str: return f"{self.issuer_url}/protocol/openid-connect/userinfo" @dataclass class KeycloakUser: """User information extracted from Keycloak token.""" user_id: str # Keycloak subject (sub) email: str email_verified: bool name: Optional[str] given_name: Optional[str] family_name: Optional[str] realm_roles: List[str] # Keycloak realm roles client_roles: Dict[str, List[str]] # Client-specific roles groups: List[str] # Keycloak groups tenant_id: Optional[str] # Custom claim for school/tenant raw_claims: Dict[str, Any] # All claims for debugging def has_realm_role(self, role: str) -> bool: """Check if user has a specific realm role.""" return role in self.realm_roles def has_client_role(self, client_id: str, role: str) -> bool: """Check if user has a specific client role.""" client_roles = self.client_roles.get(client_id, []) return role in client_roles def is_admin(self) -> bool: """Check if user has admin role.""" return self.has_realm_role("admin") or self.has_realm_role("schul_admin") def is_teacher(self) -> bool: """Check if user is a teacher.""" return self.has_realm_role("teacher") or self.has_realm_role("lehrer") class KeycloakAuthError(Exception): """Base exception for Keycloak authentication errors.""" pass class TokenExpiredError(KeycloakAuthError): """Token has expired.""" pass class TokenInvalidError(KeycloakAuthError): """Token is invalid.""" pass class KeycloakConfigError(KeycloakAuthError): """Keycloak configuration error.""" pass class KeycloakAuthenticator: """ Validates JWT tokens against Keycloak. Usage: config = KeycloakConfig( server_url="https://keycloak.example.com", realm="breakpilot", client_id="breakpilot-backend" ) auth = KeycloakAuthenticator(config) user = await auth.validate_token(token) if user.is_teacher(): # Grant access """ def __init__(self, config: KeycloakConfig): self.config = config self._jwks_client: Optional[PyJWKClient] = None self._http_client: Optional[httpx.AsyncClient] = None @property def jwks_client(self) -> PyJWKClient: """Lazy-load JWKS client.""" if self._jwks_client is None: self._jwks_client = PyJWKClient( self.config.jwks_url, cache_keys=True, lifespan=3600 # Cache keys for 1 hour ) return self._jwks_client async def get_http_client(self) -> httpx.AsyncClient: """Get or create async HTTP client.""" if self._http_client is None or self._http_client.is_closed: self._http_client = httpx.AsyncClient( verify=self.config.verify_ssl, timeout=30.0 ) return self._http_client async def close(self): """Close HTTP client.""" if self._http_client and not self._http_client.is_closed: await self._http_client.aclose() def validate_token_sync(self, token: str) -> KeycloakUser: """ Synchronously validate a JWT token against Keycloak JWKS. Args: token: The JWT access token Returns: KeycloakUser with extracted claims Raises: TokenExpiredError: If token has expired TokenInvalidError: If token signature is invalid """ try: # Get signing key from JWKS signing_key = self.jwks_client.get_signing_key_from_jwt(token) # Decode and validate token payload = jwt.decode( token, signing_key.key, algorithms=["RS256"], audience=self.config.client_id, issuer=self.config.issuer_url, options={ "verify_exp": True, "verify_iat": True, "verify_aud": True, "verify_iss": True } ) return self._extract_user(payload) except jwt.ExpiredSignatureError: raise TokenExpiredError("Token has expired") except jwt.InvalidAudienceError: raise TokenInvalidError("Invalid token audience") except jwt.InvalidIssuerError: raise TokenInvalidError("Invalid token issuer") except jwt.InvalidTokenError as e: raise TokenInvalidError(f"Invalid token: {e}") except Exception as e: logger.error(f"Token validation failed: {e}") raise TokenInvalidError(f"Token validation failed: {e}") async def validate_token(self, token: str) -> KeycloakUser: """ Asynchronously validate a JWT token. Note: JWKS fetching is synchronous due to PyJWKClient limitations, but this wrapper allows async context usage. """ return self.validate_token_sync(token) async def get_userinfo(self, token: str) -> Dict[str, Any]: """ Fetch user info from Keycloak userinfo endpoint. This provides additional user claims not in the access token. """ client = await self.get_http_client() try: response = await client.get( self.config.userinfo_url, headers={"Authorization": f"Bearer {token}"} ) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: if e.response.status_code == 401: raise TokenExpiredError("Token is invalid or expired") raise TokenInvalidError(f"Failed to fetch userinfo: {e}") def _extract_user(self, payload: Dict[str, Any]) -> KeycloakUser: """Extract KeycloakUser from JWT payload.""" # Extract realm roles realm_access = payload.get("realm_access", {}) realm_roles = realm_access.get("roles", []) # Extract client roles resource_access = payload.get("resource_access", {}) client_roles = {} for client_id, access in resource_access.items(): client_roles[client_id] = access.get("roles", []) # Extract groups groups = payload.get("groups", []) # Extract custom tenant claim (if configured in Keycloak) tenant_id = payload.get("tenant_id") or payload.get("school_id") return KeycloakUser( user_id=payload.get("sub", ""), email=payload.get("email", ""), email_verified=payload.get("email_verified", False), name=payload.get("name"), given_name=payload.get("given_name"), family_name=payload.get("family_name"), realm_roles=realm_roles, client_roles=client_roles, groups=groups, tenant_id=tenant_id, raw_claims=payload ) # ============================================= # HYBRID AUTH: Keycloak + Local JWT # ============================================= class HybridAuthenticator: """ Hybrid authenticator supporting both Keycloak and local JWT tokens. This allows gradual migration from local JWT to Keycloak: 1. Development: Use local JWT (fast, no external dependencies) 2. Production: Use Keycloak for full IAM capabilities Token type detection: - Keycloak tokens: Have 'iss' claim matching Keycloak URL - Local tokens: Have 'iss' claim as 'breakpilot' or no 'iss' """ def __init__( self, keycloak_config: Optional[KeycloakConfig] = None, local_jwt_secret: Optional[str] = None, environment: str = "development" ): self.environment = environment self.keycloak_enabled = keycloak_config is not None self.local_jwt_secret = local_jwt_secret if keycloak_config: self.keycloak_auth = KeycloakAuthenticator(keycloak_config) else: self.keycloak_auth = None async def validate_token(self, token: str) -> Dict[str, Any]: """ Validate token using appropriate method. Returns a unified user dict compatible with existing code. """ if not token: raise TokenInvalidError("No token provided") # Try to peek at the token to determine type try: # Decode without verification to check issuer unverified = jwt.decode(token, options={"verify_signature": False}) issuer = unverified.get("iss", "") except jwt.InvalidTokenError: raise TokenInvalidError("Cannot decode token") # Check if it's a Keycloak token if self.keycloak_auth and self.keycloak_auth.config.issuer_url in issuer: # Validate with Keycloak kc_user = await self.keycloak_auth.validate_token(token) return self._keycloak_user_to_dict(kc_user) # Fall back to local JWT validation if self.local_jwt_secret: return self._validate_local_token(token) raise TokenInvalidError("No valid authentication method available") def _validate_local_token(self, token: str) -> Dict[str, Any]: """Validate token with local JWT secret.""" if not self.local_jwt_secret: raise KeycloakConfigError("Local JWT secret not configured") try: payload = jwt.decode( token, self.local_jwt_secret, algorithms=["HS256"] ) # Map local token claims to unified format return { "user_id": payload.get("user_id", payload.get("sub", "")), "email": payload.get("email", ""), "name": payload.get("name", ""), "role": payload.get("role", "user"), "realm_roles": [payload.get("role", "user")], "tenant_id": payload.get("tenant_id", payload.get("school_id")), "auth_method": "local_jwt" } except jwt.ExpiredSignatureError: raise TokenExpiredError("Token has expired") except jwt.InvalidTokenError as e: raise TokenInvalidError(f"Invalid local token: {e}") def _keycloak_user_to_dict(self, user: KeycloakUser) -> Dict[str, Any]: """Convert KeycloakUser to dict compatible with existing code.""" # Map Keycloak roles to our role system role = "user" if user.is_admin(): role = "admin" elif user.is_teacher(): role = "teacher" return { "user_id": user.user_id, "email": user.email, "name": user.name or f"{user.given_name or ''} {user.family_name or ''}".strip(), "role": role, "realm_roles": user.realm_roles, "client_roles": user.client_roles, "groups": user.groups, "tenant_id": user.tenant_id, "email_verified": user.email_verified, "auth_method": "keycloak" } async def close(self): """Cleanup resources.""" if self.keycloak_auth: await self.keycloak_auth.close() # ============================================= # FACTORY FUNCTIONS # ============================================= def get_keycloak_config_from_env() -> Optional[KeycloakConfig]: """ Create KeycloakConfig from environment variables. Required env vars: - KEYCLOAK_SERVER_URL: e.g., https://keycloak.breakpilot.app - KEYCLOAK_REALM: e.g., breakpilot - KEYCLOAK_CLIENT_ID: e.g., breakpilot-backend Optional: - KEYCLOAK_CLIENT_SECRET: For confidential clients - KEYCLOAK_VERIFY_SSL: Default true """ server_url = os.environ.get("KEYCLOAK_SERVER_URL") realm = os.environ.get("KEYCLOAK_REALM") client_id = os.environ.get("KEYCLOAK_CLIENT_ID") if not all([server_url, realm, client_id]): logger.info("Keycloak not configured, using local JWT only") return None return KeycloakConfig( server_url=server_url, realm=realm, client_id=client_id, client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"), verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true" ) def get_authenticator() -> HybridAuthenticator: """ Get configured authenticator instance. Uses environment variables to determine configuration. """ keycloak_config = get_keycloak_config_from_env() # JWT_SECRET is required - no default fallback in production jwt_secret = os.environ.get("JWT_SECRET") environment = os.environ.get("ENVIRONMENT", "development") if not jwt_secret and environment == "production": raise KeycloakConfigError( "JWT_SECRET environment variable is required in production" ) return HybridAuthenticator( keycloak_config=keycloak_config, local_jwt_secret=jwt_secret, environment=environment ) # ============================================= # FASTAPI DEPENDENCY # ============================================= from fastapi import Request, HTTPException, Depends # Global authenticator instance (lazy-initialized) _authenticator: Optional[HybridAuthenticator] = None def get_auth() -> HybridAuthenticator: """Get or create global authenticator.""" global _authenticator if _authenticator is None: _authenticator = get_authenticator() return _authenticator async def get_current_user(request: Request) -> Dict[str, Any]: """ FastAPI dependency to get current authenticated user. Usage: @app.get("/api/protected") async def protected_endpoint(user: dict = Depends(get_current_user)): return {"user_id": user["user_id"]} """ auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): # Check for development mode environment = os.environ.get("ENVIRONMENT", "development") if environment == "development": # Return demo user in development without token return { "user_id": "10000000-0000-0000-0000-000000000024", "email": "demo@breakpilot.app", "role": "admin", "realm_roles": ["admin"], "tenant_id": "a0000000-0000-0000-0000-000000000001", "auth_method": "development_bypass" } raise HTTPException(status_code=401, detail="Missing authorization header") token = auth_header.split(" ")[1] try: auth = get_auth() return await auth.validate_token(token) except TokenExpiredError: raise HTTPException(status_code=401, detail="Token expired") except TokenInvalidError as e: raise HTTPException(status_code=401, detail=str(e)) except Exception as e: logger.error(f"Authentication failed: {e}") raise HTTPException(status_code=401, detail="Authentication failed") async def require_role(required_role: str): """ FastAPI dependency factory for role-based access. Usage: @app.get("/api/admin-only") async def admin_endpoint(user: dict = Depends(require_role("admin"))): return {"message": "Admin access granted"} """ async def role_checker(user: dict = Depends(get_current_user)) -> dict: user_role = user.get("role", "user") realm_roles = user.get("realm_roles", []) if user_role == required_role or required_role in realm_roles: return user raise HTTPException( status_code=403, detail=f"Role '{required_role}' required" ) return role_checker