Compare commits

..

2 Commits

Author SHA1 Message Date
Sharang Parnerkar f474699279 fix(core): JWKS refresh-on-failure in M7.1 auth middleware
CI / Check (pull_request) Successful in 8m17s
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped
Without this, every Keycloak signing-key rotation produces a silent
401 storm against every request until the agent restarts — the cached
JWKS is held forever and never reconciled against KC.

Now: when `kid` isn't in the cached JWKS or the matching key fails
signature verification, we classify the failure as Stale, force a JWKS
refresh, and retry once. Anything else (expired, malformed, missing
tenant_id) is Permanent and short-circuits straight to 401.

* Splits the path into a pure `try_validate(token, header, kid, jwks)`
  helper returning a `ValidationError { Stale | Permanent }` enum.
* `fetch_or_get_jwks(state, force)` takes a force flag and holds the
  write lock across the network fetch so concurrent refreshers don't
  all hammer Keycloak when keys rotate (the second writer reuses what
  the first put in cache).
* Adds a unit test for the kid-not-found Stale classification.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-04 16:40:55 +02:00
sharang 116293519d M7.1 smoke harness: lift auth to compliance-core + compliance-smoke service (#83)
CI / Check (push) Has been cancelled
CI / Detect Changes (push) Has been cancelled
CI / Deploy Agent (push) Has been cancelled
CI / Deploy Dashboard (push) Has been cancelled
CI / Deploy Docs (push) Has been cancelled
CI / Deploy MCP (push) Has been cancelled
2026-06-04 14:38:35 +00:00
15 changed files with 722 additions and 349 deletions
Generated
+18 -1
View File
@@ -687,7 +687,6 @@ dependencies = [
"tokio-cron-scheduler", "tokio-cron-scheduler",
"tokio-stream", "tokio-stream",
"tokio-tungstenite 0.26.2", "tokio-tungstenite 0.26.2",
"tower",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@@ -701,19 +700,23 @@ dependencies = [
name = "compliance-core" name = "compliance-core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"axum",
"bson", "bson",
"chrono", "chrono",
"hex", "hex",
"jsonwebtoken",
"mongodb", "mongodb",
"opentelemetry", "opentelemetry",
"opentelemetry-appender-tracing", "opentelemetry-appender-tracing",
"opentelemetry-otlp", "opentelemetry-otlp",
"opentelemetry_sdk", "opentelemetry_sdk",
"reqwest",
"secrecy", "secrecy",
"serde", "serde",
"serde_json", "serde_json",
"sha2", "sha2",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",
"tracing-subscriber", "tracing-subscriber",
@@ -827,6 +830,20 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
] ]
[[package]]
name = "compliance-smoke"
version = "0.1.0"
dependencies = [
"axum",
"compliance-core",
"reqwest",
"serde",
"serde_json",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "console_error_panic_hook" name = "console_error_panic_hook"
version = "0.1.7" version = "0.1.7"
+1
View File
@@ -6,6 +6,7 @@ members = [
"compliance-graph", "compliance-graph",
"compliance-dast", "compliance-dast",
"compliance-mcp", "compliance-mcp",
"compliance-smoke",
] ]
resolver = "2" resolver = "2"
-1
View File
@@ -52,5 +52,4 @@ mongodb = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
secrecy = { workspace = true } secrecy = { workspace = true }
axum = "0.8" axum = "0.8"
tower = { version = "0.5", features = ["util"] }
tower-http = { version = "0.6", features = ["cors"] } tower-http = { version = "0.6", features = ["cors"] }
+10 -202
View File
@@ -1,29 +1,10 @@
//! M7.1 — JWT validation + tenant context propagation.
//!
//! `require_jwt_auth` validates a Bearer JWT against Keycloak's JWKS and
//! attaches a `TenantContext` to the request extensions. Downstream
//! middleware (`require_tenant_status`) and Axum extractors (`TenantCtx`)
//! read it from there.
//!
//! Skipped paths:
//! * `/api/v1/health` — Kubernetes liveness; never authenticated.
//!
//! Failure modes:
//! * No `JwksState` extension → pass-through (single-tenant dev mode).
//! * Missing / malformed Bearer header → 401.
//! * Signature / expiry invalid → 401.
//! * Claims present but tenant_id missing → 401 (treated as a malformed
//! token; the realm must always issue tenant_id).
use std::sync::Arc; use std::sync::Arc;
use axum::{ use axum::{
extract::Request, extract::Request,
http::Method,
middleware::Next, middleware::Next,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use compliance_core::{OrgRole, TenantContext, TenantStatus};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation}; use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Deserialize; use serde::Deserialize;
@@ -36,39 +17,20 @@ pub struct JwksState {
pub jwks_url: String, pub jwks_url: String,
} }
/// Raw shape of the JWT payload — matches the breakpilot-dev realm's
/// protocol-mapper output. Missing fields default to "" / empty so a
/// realm that hasn't been fully wired yet still validates.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Claims { struct Claims {
#[allow(dead_code)]
sub: String, sub: String,
#[serde(default)]
name: Option<String>,
#[serde(default)]
preferred_username: Option<String>,
#[serde(default)]
tenant_id: String,
#[serde(default)]
tenant_slug: String,
#[serde(default)]
org_roles: Vec<String>,
#[serde(default)]
products: Vec<String>,
#[serde(default)]
plan: String,
#[serde(default)]
tenant_status: Option<TenantStatus>,
} }
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"]; const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS /// Middleware that validates Bearer JWT tokens against Keycloak's JWKS.
/// and attaches a `TenantContext` extension on success.
/// ///
/// Skips validation for the health endpoint. /// Skips validation for health check endpoints.
/// If `JwksState` is not present (Keycloak not configured), requests /// If `JwksState` is not present as an extension (keycloak not configured),
/// pass through and downstream code must handle the missing context. /// all requests pass through.
pub async fn require_jwt_auth(mut request: Request, next: Next) -> Response { pub async fn require_jwt_auth(request: Request, next: Next) -> Response {
let path = request.uri().path(); let path = request.uri().path();
if PUBLIC_ENDPOINTS.contains(&path) { if PUBLIC_ENDPOINTS.contains(&path) {
@@ -91,10 +53,7 @@ pub async fn require_jwt_auth(mut request: Request, next: Next) -> Response {
}; };
match validate_token(token, &jwks_state).await { match validate_token(token, &jwks_state).await {
Ok(ctx) => { Ok(()) => next.run(request).await,
request.extensions_mut().insert(ctx);
next.run(request).await
}
Err(e) => { Err(e) => {
tracing::warn!("JWT validation failed: {e}"); tracing::warn!("JWT validation failed: {e}");
(StatusCode::UNAUTHORIZED, "Invalid token").into_response() (StatusCode::UNAUTHORIZED, "Invalid token").into_response()
@@ -102,47 +61,7 @@ pub async fn require_jwt_auth(mut request: Request, next: Next) -> Response {
} }
} }
/// Middleware that enforces the M7.1 `tenant_status` contract. async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> {
///
/// * `Active` / `Trial` / `Demo` — pass through.
/// * `Frozen` — read-only after cancel / non-payment. Writes return 402.
/// * `Archived` — data-retention window closed. Every request returns 410.
///
/// Pass-through when no `TenantContext` is present (single-tenant dev or
/// the upstream JWT middleware ran without `JwksState`).
pub async fn require_tenant_status(request: Request, next: Next) -> Response {
let ctx = match request.extensions().get::<TenantContext>() {
Some(c) => c.clone(),
None => return next.run(request).await,
};
if ctx.status.is_archived() {
return (
StatusCode::GONE,
"Tenant archived — data retention window closed",
)
.into_response();
}
if ctx.status.is_frozen() && is_write(request.method()) {
return (
StatusCode::PAYMENT_REQUIRED,
"Tenant frozen — read-only. Re-activate to resume writes.",
)
.into_response();
}
next.run(request).await
}
/// Treat anything other than GET/HEAD/OPTIONS as a write. Good enough for
/// REST. The few exceptions (e.g. read-side POSTs) can opt out at the
/// handler level once we have them.
fn is_write(m: &Method) -> bool {
!matches!(m, &Method::GET | &Method::HEAD | &Method::OPTIONS)
}
async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext, String> {
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?; let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
let kid = header let kid = header
@@ -164,37 +83,10 @@ async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext,
validation.validate_exp = true; validation.validate_exp = true;
validation.validate_aud = false; validation.validate_aud = false;
let data = decode::<Claims>(token, &decoding_key, &validation) decode::<Claims>(token, &decoding_key, &validation)
.map_err(|e| format!("token validation failed: {e}"))?; .map_err(|e| format!("token validation failed: {e}"))?;
claims_to_context(data.claims) Ok(())
}
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
/// Pulled out for unit testing — no I/O.
fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
if c.tenant_id.is_empty() {
return Err("JWT is missing tenant_id claim".to_string());
}
let status = c.tenant_status.unwrap_or_else(|| {
tracing::warn!(
"JWT missing tenant_status claim for tenant {} — defaulting to Trial",
c.tenant_id
);
TenantStatus::Trial
});
Ok(TenantContext {
tenant_id: c.tenant_id,
tenant_slug: c.tenant_slug,
org_roles: c.org_roles.iter().map(|r| OrgRole::parse(r)).collect(),
products: c.products,
plan: c.plan,
status,
user_id: c.sub,
user_name: c.name.or(c.preferred_username),
})
} }
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> { async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
@@ -219,87 +111,3 @@ async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
Ok(jwks) Ok(jwks)
} }
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
fn base_claims() -> Claims {
Claims {
sub: "user-123".to_string(),
name: Some("Alice Acme".to_string()),
preferred_username: None,
tenant_id: "00000000-0000-0000-0000-000000000001".to_string(),
tenant_slug: "acme".to_string(),
org_roles: vec!["IT_ADMIN".to_string()],
products: vec!["compliance".to_string()],
plan: "professional".to_string(),
tenant_status: Some(TenantStatus::Active),
}
}
#[test]
fn claims_to_context_happy_path() {
let ctx = claims_to_context(base_claims()).expect("should map");
assert_eq!(ctx.tenant_id, "00000000-0000-0000-0000-000000000001");
assert_eq!(ctx.tenant_slug, "acme");
assert_eq!(ctx.org_roles, vec![OrgRole::ItAdmin]);
assert_eq!(ctx.products, vec!["compliance"]);
assert_eq!(ctx.plan, "professional");
assert_eq!(ctx.status, TenantStatus::Active);
assert_eq!(ctx.user_id, "user-123");
assert_eq!(ctx.user_name.as_deref(), Some("Alice Acme"));
}
#[test]
fn claims_to_context_rejects_missing_tenant_id() {
let mut c = base_claims();
c.tenant_id = "".to_string();
let err = claims_to_context(c).expect_err("should reject");
assert!(err.contains("tenant_id"));
}
#[test]
fn claims_to_context_defaults_status_when_missing() {
let mut c = base_claims();
c.tenant_status = None;
let ctx = claims_to_context(c).expect("should map");
assert_eq!(ctx.status, TenantStatus::Trial);
}
#[test]
fn claims_to_context_falls_back_to_preferred_username() {
let mut c = base_claims();
c.name = None;
c.preferred_username = Some("alice@acme.dev".to_string());
let ctx = claims_to_context(c).expect("should map");
assert_eq!(ctx.user_name.as_deref(), Some("alice@acme.dev"));
}
#[test]
fn claims_to_context_parses_multiple_roles() {
let mut c = base_claims();
c.org_roles = vec![
"IT_ADMIN".to_string(),
"CXO".to_string(),
"GARBAGE".to_string(),
];
let ctx = claims_to_context(c).expect("should map");
assert_eq!(
ctx.org_roles,
vec![OrgRole::ItAdmin, OrgRole::Cxo, OrgRole::Unknown]
);
}
#[test]
fn is_write_detects_methods() {
assert!(!is_write(&Method::GET));
assert!(!is_write(&Method::HEAD));
assert!(!is_write(&Method::OPTIONS));
assert!(is_write(&Method::POST));
assert!(is_write(&Method::PUT));
assert!(is_write(&Method::PATCH));
assert!(is_write(&Method::DELETE));
}
}
-2
View File
@@ -2,7 +2,5 @@ pub mod auth_middleware;
pub mod handlers; pub mod handlers;
pub mod routes; pub mod routes;
pub mod server; pub mod server;
pub mod tenant_ctx;
pub use server::start_api_server; pub use server::start_api_server;
pub use tenant_ctx::TenantCtx;
+3 -8
View File
@@ -8,7 +8,7 @@ use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use crate::api::auth_middleware::{require_jwt_auth, require_tenant_status, JwksState}; use crate::api::auth_middleware::{require_jwt_auth, JwksState};
use crate::api::routes; use crate::api::routes;
use crate::error::AgentError; use crate::error::AgentError;
@@ -44,14 +44,9 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A
jwks_url, jwks_url,
}; };
tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'"); tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'");
// Layers execute outermost-first. The Extension must run before
// require_jwt_auth so that middleware can read JwksState from
// request extensions, and the status gate must run after the
// JWT auth so TenantContext is in extensions.
app = app app = app
.layer(middleware::from_fn(require_tenant_status)) .layer(Extension(jwks_state))
.layer(middleware::from_fn(require_jwt_auth)) .layer(middleware::from_fn(require_jwt_auth));
.layer(Extension(jwks_state));
} else { } else {
tracing::warn!("Keycloak not configured - API endpoints are unprotected"); tracing::warn!("Keycloak not configured - API endpoints are unprotected");
} }
@@ -1,123 +0,0 @@
//! M7.1 — integration tests for `require_tenant_status`.
//!
//! Exercises the middleware end-to-end through an Axum router so we
//! catch wiring bugs (extension propagation, method matching) that pure
//! unit tests would miss.
#![allow(clippy::expect_used, clippy::unwrap_used)]
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
middleware::{from_fn, Next},
response::Response,
routing::{get, post},
Router,
};
use compliance_agent::api::auth_middleware::require_tenant_status;
use compliance_core::{TenantContext, TenantStatus};
use tower::ServiceExt;
fn ctx_with(status: TenantStatus) -> TenantContext {
TenantContext {
tenant_id: "t-1".to_string(),
tenant_slug: "acme".to_string(),
org_roles: vec![],
products: vec![],
plan: "starter".to_string(),
status,
user_id: "u-1".to_string(),
user_name: None,
}
}
fn router_with_ctx(ctx: Option<TenantContext>) -> Router {
let injector = move |mut req: Request, next: Next| {
let ctx = ctx.clone();
async move {
if let Some(c) = ctx {
req.extensions_mut().insert(c);
}
next.run(req).await
}
};
Router::new()
.route("/r", get(|| async { "read" }))
.route("/w", post(|| async { "write" }))
.layer(from_fn(require_tenant_status))
.layer(from_fn(injector))
}
async fn call(router: Router, method: Method, path: &str) -> Response {
let req = Request::builder()
.method(method)
.uri(path)
.body(Body::empty())
.expect("request build");
router.oneshot(req).await.expect("oneshot")
}
#[tokio::test]
async fn active_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Active)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn trial_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Trial)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn demo_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Demo)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn frozen_tenant_can_read_but_not_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Frozen)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(
call(r, Method::POST, "/w").await.status(),
StatusCode::PAYMENT_REQUIRED
);
}
#[tokio::test]
async fn archived_tenant_is_gone_on_every_method() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Archived)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::GONE
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::GONE);
}
#[tokio::test]
async fn no_context_passes_through() {
let r = router_with_ctx(None);
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
+13
View File
@@ -18,6 +18,15 @@ telemetry = [
"dep:tracing-subscriber", "dep:tracing-subscriber",
"dep:tracing", "dep:tracing",
] ]
# Pulls in the M7.1 Axum middleware + extractor. Consumers that don't
# embed an HTTP server (e.g. the wasm dashboard frontend) leave it off.
axum = [
"dep:axum",
"dep:jsonwebtoken",
"dep:reqwest",
"dep:tokio",
"dep:tracing",
]
[dependencies] [dependencies]
serde = { workspace = true } serde = { workspace = true }
@@ -37,3 +46,7 @@ opentelemetry-appender-tracing = { version = "0.29", optional = true }
tracing-opentelemetry = { version = "0.30", optional = true } tracing-opentelemetry = { version = "0.30", optional = true }
tracing-subscriber = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true }
tracing = { workspace = true, optional = true } tracing = { workspace = true, optional = true }
axum = { version = "0.8", optional = true }
jsonwebtoken = { version = "9", optional = true }
reqwest = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
+390
View File
@@ -0,0 +1,390 @@
//! M7.1 — JWT validation + tenant context propagation.
//!
//! `require_jwt_auth` validates a Bearer JWT against Keycloak's JWKS and
//! attaches a [`TenantContext`] to the request extensions. Downstream
//! middleware ([`require_tenant_status`]) and Axum extractors
//! ([`crate::tenant_ctx::TenantCtx`]) read it from there.
//!
//! Skipped paths:
//! * `/api/v1/health` — Kubernetes liveness; never authenticated.
//!
//! Failure modes:
//! * No `JwksState` extension → pass-through (single-tenant dev mode).
//! * Missing / malformed Bearer header → 401.
//! * Signature / expiry invalid → 401.
//! * Claims present but tenant_id missing → 401 (treated as a malformed
//! token; the realm must always issue tenant_id).
use std::sync::Arc;
use axum::{
extract::Request,
http::Method,
middleware::Next,
response::{IntoResponse, Response},
};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
use reqwest::StatusCode;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::{OrgRole, TenantContext, TenantStatus};
/// Cached JWKS from Keycloak for token validation.
#[derive(Clone)]
pub struct JwksState {
pub jwks: Arc<RwLock<Option<JwkSet>>>,
pub jwks_url: String,
}
/// Raw shape of the JWT payload — matches the breakpilot-dev realm's
/// protocol-mapper output. Missing fields default to "" / empty so a
/// realm that hasn't been fully wired yet still validates.
#[derive(Debug, Deserialize)]
struct Claims {
sub: String,
#[serde(default)]
name: Option<String>,
#[serde(default)]
preferred_username: Option<String>,
#[serde(default)]
tenant_id: String,
#[serde(default)]
tenant_slug: String,
#[serde(default)]
org_roles: Vec<String>,
#[serde(default)]
products: Vec<String>,
#[serde(default)]
plan: String,
#[serde(default)]
tenant_status: Option<TenantStatus>,
}
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS
/// and attaches a `TenantContext` extension on success.
///
/// Skips validation for the health endpoint.
/// If `JwksState` is not present (Keycloak not configured), requests
/// pass through and downstream code must handle the missing context.
pub async fn require_jwt_auth(mut request: Request, next: Next) -> Response {
let path = request.uri().path();
if PUBLIC_ENDPOINTS.contains(&path) {
return next.run(request).await;
}
let jwks_state = match request.extensions().get::<JwksState>() {
Some(s) => s.clone(),
None => return next.run(request).await,
};
let auth_header = match request.headers().get("authorization") {
Some(h) => h,
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
};
let token = match auth_header.to_str() {
Ok(s) if s.starts_with("Bearer ") => &s[7..],
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
};
match validate_token(token, &jwks_state).await {
Ok(ctx) => {
request.extensions_mut().insert(ctx);
next.run(request).await
}
Err(e) => {
tracing::warn!("JWT validation failed: {e}");
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
}
}
}
/// Middleware that enforces the M7.1 `tenant_status` contract.
///
/// * `Active` / `Trial` / `Demo` — pass through.
/// * `Frozen` — read-only after cancel / non-payment. Writes return 402.
/// * `Archived` — data-retention window closed. Every request returns 410.
///
/// Pass-through when no `TenantContext` is present (single-tenant dev or
/// the upstream JWT middleware ran without `JwksState`).
pub async fn require_tenant_status(request: Request, next: Next) -> Response {
let ctx = match request.extensions().get::<TenantContext>() {
Some(c) => c.clone(),
None => return next.run(request).await,
};
if ctx.status.is_archived() {
return (
StatusCode::GONE,
"Tenant archived — data retention window closed",
)
.into_response();
}
if ctx.status.is_frozen() && is_write(request.method()) {
return (
StatusCode::PAYMENT_REQUIRED,
"Tenant frozen — read-only. Re-activate to resume writes.",
)
.into_response();
}
next.run(request).await
}
/// Treat anything other than GET/HEAD/OPTIONS as a write. Good enough for
/// REST. The few exceptions (e.g. read-side POSTs) can opt out at the
/// handler level once we have them.
fn is_write(m: &Method) -> bool {
!matches!(m, &Method::GET | &Method::HEAD | &Method::OPTIONS)
}
async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext, String> {
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
let kid = header
.kid
.clone()
.ok_or_else(|| "JWT missing kid header".to_string())?;
// First try against whatever's currently cached. If the kid isn't
// there or the signature doesn't verify, the cached JWKS is most
// likely stale (KC rotated keys) — refresh once and retry before
// giving up. Without this every key rotation produces a silent 401
// storm that only goes away when the agent restarts.
let jwks = fetch_or_get_jwks(state, false).await?;
match try_validate(token, &header, &kid, &jwks) {
Ok(ctx) => Ok(ctx),
Err(ValidationError::Permanent(e)) => Err(e),
Err(ValidationError::Stale(reason)) => {
tracing::info!(
kid = %kid,
reason = %reason,
"JWKS appears stale — forcing refresh and retrying"
);
let jwks = fetch_or_get_jwks(state, true).await?;
try_validate(token, &header, &kid, &jwks).map_err(|e| match e {
ValidationError::Stale(s) | ValidationError::Permanent(s) => s,
})
}
}
}
#[derive(Debug)]
enum ValidationError {
/// Refresh-eligible: cached JWKS may be stale.
Stale(String),
/// Refusing the token regardless of JWKS freshness.
Permanent(String),
}
fn try_validate(
token: &str,
header: &jsonwebtoken::Header,
kid: &str,
jwks: &JwkSet,
) -> Result<TenantContext, ValidationError> {
let jwk = match jwks
.keys
.iter()
.find(|k| k.common.key_id.as_deref() == Some(kid))
{
Some(j) => j,
None => {
return Err(ValidationError::Stale(
"no matching key found in JWKS".to_string(),
))
}
};
let decoding_key = DecodingKey::from_jwk(jwk)
.map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?;
let mut validation = Validation::new(header.alg);
validation.validate_exp = true;
validation.validate_aud = false;
let data = match decode::<Claims>(token, &decoding_key, &validation) {
Ok(d) => d,
Err(e) => {
// Signature mismatch is the other refresh-eligible failure:
// the matching kid is present but the key bytes don't match.
// Everything else (expired, malformed, etc.) is permanent.
return Err(
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) {
ValidationError::Stale(format!("token validation failed: {e}"))
} else {
ValidationError::Permanent(format!("token validation failed: {e}"))
},
);
}
};
claims_to_context(data.claims).map_err(ValidationError::Permanent)
}
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
/// Pulled out for unit testing — no I/O.
fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
if c.tenant_id.is_empty() {
return Err("JWT is missing tenant_id claim".to_string());
}
let status = c.tenant_status.unwrap_or_else(|| {
tracing::warn!(
"JWT missing tenant_status claim for tenant {} — defaulting to Trial",
c.tenant_id
);
TenantStatus::Trial
});
Ok(TenantContext {
tenant_id: c.tenant_id,
tenant_slug: c.tenant_slug,
org_roles: c.org_roles.iter().map(|r| OrgRole::parse(r)).collect(),
products: c.products,
plan: c.plan,
status,
user_id: c.sub,
user_name: c.name.or(c.preferred_username),
})
}
async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, String> {
if !force {
let cached = state.jwks.read().await;
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
// Hold the write lock across the fetch so concurrent refreshers
// don't all hammer Keycloak when keys rotate. If another writer
// already populated a fresh JWKS while we were waiting (and we
// weren't asked to force), use theirs.
let mut cached = state.jwks.write().await;
if !force {
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
let resp = reqwest::get(&state.jwks_url)
.await
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
let jwks: JwkSet = resp
.json()
.await
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
*cached = Some(jwks.clone());
Ok(jwks)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
fn base_claims() -> Claims {
Claims {
sub: "user-123".to_string(),
name: Some("Alice Acme".to_string()),
preferred_username: None,
tenant_id: "00000000-0000-0000-0000-000000000001".to_string(),
tenant_slug: "acme".to_string(),
org_roles: vec!["IT_ADMIN".to_string()],
products: vec!["compliance".to_string()],
plan: "professional".to_string(),
tenant_status: Some(TenantStatus::Active),
}
}
#[test]
fn claims_to_context_happy_path() {
let ctx = claims_to_context(base_claims()).expect("should map");
assert_eq!(ctx.tenant_id, "00000000-0000-0000-0000-000000000001");
assert_eq!(ctx.tenant_slug, "acme");
assert_eq!(ctx.org_roles, vec![OrgRole::ItAdmin]);
assert_eq!(ctx.products, vec!["compliance"]);
assert_eq!(ctx.plan, "professional");
assert_eq!(ctx.status, TenantStatus::Active);
assert_eq!(ctx.user_id, "user-123");
assert_eq!(ctx.user_name.as_deref(), Some("Alice Acme"));
}
#[test]
fn claims_to_context_rejects_missing_tenant_id() {
let mut c = base_claims();
c.tenant_id = "".to_string();
let err = claims_to_context(c).expect_err("should reject");
assert!(err.contains("tenant_id"));
}
#[test]
fn claims_to_context_defaults_status_when_missing() {
let mut c = base_claims();
c.tenant_status = None;
let ctx = claims_to_context(c).expect("should map");
assert_eq!(ctx.status, TenantStatus::Trial);
}
#[test]
fn claims_to_context_falls_back_to_preferred_username() {
let mut c = base_claims();
c.name = None;
c.preferred_username = Some("alice@acme.dev".to_string());
let ctx = claims_to_context(c).expect("should map");
assert_eq!(ctx.user_name.as_deref(), Some("alice@acme.dev"));
}
#[test]
fn claims_to_context_parses_multiple_roles() {
let mut c = base_claims();
c.org_roles = vec![
"IT_ADMIN".to_string(),
"CXO".to_string(),
"GARBAGE".to_string(),
];
let ctx = claims_to_context(c).expect("should map");
assert_eq!(
ctx.org_roles,
vec![OrgRole::ItAdmin, OrgRole::Cxo, OrgRole::Unknown]
);
}
#[test]
fn try_validate_returns_stale_when_kid_missing_from_jwks() {
// Empty JWKS — the kid we ask for can't possibly match. The error
// must classify as Stale so the caller refreshes JWKS and retries.
let jwks = JwkSet { keys: vec![] };
let header = jsonwebtoken::Header {
alg: jsonwebtoken::Algorithm::RS256,
kid: Some("kid-rotated-out".to_string()),
..Default::default()
};
let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks)
.expect_err("should fail");
match err {
ValidationError::Stale(s) => assert!(s.contains("no matching key")),
ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"),
}
}
#[test]
fn is_write_detects_methods() {
assert!(!is_write(&Method::GET));
assert!(!is_write(&Method::HEAD));
assert!(!is_write(&Method::OPTIONS));
assert!(is_write(&Method::POST));
assert!(is_write(&Method::PUT));
assert!(is_write(&Method::PATCH));
assert!(is_write(&Method::DELETE));
}
}
+5
View File
@@ -7,6 +7,11 @@ pub mod telemetry;
pub mod tenant; pub mod tenant;
pub mod traits; pub mod traits;
#[cfg(feature = "axum")]
pub mod auth;
#[cfg(feature = "axum")]
pub mod tenant_ctx;
pub use config::{AgentConfig, DashboardConfig}; pub use config::{AgentConfig, DashboardConfig};
pub use error::CoreError; pub use error::CoreError;
pub use tenant::{OrgRole, TenantContext, TenantStatus}; pub use tenant::{OrgRole, TenantContext, TenantStatus};
+6 -6
View File
@@ -1,9 +1,9 @@
//! Tenant context propagated through every authenticated request. //! Tenant context propagated through every authenticated request.
//! //!
//! This module is the M7.1 single source of truth for "who is this request //! M7.1 single source of truth for "who is this request for". Claims come
//! for". Claims come from a Keycloak-issued JWT and land here via //! from a Keycloak-issued JWT and land here via [`crate::auth::require_jwt_auth`]
//! `compliance-agent`'s `require_jwt_auth` middleware. Handlers reach into //! (enabled with the `axum` feature). Handlers reach into the request
//! the request extensions with the `TenantCtx` Axum extractor. //! extensions with the [`crate::tenant_ctx::TenantCtx`] extractor.
//! //!
//! The shape mirrors the JWT claim names the breakpilot-platform realm //! The shape mirrors the JWT claim names the breakpilot-platform realm
//! emits (see `platform/orca-platform/dev/keycloak/realm-export.json`). //! emits (see `platform/orca-platform/dev/keycloak/realm-export.json`).
@@ -87,8 +87,8 @@ impl OrgRole {
} }
} }
/// Everything `compliance-agent` knows about the requesting tenant at the /// Everything we know about the requesting tenant at the moment a request
/// moment a request lands. Cheap to clone (every field is owned + small). /// lands. Cheap to clone (every field is owned + small).
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantContext { pub struct TenantContext {
/// `tenants.id` from the platform's tenant-registry (UUID). /// `tenants.id` from the platform's tenant-registry (UUID).
@@ -9,17 +9,18 @@
//! } //! }
//! ``` //! ```
//! //!
//! The middleware (`require_jwt_auth`) is responsible for inserting the //! The middleware ([`crate::auth::require_jwt_auth`]) is responsible for
//! context into the request extensions. If it's missing on a route that //! inserting the context into the request extensions. If it's missing on
//! uses this extractor, that's a bug in the wiring — we return 401 so the //! a route that uses this extractor, that's a bug in the wiring — we
//! caller sees an auth failure rather than a 500. //! return 401 so the caller sees an auth failure rather than a 500.
use axum::{ use axum::{
extract::FromRequestParts, extract::FromRequestParts,
http::{request::Parts, StatusCode}, http::{request::Parts, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use compliance_core::TenantContext;
use crate::TenantContext;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TenantCtx(pub TenantContext); pub struct TenantCtx(pub TenantContext);
@@ -57,8 +58,8 @@ where
#[allow(clippy::expect_used, clippy::unwrap_used)] #[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests { mod tests {
use super::*; use super::*;
use crate::TenantStatus;
use axum::http::Request; use axum::http::Request;
use compliance_core::TenantStatus;
fn ctx() -> TenantContext { fn ctx() -> TenantContext {
TenantContext { TenantContext {
+22
View File
@@ -0,0 +1,22 @@
[package]
name = "compliance-smoke"
version = "0.1.0"
edition = "2021"
description = "Tiny Axum service exercising compliance-core M7.1 tenant gating. Run smoke.sh against it before merging anything that touches the auth/tenant path."
[lints]
workspace = true
[[bin]]
name = "compliance-smoke"
path = "src/main.rs"
[dependencies]
compliance-core = { workspace = true, features = ["axum"] }
axum = "0.8"
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
reqwest = { workspace = true }
+111
View File
@@ -0,0 +1,111 @@
//! M7.1 smoke service.
//!
//! A standalone Axum binary whose only job is to host the
//! [`compliance_core::auth`] middleware + [`compliance_core::tenant_ctx`]
//! extractor on three endpoints, so `scripts/smoke.sh` can prove the
//! tenant-gating contract end-to-end before any auth-path PR merges.
//!
//! Endpoints:
//! * `GET /api/v1/health` — public, never authenticated.
//! * `GET /api/v1/echo` — protected read; returns the [`TenantContext`].
//! * `POST /api/v1/echo` — protected write; exercises the `Frozen → 402`
//! gate on the same handler.
//!
//! Configuration (env):
//! * `KEYCLOAK_URL` — e.g. `http://localhost:8080`. Required.
//! * `KEYCLOAK_REALM` — e.g. `certifai`. Required.
//! * `SMOKE_PORT` — defaults to `3010`.
use std::sync::Arc;
use axum::{middleware, routing::get, Extension, Json, Router};
use compliance_core::{
auth::{require_jwt_auth, require_tenant_status, JwksState},
tenant_ctx::TenantCtx,
};
use serde::Serialize;
use tokio::sync::RwLock;
#[derive(Serialize)]
struct EchoResponse {
method: &'static str,
tenant_id: String,
tenant_slug: String,
plan: String,
status: String,
products: Vec<String>,
org_roles: Vec<String>,
user_id: String,
user_name: Option<String>,
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "ok": true }))
}
async fn echo_read(TenantCtx(ctx): TenantCtx) -> Json<EchoResponse> {
Json(echo(ctx, "GET"))
}
async fn echo_write(TenantCtx(ctx): TenantCtx) -> Json<EchoResponse> {
Json(echo(ctx, "POST"))
}
fn echo(ctx: compliance_core::TenantContext, method: &'static str) -> EchoResponse {
EchoResponse {
method,
tenant_id: ctx.tenant_id,
tenant_slug: ctx.tenant_slug,
plan: ctx.plan,
status: ctx.status.to_string(),
products: ctx.products,
org_roles: ctx.org_roles.iter().map(|r| format!("{r:?}")).collect(),
user_id: ctx.user_id,
user_name: ctx.user_name,
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let kc_url = std::env::var("KEYCLOAK_URL")
.map_err(|_| "KEYCLOAK_URL is required (e.g. http://localhost:8080)")?;
let kc_realm = std::env::var("KEYCLOAK_REALM")
.map_err(|_| "KEYCLOAK_REALM is required (e.g. certifai)")?;
let port: u16 = std::env::var("SMOKE_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(3010);
let jwks_url = format!("{kc_url}/realms/{kc_realm}/protocol/openid-connect/certs");
let jwks_state = JwksState {
jwks: Arc::new(RwLock::new(None)),
jwks_url: jwks_url.clone(),
};
// Layers execute outermost-first. The Extension must be registered
// before `require_jwt_auth` so the middleware can read JwksState; the
// status gate must run after JWT so `TenantContext` is in extensions.
let app = Router::new()
.route("/api/v1/health", get(health))
.route("/api/v1/echo", get(echo_read).post(echo_write))
.layer(middleware::from_fn(require_tenant_status))
.layer(middleware::from_fn(require_jwt_auth))
.layer(Extension(jwks_state));
let addr = format!("0.0.0.0:{port}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!(
port,
jwks = %jwks_url,
"compliance-smoke listening — try `scripts/smoke.sh`"
);
axum::serve(listener, app).await?;
Ok(())
}
+136
View File
@@ -0,0 +1,136 @@
#!/usr/bin/env bash
# M7.1 tenant-gating smoke test.
#
# Drives compliance-smoke against a live Keycloak realm with five test
# users (one per tenant_status), asserts the response code on each
# endpoint, and exits non-zero on any mismatch.
#
# Pre-reqs (one-time):
# * KC up at $KC_URL with realm $KC_REALM
# * Client $KC_CLIENT has direct-access-grants enabled
# * Users + tenant_status mappers per certifai/keycloak/realm-export.json
# * compliance-smoke binary running and reachable at $SMOKE_URL
#
# Usage:
# scripts/smoke.sh # uses defaults below
# SMOKE_URL=... scripts/smoke.sh
set -euo pipefail
KC_URL="${KC_URL:-http://localhost:8080}"
KC_REALM="${KC_REALM:-certifai}"
KC_CLIENT="${KC_CLIENT:-certifai-dashboard}"
SMOKE_URL="${SMOKE_URL:-http://localhost:3010}"
readonly TOKEN_ENDPOINT="${KC_URL}/realms/${KC_REALM}/protocol/openid-connect/token"
PASS=0
FAIL=0
red() { printf '\033[31m%s\033[0m' "$*"; }
green() { printf '\033[32m%s\033[0m' "$*"; }
yellow() { printf '\033[33m%s\033[0m' "$*"; }
# Fetches an access token via direct access grant. Echoes the raw token.
get_token() {
local user="$1" pass="$2"
curl -sS -X POST "$TOKEN_ENDPOINT" \
-H 'Content-Type: application/x-www-form-urlencoded' \
-d "grant_type=password" \
-d "client_id=${KC_CLIENT}" \
-d "username=${user}" \
-d "password=${pass}" \
-d "scope=openid" \
| sed -n 's/.*"access_token":"\([^"]*\)".*/\1/p'
}
# Hits SMOKE_URL$path with the given method and (optional) bearer token,
# asserts the response status code matches $want.
assert_status() {
local label="$1" method="$2" path="$3" want="$4" token="${5:-}"
local args=(-sS -o /dev/null -w '%{http_code}' -X "$method" "${SMOKE_URL}${path}")
if [[ -n "$token" ]]; then
args+=(-H "Authorization: Bearer ${token}")
fi
local got
got=$(curl "${args[@]}")
if [[ "$got" == "$want" ]]; then
printf ' %s %s %-4s %-15s → %s\n' "$(green PASS)" "$label" "$method" "$path" "$got"
PASS=$((PASS + 1))
else
printf ' %s %s %-4s %-15s → got %s, want %s\n' "$(red FAIL)" "$label" "$method" "$path" "$got" "$want"
FAIL=$((FAIL + 1))
fi
}
header() {
printf '\n%s %s\n' "$(yellow '##')" "$1"
}
# ---- Pre-flight ----------------------------------------------------------
header "Pre-flight"
if ! curl -sS -o /dev/null -w '%{http_code}\n' "${SMOKE_URL}/api/v1/health" | grep -q '^200$'; then
printf ' %s smoke service not reachable at %s\n' "$(red ERR)" "$SMOKE_URL"
exit 2
fi
if ! curl -sS -o /dev/null -w '%{http_code}\n' "${KC_URL}/realms/${KC_REALM}/.well-known/openid-configuration" | grep -q '^200$'; then
printf ' %s Keycloak realm %s not reachable at %s\n' "$(red ERR)" "$KC_REALM" "$KC_URL"
exit 2
fi
printf ' %s smoke service + Keycloak both up\n' "$(green OK)"
# ---- Public endpoint --------------------------------------------------
header "Public endpoint (no auth required)"
assert_status anon GET /api/v1/health 200
# ---- Anonymous access to protected endpoints ----------------------------
header "Anonymous → 401 on protected endpoints"
assert_status anon GET /api/v1/echo 401
assert_status anon POST /api/v1/echo 401
# ---- Bad token ----------------------------------------------------------
header "Bad token → 401"
assert_status bogus GET /api/v1/echo 401 "not-a-real-jwt"
assert_status bogus POST /api/v1/echo 401 "not-a-real-jwt"
# ---- Active tenant (admin user) -----------------------------------------
header "admin@certifai.local (active) → full access"
TOKEN=$(get_token admin@certifai.local admin)
if [[ -z "$TOKEN" ]]; then
printf ' %s failed to fetch token for admin\n' "$(red ERR)"
exit 2
fi
assert_status active GET /api/v1/echo 200 "$TOKEN"
assert_status active POST /api/v1/echo 200 "$TOKEN"
# ---- Active tenant (USER role) ------------------------------------------
header "user@certifai.local (active) → full access"
TOKEN=$(get_token user@certifai.local user)
assert_status active GET /api/v1/echo 200 "$TOKEN"
assert_status active POST /api/v1/echo 200 "$TOKEN"
# ---- Trial tenant -------------------------------------------------------
header "trial@acme.local (trial) → full access"
TOKEN=$(get_token trial@acme.local trial)
assert_status trial GET /api/v1/echo 200 "$TOKEN"
assert_status trial POST /api/v1/echo 200 "$TOKEN"
# ---- Frozen tenant ------------------------------------------------------
header "frozen@acme.local (frozen) → read-only, writes 402"
TOKEN=$(get_token frozen@acme.local frozen)
assert_status frozen GET /api/v1/echo 200 "$TOKEN"
assert_status frozen POST /api/v1/echo 402 "$TOKEN"
# ---- Archived tenant ----------------------------------------------------
header "archived@acme.local (archived) → 410 everywhere"
TOKEN=$(get_token archived@acme.local archived)
assert_status archived GET /api/v1/echo 410 "$TOKEN"
assert_status archived POST /api/v1/echo 410 "$TOKEN"
# ---- Summary ------------------------------------------------------------
printf '\n'
if [[ "$FAIL" -gt 0 ]]; then
printf '%s %d passed, %d failed\n' "$(red FAIL)" "$PASS" "$FAIL"
exit 1
fi
printf '%s %d/%d assertions passed\n' "$(green PASS)" "$PASS" "$PASS"