From 324b1378624410dc3881aba22263a864614983e4 Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar <30073382+mighty840@users.noreply.github.com> Date: Wed, 17 Jun 2026 11:07:56 +0200 Subject: [PATCH] feat(m7.1): wire compliance-agent to compliance-core auth + status gate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Supersedes #82. Deletes the inline JWT middleware in compliance-agent (now stale — missing JWKS refresh from #84 and tenant extraction from #83) and imports require_jwt_auth, require_tenant_status, JwksState from compliance_core::auth. Wires the status gate into the server's layer stack: Extension(jwks_state) → require_jwt_auth → require_tenant_status → handler. Adds the integration test from #82, retargeted to compliance_core::auth::require_tenant_status. Test plan - cargo fmt --all clean - cargo clippy --workspace --exclude compliance-dashboard -- -D warnings clean (matches baseline) - cargo test -p compliance-core --lib — 7 tests pass - cargo test -p compliance-agent --lib — 228 tests pass - cargo test -p compliance-agent --test tenant_status_middleware — 6 tests pass - scripts/smoke.sh against live certifai KC — 15/15 cells pass (anon, bogus, active×2, trial, frozen, archived × {GET/health, GET/echo, POST/echo}) Caveats - M7.1 only — status gate + claim extraction. Per-collection tenant_id scoping (M7.2) still pending; agent will still serve any Active/Trial tenant's data to any caller until the ~38 query call-sites use compliance_core::db::tenant_filter. Co-Authored-By: Claude Opus 4.7 --- Cargo.lock | 1 + compliance-agent/Cargo.toml | 5 +- compliance-agent/src/api/auth_middleware.rs | 113 ---------------- compliance-agent/src/api/mod.rs | 1 - compliance-agent/src/api/server.rs | 11 +- .../tests/tenant_status_middleware.rs | 122 ++++++++++++++++++ 6 files changed, 134 insertions(+), 119 deletions(-) delete mode 100644 compliance-agent/src/api/auth_middleware.rs create mode 100644 compliance-agent/tests/tenant_status_middleware.rs diff --git a/Cargo.lock b/Cargo.lock index efd22e7..200bbb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,6 +687,7 @@ dependencies = [ "tokio-cron-scheduler", "tokio-stream", "tokio-tungstenite 0.26.2", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/compliance-agent/Cargo.toml b/compliance-agent/Cargo.toml index e0a129f..e8b81d6 100644 --- a/compliance-agent/Cargo.toml +++ b/compliance-agent/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" workspace = true [dependencies] -compliance-core = { workspace = true, features = ["mongodb", "telemetry"] } +compliance-core = { workspace = true, features = ["mongodb", "telemetry", "axum"] } compliance-graph = { path = "../compliance-graph" } compliance-dast = { path = "../compliance-dast" } serde = { workspace = true } @@ -44,7 +44,8 @@ dashmap = { workspace = true } tokio-stream = { workspace = true } [dev-dependencies] -compliance-core = { workspace = true, features = ["mongodb"] } +compliance-core = { workspace = true, features = ["mongodb", "axum"] } +tower = { version = "0.5", features = ["util"] } reqwest = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } diff --git a/compliance-agent/src/api/auth_middleware.rs b/compliance-agent/src/api/auth_middleware.rs deleted file mode 100644 index 90e3a49..0000000 --- a/compliance-agent/src/api/auth_middleware.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::sync::Arc; - -use axum::{ - extract::Request, - middleware::Next, - response::{IntoResponse, Response}, -}; -use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation}; -use reqwest::StatusCode; -use serde::Deserialize; -use tokio::sync::RwLock; - -/// Cached JWKS from Keycloak for token validation. -#[derive(Clone)] -pub struct JwksState { - pub jwks: Arc>>, - pub jwks_url: String, -} - -#[derive(Debug, Deserialize)] -struct Claims { - #[allow(dead_code)] - sub: String, -} - -const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"]; - -/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS. -/// -/// Skips validation for health check endpoints. -/// If `JwksState` is not present as an extension (keycloak not configured), -/// all requests pass through. -pub async fn require_jwt_auth(request: Request, next: Next) -> Response { - let path = request.uri().path(); - - if PUBLIC_ENDPOINTS.contains(&path) { - return next.run(request).await; - } - - let jwks_state = match request.extensions().get::() { - Some(s) => s.clone(), - None => return next.run(request).await, - }; - - let auth_header = match request.headers().get("authorization") { - Some(h) => h, - None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(), - }; - - let token = match auth_header.to_str() { - Ok(s) if s.starts_with("Bearer ") => &s[7..], - _ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(), - }; - - match validate_token(token, &jwks_state).await { - Ok(()) => next.run(request).await, - Err(e) => { - tracing::warn!("JWT validation failed: {e}"); - (StatusCode::UNAUTHORIZED, "Invalid token").into_response() - } - } -} - -async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> { - let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?; - - let kid = header - .kid - .ok_or_else(|| "JWT missing kid header".to_string())?; - - let jwks = fetch_or_get_jwks(state).await?; - - let jwk = jwks - .keys - .iter() - .find(|k| k.common.key_id.as_deref() == Some(&kid)) - .ok_or_else(|| "no matching key found in JWKS".to_string())?; - - let decoding_key = - DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?; - - let mut validation = Validation::new(header.alg); - validation.validate_exp = true; - validation.validate_aud = false; - - decode::(token, &decoding_key, &validation) - .map_err(|e| format!("token validation failed: {e}"))?; - - Ok(()) -} - -async fn fetch_or_get_jwks(state: &JwksState) -> Result { - { - let cached = state.jwks.read().await; - if let Some(ref jwks) = *cached { - return Ok(jwks.clone()); - } - } - - let resp = reqwest::get(&state.jwks_url) - .await - .map_err(|e| format!("failed to fetch JWKS: {e}"))?; - - let jwks: JwkSet = resp - .json() - .await - .map_err(|e| format!("failed to parse JWKS: {e}"))?; - - let mut cached = state.jwks.write().await; - *cached = Some(jwks.clone()); - - Ok(jwks) -} diff --git a/compliance-agent/src/api/mod.rs b/compliance-agent/src/api/mod.rs index f969c3b..43a142c 100644 --- a/compliance-agent/src/api/mod.rs +++ b/compliance-agent/src/api/mod.rs @@ -1,4 +1,3 @@ -pub mod auth_middleware; pub mod handlers; pub mod routes; pub mod server; diff --git a/compliance-agent/src/api/server.rs b/compliance-agent/src/api/server.rs index 9b89714..8b65894 100644 --- a/compliance-agent/src/api/server.rs +++ b/compliance-agent/src/api/server.rs @@ -7,8 +7,9 @@ use tower_http::cors::CorsLayer; use tower_http::set_header::SetResponseHeaderLayer; use tower_http::trace::TraceLayer; +use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState}; + use crate::agent::ComplianceAgent; -use crate::api::auth_middleware::{require_jwt_auth, JwksState}; use crate::api::routes; use crate::error::AgentError; @@ -44,9 +45,13 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A jwks_url, }; tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'"); + // Layers execute outermost-first. Extension(jwks_state) must run + // before require_jwt_auth so the middleware can read it; the + // status gate runs after JWT so TenantContext is in extensions. app = app - .layer(Extension(jwks_state)) - .layer(middleware::from_fn(require_jwt_auth)); + .layer(middleware::from_fn(require_tenant_status)) + .layer(middleware::from_fn(require_jwt_auth)) + .layer(Extension(jwks_state)); } else { tracing::warn!("Keycloak not configured - API endpoints are unprotected"); } diff --git a/compliance-agent/tests/tenant_status_middleware.rs b/compliance-agent/tests/tenant_status_middleware.rs new file mode 100644 index 0000000..5ce0d9a --- /dev/null +++ b/compliance-agent/tests/tenant_status_middleware.rs @@ -0,0 +1,122 @@ +//! M7.1 — integration tests for `compliance_core::auth::require_tenant_status`. +//! +//! Exercises the middleware end-to-end through an Axum router so we +//! catch wiring bugs (extension propagation, method matching) that pure +//! unit tests would miss. + +#![allow(clippy::expect_used, clippy::unwrap_used)] + +use axum::{ + body::Body, + extract::Request, + http::{Method, StatusCode}, + middleware::{from_fn, Next}, + response::Response, + routing::{get, post}, + Router, +}; +use compliance_core::{auth::require_tenant_status, TenantContext, TenantStatus}; +use tower::ServiceExt; + +fn ctx_with(status: TenantStatus) -> TenantContext { + TenantContext { + tenant_id: "t-1".to_string(), + tenant_slug: "acme".to_string(), + org_roles: vec![], + products: vec![], + plan: "starter".to_string(), + status, + user_id: "u-1".to_string(), + user_name: None, + } +} + +fn router_with_ctx(ctx: Option) -> Router { + let injector = move |mut req: Request, next: Next| { + let ctx = ctx.clone(); + async move { + if let Some(c) = ctx { + req.extensions_mut().insert(c); + } + next.run(req).await + } + }; + + Router::new() + .route("/r", get(|| async { "read" })) + .route("/w", post(|| async { "write" })) + .layer(from_fn(require_tenant_status)) + .layer(from_fn(injector)) +} + +async fn call(router: Router, method: Method, path: &str) -> Response { + let req = Request::builder() + .method(method) + .uri(path) + .body(Body::empty()) + .expect("request build"); + router.oneshot(req).await.expect("oneshot") +} + +#[tokio::test] +async fn active_tenant_can_read_and_write() { + let r = router_with_ctx(Some(ctx_with(TenantStatus::Active))); + assert_eq!( + call(r.clone(), Method::GET, "/r").await.status(), + StatusCode::OK + ); + assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK); +} + +#[tokio::test] +async fn trial_tenant_can_read_and_write() { + let r = router_with_ctx(Some(ctx_with(TenantStatus::Trial))); + assert_eq!( + call(r.clone(), Method::GET, "/r").await.status(), + StatusCode::OK + ); + assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK); +} + +#[tokio::test] +async fn demo_tenant_can_read_and_write() { + let r = router_with_ctx(Some(ctx_with(TenantStatus::Demo))); + assert_eq!( + call(r.clone(), Method::GET, "/r").await.status(), + StatusCode::OK + ); + assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK); +} + +#[tokio::test] +async fn frozen_tenant_can_read_but_not_write() { + let r = router_with_ctx(Some(ctx_with(TenantStatus::Frozen))); + assert_eq!( + call(r.clone(), Method::GET, "/r").await.status(), + StatusCode::OK + ); + assert_eq!( + call(r, Method::POST, "/w").await.status(), + StatusCode::PAYMENT_REQUIRED + ); +} + +#[tokio::test] +async fn archived_tenant_is_gone_on_every_method() { + let r = router_with_ctx(Some(ctx_with(TenantStatus::Archived))); + assert_eq!( + call(r.clone(), Method::GET, "/r").await.status(), + StatusCode::GONE + ); + assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::GONE); +} + +#[tokio::test] +async fn no_context_passes_through() { + let r = router_with_ctx(None); + assert_eq!( + call(r.clone(), Method::GET, "/r").await.status(), + StatusCode::OK + ); + assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK); +}