Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 079f913024 | |||
| f583d0788c |
Generated
-1
@@ -687,7 +687,6 @@ dependencies = [
|
||||
"tokio-cron-scheduler",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite 0.26.2",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
||||
@@ -7,7 +7,7 @@ edition = "2021"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
compliance-core = { workspace = true, features = ["mongodb", "telemetry", "axum"] }
|
||||
compliance-core = { workspace = true, features = ["mongodb", "telemetry"] }
|
||||
compliance-graph = { path = "../compliance-graph" }
|
||||
compliance-dast = { path = "../compliance-dast" }
|
||||
serde = { workspace = true }
|
||||
@@ -44,8 +44,7 @@ dashmap = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
compliance-core = { workspace = true, features = ["mongodb", "axum"] }
|
||||
tower = { version = "0.5", features = ["util"] }
|
||||
compliance-core = { workspace = true, features = ["mongodb"] }
|
||||
reqwest = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Cached JWKS from Keycloak for token validation.
|
||||
#[derive(Clone)]
|
||||
pub struct JwksState {
|
||||
pub jwks: Arc<RwLock<Option<JwkSet>>>,
|
||||
pub jwks_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Claims {
|
||||
#[allow(dead_code)]
|
||||
sub: String,
|
||||
}
|
||||
|
||||
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
|
||||
|
||||
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS.
|
||||
///
|
||||
/// Skips validation for health check endpoints.
|
||||
/// If `JwksState` is not present as an extension (keycloak not configured),
|
||||
/// all requests pass through.
|
||||
pub async fn require_jwt_auth(request: Request, next: Next) -> Response {
|
||||
let path = request.uri().path();
|
||||
|
||||
if PUBLIC_ENDPOINTS.contains(&path) {
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let jwks_state = match request.extensions().get::<JwksState>() {
|
||||
Some(s) => s.clone(),
|
||||
None => return next.run(request).await,
|
||||
};
|
||||
|
||||
let auth_header = match request.headers().get("authorization") {
|
||||
Some(h) => h,
|
||||
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
|
||||
};
|
||||
|
||||
let token = match auth_header.to_str() {
|
||||
Ok(s) if s.starts_with("Bearer ") => &s[7..],
|
||||
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
|
||||
};
|
||||
|
||||
match validate_token(token, &jwks_state).await {
|
||||
Ok(()) => next.run(request).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("JWT validation failed: {e}");
|
||||
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> {
|
||||
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
|
||||
|
||||
let kid = header
|
||||
.kid
|
||||
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
||||
|
||||
let jwks = fetch_or_get_jwks(state).await?;
|
||||
|
||||
let jwk = jwks
|
||||
.keys
|
||||
.iter()
|
||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
||||
|
||||
let decoding_key =
|
||||
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
||||
|
||||
let mut validation = Validation::new(header.alg);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_aud = false;
|
||||
|
||||
decode::<Claims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("token validation failed: {e}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
||||
{
|
||||
let cached = state.jwks.read().await;
|
||||
if let Some(ref jwks) = *cached {
|
||||
return Ok(jwks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let resp = reqwest::get(&state.jwks_url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
||||
|
||||
let jwks: JwkSet = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
||||
|
||||
let mut cached = state.jwks.write().await;
|
||||
*cached = Some(jwks.clone());
|
||||
|
||||
Ok(jwks)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod auth_middleware;
|
||||
pub mod handlers;
|
||||
pub mod routes;
|
||||
pub mod server;
|
||||
|
||||
@@ -7,9 +7,8 @@ use tower_http::cors::CorsLayer;
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState};
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::api::auth_middleware::{require_jwt_auth, JwksState};
|
||||
use crate::api::routes;
|
||||
use crate::error::AgentError;
|
||||
|
||||
@@ -45,13 +44,9 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A
|
||||
jwks_url,
|
||||
};
|
||||
tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'");
|
||||
// Layers execute outermost-first. Extension(jwks_state) must run
|
||||
// before require_jwt_auth so the middleware can read it; the
|
||||
// status gate runs after JWT so TenantContext is in extensions.
|
||||
app = app
|
||||
.layer(middleware::from_fn(require_tenant_status))
|
||||
.layer(middleware::from_fn(require_jwt_auth))
|
||||
.layer(Extension(jwks_state));
|
||||
.layer(Extension(jwks_state))
|
||||
.layer(middleware::from_fn(require_jwt_auth));
|
||||
} else {
|
||||
tracing::warn!("Keycloak not configured - API endpoints are unprotected");
|
||||
}
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
//! M7.1 — integration tests for `compliance_core::auth::require_tenant_status`.
|
||||
//!
|
||||
//! Exercises the middleware end-to-end through an Axum router so we
|
||||
//! catch wiring bugs (extension propagation, method matching) that pure
|
||||
//! unit tests would miss.
|
||||
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used)]
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{Method, StatusCode},
|
||||
middleware::{from_fn, Next},
|
||||
response::Response,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use compliance_core::{auth::require_tenant_status, TenantContext, TenantStatus};
|
||||
use tower::ServiceExt;
|
||||
|
||||
fn ctx_with(status: TenantStatus) -> TenantContext {
|
||||
TenantContext {
|
||||
tenant_id: "t-1".to_string(),
|
||||
tenant_slug: "acme".to_string(),
|
||||
org_roles: vec![],
|
||||
products: vec![],
|
||||
plan: "starter".to_string(),
|
||||
status,
|
||||
user_id: "u-1".to_string(),
|
||||
user_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn router_with_ctx(ctx: Option<TenantContext>) -> Router {
|
||||
let injector = move |mut req: Request, next: Next| {
|
||||
let ctx = ctx.clone();
|
||||
async move {
|
||||
if let Some(c) = ctx {
|
||||
req.extensions_mut().insert(c);
|
||||
}
|
||||
next.run(req).await
|
||||
}
|
||||
};
|
||||
|
||||
Router::new()
|
||||
.route("/r", get(|| async { "read" }))
|
||||
.route("/w", post(|| async { "write" }))
|
||||
.layer(from_fn(require_tenant_status))
|
||||
.layer(from_fn(injector))
|
||||
}
|
||||
|
||||
async fn call(router: Router, method: Method, path: &str) -> Response {
|
||||
let req = Request::builder()
|
||||
.method(method)
|
||||
.uri(path)
|
||||
.body(Body::empty())
|
||||
.expect("request build");
|
||||
router.oneshot(req).await.expect("oneshot")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_tenant_can_read_and_write() {
|
||||
let r = router_with_ctx(Some(ctx_with(TenantStatus::Active)));
|
||||
assert_eq!(
|
||||
call(r.clone(), Method::GET, "/r").await.status(),
|
||||
StatusCode::OK
|
||||
);
|
||||
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn trial_tenant_can_read_and_write() {
|
||||
let r = router_with_ctx(Some(ctx_with(TenantStatus::Trial)));
|
||||
assert_eq!(
|
||||
call(r.clone(), Method::GET, "/r").await.status(),
|
||||
StatusCode::OK
|
||||
);
|
||||
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn demo_tenant_can_read_and_write() {
|
||||
let r = router_with_ctx(Some(ctx_with(TenantStatus::Demo)));
|
||||
assert_eq!(
|
||||
call(r.clone(), Method::GET, "/r").await.status(),
|
||||
StatusCode::OK
|
||||
);
|
||||
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn frozen_tenant_can_read_but_not_write() {
|
||||
let r = router_with_ctx(Some(ctx_with(TenantStatus::Frozen)));
|
||||
assert_eq!(
|
||||
call(r.clone(), Method::GET, "/r").await.status(),
|
||||
StatusCode::OK
|
||||
);
|
||||
assert_eq!(
|
||||
call(r, Method::POST, "/w").await.status(),
|
||||
StatusCode::PAYMENT_REQUIRED
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn archived_tenant_is_gone_on_every_method() {
|
||||
let r = router_with_ctx(Some(ctx_with(TenantStatus::Archived)));
|
||||
assert_eq!(
|
||||
call(r.clone(), Method::GET, "/r").await.status(),
|
||||
StatusCode::GONE
|
||||
);
|
||||
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::GONE);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn no_context_passes_through() {
|
||||
let r = router_with_ctx(None);
|
||||
assert_eq!(
|
||||
call(r.clone(), Method::GET, "/r").await.status(),
|
||||
StatusCode::OK
|
||||
);
|
||||
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
|
||||
}
|
||||
+12
-96
@@ -148,83 +148,27 @@ async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext,
|
||||
|
||||
let kid = header
|
||||
.kid
|
||||
.clone()
|
||||
.ok_or_else(|| "JWT missing kid header".to_string())?;
|
||||
|
||||
// First try against whatever's currently cached. If the kid isn't
|
||||
// there or the signature doesn't verify, the cached JWKS is most
|
||||
// likely stale (KC rotated keys) — refresh once and retry before
|
||||
// giving up. Without this every key rotation produces a silent 401
|
||||
// storm that only goes away when the agent restarts.
|
||||
let jwks = fetch_or_get_jwks(state, false).await?;
|
||||
match try_validate(token, &header, &kid, &jwks) {
|
||||
Ok(ctx) => Ok(ctx),
|
||||
Err(ValidationError::Permanent(e)) => Err(e),
|
||||
Err(ValidationError::Stale(reason)) => {
|
||||
tracing::info!(
|
||||
kid = %kid,
|
||||
reason = %reason,
|
||||
"JWKS appears stale — forcing refresh and retrying"
|
||||
);
|
||||
let jwks = fetch_or_get_jwks(state, true).await?;
|
||||
try_validate(token, &header, &kid, &jwks).map_err(|e| match e {
|
||||
ValidationError::Stale(s) | ValidationError::Permanent(s) => s,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
let jwks = fetch_or_get_jwks(state).await?;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ValidationError {
|
||||
/// Refresh-eligible: cached JWKS may be stale.
|
||||
Stale(String),
|
||||
/// Refusing the token regardless of JWKS freshness.
|
||||
Permanent(String),
|
||||
}
|
||||
|
||||
fn try_validate(
|
||||
token: &str,
|
||||
header: &jsonwebtoken::Header,
|
||||
kid: &str,
|
||||
jwks: &JwkSet,
|
||||
) -> Result<TenantContext, ValidationError> {
|
||||
let jwk = match jwks
|
||||
let jwk = jwks
|
||||
.keys
|
||||
.iter()
|
||||
.find(|k| k.common.key_id.as_deref() == Some(kid))
|
||||
{
|
||||
Some(j) => j,
|
||||
None => {
|
||||
return Err(ValidationError::Stale(
|
||||
"no matching key found in JWKS".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
.find(|k| k.common.key_id.as_deref() == Some(&kid))
|
||||
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
|
||||
|
||||
let decoding_key = DecodingKey::from_jwk(jwk)
|
||||
.map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?;
|
||||
let decoding_key =
|
||||
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
|
||||
|
||||
let mut validation = Validation::new(header.alg);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_aud = false;
|
||||
|
||||
let data = match decode::<Claims>(token, &decoding_key, &validation) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
// Signature mismatch is the other refresh-eligible failure:
|
||||
// the matching kid is present but the key bytes don't match.
|
||||
// Everything else (expired, malformed, etc.) is permanent.
|
||||
return Err(
|
||||
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) {
|
||||
ValidationError::Stale(format!("token validation failed: {e}"))
|
||||
} else {
|
||||
ValidationError::Permanent(format!("token validation failed: {e}"))
|
||||
},
|
||||
);
|
||||
}
|
||||
};
|
||||
let data = decode::<Claims>(token, &decoding_key, &validation)
|
||||
.map_err(|e| format!("token validation failed: {e}"))?;
|
||||
|
||||
claims_to_context(data.claims).map_err(ValidationError::Permanent)
|
||||
claims_to_context(data.claims)
|
||||
}
|
||||
|
||||
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
|
||||
@@ -254,25 +198,14 @@ fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, String> {
|
||||
if !force {
|
||||
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
|
||||
{
|
||||
let cached = state.jwks.read().await;
|
||||
if let Some(ref jwks) = *cached {
|
||||
return Ok(jwks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Hold the write lock across the fetch so concurrent refreshers
|
||||
// don't all hammer Keycloak when keys rotate. If another writer
|
||||
// already populated a fresh JWKS while we were waiting (and we
|
||||
// weren't asked to force), use theirs.
|
||||
let mut cached = state.jwks.write().await;
|
||||
if !force {
|
||||
if let Some(ref jwks) = *cached {
|
||||
return Ok(jwks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let resp = reqwest::get(&state.jwks_url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
|
||||
@@ -282,6 +215,7 @@ async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, Str
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
|
||||
|
||||
let mut cached = state.jwks.write().await;
|
||||
*cached = Some(jwks.clone());
|
||||
|
||||
Ok(jwks)
|
||||
@@ -359,24 +293,6 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_validate_returns_stale_when_kid_missing_from_jwks() {
|
||||
// Empty JWKS — the kid we ask for can't possibly match. The error
|
||||
// must classify as Stale so the caller refreshes JWKS and retries.
|
||||
let jwks = JwkSet { keys: vec![] };
|
||||
let header = jsonwebtoken::Header {
|
||||
alg: jsonwebtoken::Algorithm::RS256,
|
||||
kid: Some("kid-rotated-out".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks)
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
ValidationError::Stale(s) => assert!(s.contains("no matching key")),
|
||||
ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_write_detects_methods() {
|
||||
assert!(!is_write(&Method::GET));
|
||||
|
||||
Reference in New Issue
Block a user