""" Keycloak Authentication Module Implements token validation against Keycloak JWKS endpoint. This module handles authentication (who is the user?), while rbac.py handles authorization (what can the user do?). Architecture: - Keycloak validates JWT tokens and provides basic identity - Our custom rbac.py handles fine-grained permissions """ import os import httpx import jwt from jwt import PyJWKClient import logging from typing import Optional, Dict, Any from .keycloak_models import ( KeycloakConfig, KeycloakUser, KeycloakAuthError, TokenExpiredError, TokenInvalidError, KeycloakConfigError, get_keycloak_config_from_env, ) logger = logging.getLogger(__name__) class KeycloakAuthenticator: """Validates JWT tokens against Keycloak.""" def __init__(self, config: KeycloakConfig): self.config = config self._jwks_client: Optional[PyJWKClient] = None self._http_client: Optional[httpx.AsyncClient] = None @property def jwks_client(self) -> PyJWKClient: if self._jwks_client is None: self._jwks_client = PyJWKClient(self.config.jwks_url, cache_keys=True, lifespan=3600) return self._jwks_client async def get_http_client(self) -> httpx.AsyncClient: if self._http_client is None or self._http_client.is_closed: self._http_client = httpx.AsyncClient(verify=self.config.verify_ssl, timeout=30.0) return self._http_client async def close(self): if self._http_client and not self._http_client.is_closed: await self._http_client.aclose() def validate_token_sync(self, token: str) -> KeycloakUser: """Synchronously validate a JWT token against Keycloak JWKS.""" try: signing_key = self.jwks_client.get_signing_key_from_jwt(token) payload = jwt.decode( token, signing_key.key, algorithms=["RS256"], audience=self.config.client_id, issuer=self.config.issuer_url, options={"verify_exp": True, "verify_iat": True, "verify_aud": True, "verify_iss": True} ) return self._extract_user(payload) except jwt.ExpiredSignatureError: raise TokenExpiredError("Token has expired") except jwt.InvalidAudienceError: raise TokenInvalidError("Invalid token audience") except jwt.InvalidIssuerError: raise TokenInvalidError("Invalid token issuer") except jwt.InvalidTokenError as e: raise TokenInvalidError(f"Invalid token: {e}") except Exception as e: logger.error(f"Token validation failed: {e}") raise TokenInvalidError(f"Token validation failed: {e}") async def validate_token(self, token: str) -> KeycloakUser: """Asynchronously validate a JWT token.""" return self.validate_token_sync(token) async def get_userinfo(self, token: str) -> Dict[str, Any]: """Fetch user info from Keycloak userinfo endpoint.""" client = await self.get_http_client() try: response = await client.get(self.config.userinfo_url, headers={"Authorization": f"Bearer {token}"}) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: if e.response.status_code == 401: raise TokenExpiredError("Token is invalid or expired") raise TokenInvalidError(f"Failed to fetch userinfo: {e}") def _extract_user(self, payload: Dict[str, Any]) -> KeycloakUser: """Extract KeycloakUser from JWT payload.""" realm_access = payload.get("realm_access", {}) realm_roles = realm_access.get("roles", []) resource_access = payload.get("resource_access", {}) client_roles = {} for client_id, access in resource_access.items(): client_roles[client_id] = access.get("roles", []) groups = payload.get("groups", []) tenant_id = payload.get("tenant_id") or payload.get("school_id") return KeycloakUser( user_id=payload.get("sub", ""), email=payload.get("email", ""), email_verified=payload.get("email_verified", False), name=payload.get("name"), given_name=payload.get("given_name"), family_name=payload.get("family_name"), realm_roles=realm_roles, client_roles=client_roles, groups=groups, tenant_id=tenant_id, raw_claims=payload ) class HybridAuthenticator: """Hybrid authenticator supporting both Keycloak and local JWT tokens.""" def __init__(self, keycloak_config=None, local_jwt_secret=None, environment="development"): self.environment = environment self.keycloak_enabled = keycloak_config is not None self.local_jwt_secret = local_jwt_secret self.keycloak_auth = KeycloakAuthenticator(keycloak_config) if keycloak_config else None async def validate_token(self, token: str) -> Dict[str, Any]: """Validate token using appropriate method.""" if not token: raise TokenInvalidError("No token provided") try: unverified = jwt.decode(token, options={"verify_signature": False}) issuer = unverified.get("iss", "") except jwt.InvalidTokenError: raise TokenInvalidError("Cannot decode token") if self.keycloak_auth and self.keycloak_auth.config.issuer_url in issuer: kc_user = await self.keycloak_auth.validate_token(token) return self._keycloak_user_to_dict(kc_user) if self.local_jwt_secret: return self._validate_local_token(token) raise TokenInvalidError("No valid authentication method available") def _validate_local_token(self, token: str) -> Dict[str, Any]: """Validate token with local JWT secret.""" if not self.local_jwt_secret: raise KeycloakConfigError("Local JWT secret not configured") try: payload = jwt.decode(token, self.local_jwt_secret, algorithms=["HS256"]) return { "user_id": payload.get("user_id", payload.get("sub", "")), "email": payload.get("email", ""), "name": payload.get("name", ""), "role": payload.get("role", "user"), "realm_roles": [payload.get("role", "user")], "tenant_id": payload.get("tenant_id", payload.get("school_id")), "auth_method": "local_jwt" } except jwt.ExpiredSignatureError: raise TokenExpiredError("Token has expired") except jwt.InvalidTokenError as e: raise TokenInvalidError(f"Invalid local token: {e}") def _keycloak_user_to_dict(self, user: KeycloakUser) -> Dict[str, Any]: """Convert KeycloakUser to dict compatible with existing code.""" role = "user" if user.is_admin(): role = "admin" elif user.is_teacher(): role = "teacher" return { "user_id": user.user_id, "email": user.email, "name": user.name or f"{user.given_name or ''} {user.family_name or ''}".strip(), "role": role, "realm_roles": user.realm_roles, "client_roles": user.client_roles, "groups": user.groups, "tenant_id": user.tenant_id, "email_verified": user.email_verified, "auth_method": "keycloak" } async def close(self): if self.keycloak_auth: await self.keycloak_auth.close() # ============================================= # FACTORY FUNCTIONS # ============================================= def get_authenticator() -> HybridAuthenticator: """Get configured authenticator instance.""" keycloak_config = get_keycloak_config_from_env() jwt_secret = os.environ.get("JWT_SECRET") environment = os.environ.get("ENVIRONMENT", "development") if not jwt_secret and environment == "production": raise KeycloakConfigError("JWT_SECRET environment variable is required in production") return HybridAuthenticator( keycloak_config=keycloak_config, local_jwt_secret=jwt_secret, environment=environment ) # ============================================= # FASTAPI DEPENDENCY # ============================================= from fastapi import Request, HTTPException, Depends _authenticator: Optional[HybridAuthenticator] = None def get_auth() -> HybridAuthenticator: """Get or create global authenticator.""" global _authenticator if _authenticator is None: _authenticator = get_authenticator() return _authenticator async def get_current_user(request: Request) -> Dict[str, Any]: """FastAPI dependency to get current authenticated user.""" auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): environment = os.environ.get("ENVIRONMENT", "development") if environment == "development": return { "user_id": "10000000-0000-0000-0000-000000000024", "email": "demo@breakpilot.app", "role": "admin", "realm_roles": ["admin"], "tenant_id": "a0000000-0000-0000-0000-000000000001", "auth_method": "development_bypass" } raise HTTPException(status_code=401, detail="Missing authorization header") token = auth_header.split(" ")[1] try: auth = get_auth() return await auth.validate_token(token) except TokenExpiredError: raise HTTPException(status_code=401, detail="Token expired") except TokenInvalidError as e: raise HTTPException(status_code=401, detail=str(e)) except Exception as e: logger.error(f"Authentication failed: {e}") raise HTTPException(status_code=401, detail="Authentication failed") async def require_role(required_role: str): """FastAPI dependency factory for role-based access.""" async def role_checker(user: dict = Depends(get_current_user)) -> dict: user_role = user.get("role", "user") realm_roles = user.get("realm_roles", []) if user_role == required_role or required_role in realm_roles: return user raise HTTPException(status_code=403, detail=f"Role '{required_role}' required") return role_checker