From 628f346529f7405f074d990c6e817d1690abc9bc Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar <30073382+mighty840@users.noreply.github.com> Date: Thu, 18 Jun 2026 11:54:01 +0200 Subject: [PATCH] feat(m7.3): MCP tenant-scoped bearer tokens MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LLM clients (Claude Desktop, Cursor, ChatGPT) can't run a Keycloak OIDC flow, so the MCP server can't use JWTs for auth. This PR introduces opaque static bearer tokens minted per-tenant via new agent endpoints, validated by the MCP server, and used to route incoming MCP requests to the caller's per-tenant database. Until now, the MCP server connected to a single shared MongoDB DB with no auth and no tenant awareness — every tool (list_findings, list_sbom_packages, etc.) returned data across all tenants. After M7.2 made the agent per-tenant, MCP was the lone cross-tenant data leak. This closes it. Design summary - Token format: `mcpt_<43 url-safe random chars>` (48 chars total). Opaque, never embeds tenant_id, never stored in plaintext. - Storage: cross-tenant `__admin.mcp_tokens` collection, keyed by SHA-256 hash. Each row carries the tenant_id, name, created_by, created_at, last_used_at, revoked flag. - Agent endpoints (tenant-scoped via TenantCtx): POST /api/v1/mcp-tokens → mint (returns raw token ONCE) GET /api/v1/mcp-tokens → list (metadata + 12-char prefix, never the hash) DELETE /api/v1/mcp-tokens/id → soft revoke - MCP middleware: extract `Authorization: Bearer mcpt_...`, sniff the prefix, SHA-256 → lookup in admin DB → reject if missing or revoked. Updates last_used_at fire-and-forget so it never blocks. Sets `tokio::task_local!` TENANT_ID for the inner service call; the rmcp tool handlers read it and resolve the per-tenant DB. - task_local is scoped via TENANT_ID.scope(...) around next.run(req) so the rmcp tool handlers downstream see the tenant_id without modifying their (macro-generated) signatures. Files - compliance-core/src/models/mcp_token.rs (new) — McpToken + McpTokenView (public projection without the hash). - compliance-agent/src/database.rs — DatabasePool::admin_db() + admin_db_name(): cross-tenant access for token storage. - compliance-agent/src/api/handlers/mcp_tokens.rs (new) — three endpoints. Token generation: 32 random bytes → URL-safe base64, no padding. SHA-256 hex stored. - compliance-mcp/src/database.rs — replaced single Database with DatabasePool. Tenant-scoped Database constructed per request. Same sanitization + 63-byte cap + hash fallback as the agent. - compliance-mcp/src/auth.rs (new) — bearer middleware + task_local. Includes a SHA-256 round-trip test against a known vector. - compliance-mcp/src/main.rs — HTTP transport: bearer middleware layered on /mcp (not /health, so orca's container probe still works). stdio transport: falls back to STDIO_TENANT_ID env (defaults to "dev") so local development still works; logged loudly as not-for-production. - compliance-mcp/src/server.rs — each of the 12 tool handlers resolves the per-tenant DB via task_local before calling its tool fn. Tool fns themselves are unchanged. Token UX - Generated by the dashboard (or curl + KC JWT) — user sees raw token exactly once, copies it into their LLM client config. - Dashboard UI for management is a follow-up; can use curl in the meantime: curl -X POST https://comp-dev.../api/v1/mcp-tokens \ -H "Authorization: Bearer $KC_JWT" \ -H "Content-Type: application/json" \ -d '{"name":"Claude Desktop"}' Test plan - cargo fmt --all clean - cargo clippy --workspace --exclude compliance-dashboard -- -D warnings clean - cargo test -p compliance-core --lib — 7 pass - cargo test -p compliance-agent --lib — 230 pass (+2 new for token generation + sha256 stability) - cargo test -p compliance-agent --test tenant_isolation — 6 pass - cargo test -p compliance-mcp — 34 pass (+1 new sha256 vector) What's deferred - Dashboard UI for managing tokens (page + create modal + list/ revoke). Trivial once the API is live. - Token expiry + per-tool scope (today every token grants access to all 12 tools for its tenant). - Lifting DatabasePool into compliance-core (duplicated for now in compliance-mcp to keep this PR focused; lift if a third consumer appears). Production - The `__admin` DB needs to NOT collide with a tenant DB. Sanitized tenant_id never starts with `_admin` for any current tenant_id shape (UUIDs); flagged in the database.rs docstring so tenant provisioning can reject `_admin*` ids proactively. - orca-infra MCP service block already has MONGODB_URI / MONGODB_DATABASE — no new env needed. No KC creds since MCP doesn't use Keycloak for its own auth. Co-Authored-By: Claude Opus 4.7 --- Cargo.lock | 4 + Cargo.toml | 2 + compliance-agent/Cargo.toml | 1 + .../src/api/handlers/mcp_tokens.rs | 186 ++++++++++++++++++ compliance-agent/src/api/handlers/mod.rs | 1 + compliance-agent/src/api/routes.rs | 9 + compliance-agent/src/database.rs | 19 ++ compliance-core/src/models/mcp_token.rs | 69 +++++++ compliance-core/src/models/mod.rs | 2 + compliance-mcp/Cargo.toml | 5 +- compliance-mcp/src/auth.rs | 129 ++++++++++++ compliance-mcp/src/database.rs | 122 +++++++++++- compliance-mcp/src/main.rs | 45 ++++- compliance-mcp/src/server.rs | 63 ++++-- 14 files changed, 622 insertions(+), 35 deletions(-) create mode 100644 compliance-agent/src/api/handlers/mcp_tokens.rs create mode 100644 compliance-core/src/models/mcp_token.rs create mode 100644 compliance-mcp/src/auth.rs diff --git a/Cargo.lock b/Cargo.lock index 200bbb0..d892731 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -676,6 +676,7 @@ dependencies = [ "jsonwebtoken", "mongodb", "octocrab", + "rand 0.9.2", "regex", "reqwest", "secrecy", @@ -818,12 +819,15 @@ dependencies = [ "bson", "chrono", "compliance-core", + "dashmap", "dotenvy", + "hex", "mongodb", "rmcp", "schemars 1.2.1", "serde", "serde_json", + "sha2", "thiserror 2.0.18", "tokio", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index b09af4e..b1c0598 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,5 @@ zip = { version = "2", features = ["aes-crypto", "deflate"] } dashmap = "6" tokio-stream = { version = "0.1", features = ["sync"] } aes-gcm = "0.10" +rand = "0.9" +base64 = "0.22" diff --git a/compliance-agent/Cargo.toml b/compliance-agent/Cargo.toml index e8b81d6..d005516 100644 --- a/compliance-agent/Cargo.toml +++ b/compliance-agent/Cargo.toml @@ -42,6 +42,7 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } futures-core = "0.3" dashmap = { workspace = true } tokio-stream = { workspace = true } +rand = { workspace = true } [dev-dependencies] compliance-core = { workspace = true, features = ["mongodb", "axum"] } diff --git a/compliance-agent/src/api/handlers/mcp_tokens.rs b/compliance-agent/src/api/handlers/mcp_tokens.rs new file mode 100644 index 0000000..87db849 --- /dev/null +++ b/compliance-agent/src/api/handlers/mcp_tokens.rs @@ -0,0 +1,186 @@ +//! `/api/v1/mcp-tokens` — per-tenant API tokens for the MCP server. +//! +//! These are opaque static bearers issued via the dashboard (or a +//! direct curl with a KC JWT) and copied into LLM clients (Claude +//! Desktop / Cursor / ChatGPT). The MCP server hashes incoming bearers +//! and looks them up in the cross-tenant `__admin.mcp_tokens` +//! collection to derive the tenant_id for routing. +//! +//! The raw token is shown to the caller exactly once at creation; the +//! database only ever stores the SHA-256 hash. Revocation is a soft +//! delete (sets `revoked: true`) so the audit log keeps the record. + +use axum::extract::{Extension, Path}; +use axum::http::StatusCode; +use axum::Json; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use compliance_core::models::{McpToken, McpTokenView}; +use compliance_core::tenant_ctx::TenantCtx; +use mongodb::bson::doc; +use rand::RngCore; +use sha2::{Digest, Sha256}; + +use super::dto::{AgentExt, ApiResponse}; + +/// Mongo collection name inside the admin DB. +const COLLECTION: &str = "mcp_tokens"; + +/// Token prefix the MCP server expects on every bearer. +const TOKEN_PREFIX: &str = "mcpt_"; + +/// Bytes of randomness behind each token. 32 → ~256 bits. +/// Encoded as URL-safe base64 without padding → 43 chars. +/// Combined with `mcpt_` → 48-char tokens. +const TOKEN_RAND_BYTES: usize = 32; + +#[derive(serde::Deserialize)] +pub struct CreateMcpTokenRequest { + pub name: String, +} + +/// Returned exactly once at creation. The `token` field is gone from +/// the listing endpoint — the user must save it now. +#[derive(serde::Serialize)] +pub struct CreateMcpTokenResponse { + pub token: String, + pub view: McpTokenView, +} + +/// `POST /api/v1/mcp-tokens` — mint a new token for the caller's tenant. +#[tracing::instrument(skip_all)] +pub async fn create_mcp_token( + Extension(agent): AgentExt, + tenant: TenantCtx, + Json(req): Json, +) -> Result, StatusCode> { + if req.name.trim().is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + let raw = generate_token(); + let token_hash = sha256_hex(&raw); + let token_prefix: String = raw.chars().take(12).collect(); + + let mut token = McpToken { + id: None, + token_hash, + token_prefix, + tenant_id: tenant.0.tenant_id.clone(), + name: req.name.trim().to_string(), + created_by: tenant.0.user_id.clone(), + created_at: chrono::Utc::now(), + last_used_at: None, + revoked: false, + }; + + let col = agent.db_pool.admin_db().collection::(COLLECTION); + let res = col.insert_one(&token).await.map_err(|e| { + tracing::error!("Failed to insert MCP token: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + token.id = res.inserted_id.as_object_id(); + + Ok(Json(CreateMcpTokenResponse { + view: McpTokenView::from(&token), + token: raw, + })) +} + +/// `GET /api/v1/mcp-tokens` — list tokens for the caller's tenant. +/// Hash is never returned; only metadata + the 12-char prefix so the +/// user can identify which row is which. +#[tracing::instrument(skip_all)] +pub async fn list_mcp_tokens( + Extension(agent): AgentExt, + tenant: TenantCtx, +) -> Result>>, StatusCode> { + let col = agent.db_pool.admin_db().collection::(COLLECTION); + let mut cursor = col + .find(doc! { "tenant_id": &tenant.0.tenant_id }) + .sort(doc! { "created_at": -1 }) + .await + .map_err(|e| { + tracing::error!("Failed to list MCP tokens: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + let mut out = Vec::new(); + while cursor.advance().await.map_err(|e| { + tracing::warn!("MCP tokens cursor advance failed: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })? { + match cursor.deserialize_current() { + Ok(t) => out.push(McpTokenView::from(&t)), + Err(e) => tracing::warn!("Failed to deserialize MCP token: {e}"), + } + } + Ok(Json(ApiResponse { + data: out, + total: None, + page: None, + })) +} + +/// `DELETE /api/v1/mcp-tokens/{id}` — revoke (soft delete). +/// Scoped to the caller's tenant: a user can't revoke another tenant's +/// token even if they guess its id. +#[tracing::instrument(skip_all, fields(id = %id))] +pub async fn revoke_mcp_token( + Extension(agent): AgentExt, + tenant: TenantCtx, + Path(id): Path, +) -> Result, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let col = agent.db_pool.admin_db().collection::(COLLECTION); + let result = col + .update_one( + doc! { "_id": oid, "tenant_id": &tenant.0.tenant_id }, + doc! { "$set": { "revoked": true } }, + ) + .await + .map_err(|e| { + tracing::error!("Failed to revoke MCP token: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + if result.matched_count == 0 { + return Err(StatusCode::NOT_FOUND); + } + Ok(Json(serde_json::json!({ "status": "revoked" }))) +} + +/// 32 bytes random → URL-safe base64 → 43 chars, no padding. +/// Prefixed with `mcpt_` so the MCP server can sniff the format +/// before bothering with the DB lookup. +fn generate_token() -> String { + let mut bytes = [0u8; TOKEN_RAND_BYTES]; + rand::rng().fill_bytes(&mut bytes); + format!("{TOKEN_PREFIX}{}", URL_SAFE_NO_PAD.encode(bytes)) +} + +fn sha256_hex(s: &str) -> String { + let mut h = Sha256::new(); + h.update(s.as_bytes()); + hex::encode(h.finalize()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generated_tokens_are_unique_and_prefixed() { + let a = generate_token(); + let b = generate_token(); + assert_ne!(a, b); + assert!(a.starts_with(TOKEN_PREFIX)); + assert!(b.starts_with(TOKEN_PREFIX)); + // 5 + 43 = 48 chars + assert_eq!(a.len(), 5 + 43); + } + + #[test] + fn sha256_is_stable_and_64_hex() { + let h = sha256_hex("mcpt_abc"); + assert_eq!(h.len(), 64); + assert!(h.chars().all(|c| c.is_ascii_hexdigit())); + assert_eq!(sha256_hex("mcpt_abc"), h); + } +} diff --git a/compliance-agent/src/api/handlers/mod.rs b/compliance-agent/src/api/handlers/mod.rs index ea1d74d..3b58cc7 100644 --- a/compliance-agent/src/api/handlers/mod.rs +++ b/compliance-agent/src/api/handlers/mod.rs @@ -6,6 +6,7 @@ pub mod graph; pub mod health; pub mod help_chat; pub mod issues; +pub mod mcp_tokens; pub mod notifications; pub mod pentest_handlers; pub use pentest_handlers as pentest; diff --git a/compliance-agent/src/api/routes.rs b/compliance-agent/src/api/routes.rs index c715df4..d19742e 100644 --- a/compliance-agent/src/api/routes.rs +++ b/compliance-agent/src/api/routes.rs @@ -47,6 +47,15 @@ pub fn build_router() -> Router { .route("/api/v1/sbom/diff", get(handlers::sbom_diff)) .route("/api/v1/issues", get(handlers::list_issues)) .route("/api/v1/scan-runs", get(handlers::list_scan_runs)) + // MCP token management (per-tenant API tokens for the MCP server) + .route( + "/api/v1/mcp-tokens", + get(handlers::mcp_tokens::list_mcp_tokens).post(handlers::mcp_tokens::create_mcp_token), + ) + .route( + "/api/v1/mcp-tokens/{id}", + delete(handlers::mcp_tokens::revoke_mcp_token), + ) // Graph API endpoints .route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph)) .route( diff --git a/compliance-agent/src/database.rs b/compliance-agent/src/database.rs index 5532b18..ec279d6 100644 --- a/compliance-agent/src/database.rs +++ b/compliance-agent/src/database.rs @@ -141,6 +141,25 @@ impl DatabasePool { &self.client } + /// Cross-tenant admin database used by features that intentionally + /// span tenants (today: MCP bearer tokens — each token row carries + /// a `tenant_id` and the MCP server reads them to route requests). + /// + /// The name `__admin` (double underscore) is reserved — + /// the sanitizer never produces it for a normal tenant DB because + /// the natural format is `_` (one + /// underscore) and tenant_ids would have to start with `_admin` to + /// collide. New tenant provisioning should reject such ids. + pub fn admin_db(&self) -> mongodb::Database { + self.client.database(&self.admin_db_name()) + } + + /// Name of the admin database — public so tests / operators can + /// drop it via the raw client. + pub fn admin_db_name(&self) -> String { + format!("{}__admin", self.db_prefix) + } + /// List every Mongo database currently belonging to this pool, /// identified by the `_` prefix. The result is the raw /// database names — opening one for offboarding/cleanup goes diff --git a/compliance-core/src/models/mcp_token.rs b/compliance-core/src/models/mcp_token.rs new file mode 100644 index 0000000..d2f56eb --- /dev/null +++ b/compliance-core/src/models/mcp_token.rs @@ -0,0 +1,69 @@ +//! Per-tenant API tokens used by `compliance-mcp` to authenticate MCP +//! HTTP requests on behalf of LLM clients (Claude Desktop, Cursor, +//! ChatGPT, etc.) that can't run a Keycloak OIDC flow. +//! +//! Tokens are opaque strings of the form `mcpt_<44 url-safe random +//! chars>`. The raw value is shown to the user exactly once at +//! creation; the database only ever sees the SHA-256 hash. Lookups go +//! through the cross-tenant `__admin.mcp_tokens` collection +//! and return the `tenant_id` the MCP server should route to. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +/// Persisted token metadata. `token_hash` is the SHA-256 hex of the +/// raw token; the raw token itself is never stored. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpToken { + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + /// SHA-256 hex of the raw token. Unique index in the collection. + pub token_hash: String, + /// First 8 chars of the raw token — purely for UI display so users + /// can identify which token is which without re-issuing. + pub token_prefix: String, + /// Routes to `_` on MCP requests. + pub tenant_id: String, + /// User-given label, e.g. "Claude Desktop" or "Sharang's laptop". + pub name: String, + /// Keycloak `sub` of the user who created this token, for audit. + pub created_by: String, + #[serde(with = "super::serde_helpers::bson_datetime")] + pub created_at: DateTime, + #[serde(default, with = "super::serde_helpers::opt_bson_datetime")] + pub last_used_at: Option>, + /// Soft-delete flag. A revoked token doc stays around for audit + /// but never authenticates. + #[serde(default)] + pub revoked: bool, +} + +/// Public projection of a token — never includes the hash. +/// Returned by `GET /api/v1/mcp-tokens`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpTokenView { + pub id: String, + pub name: String, + /// `mcpt_xxxx…` so the user can identify which row is which. + pub token_prefix: String, + pub created_by: String, + #[serde(with = "super::serde_helpers::bson_datetime")] + pub created_at: DateTime, + #[serde(default, with = "super::serde_helpers::opt_bson_datetime")] + pub last_used_at: Option>, + pub revoked: bool, +} + +impl From<&McpToken> for McpTokenView { + fn from(t: &McpToken) -> Self { + Self { + id: t.id.map(|o| o.to_hex()).unwrap_or_default(), + name: t.name.clone(), + token_prefix: t.token_prefix.clone(), + created_by: t.created_by.clone(), + created_at: t.created_at, + last_used_at: t.last_used_at, + revoked: t.revoked, + } + } +} diff --git a/compliance-core/src/models/mod.rs b/compliance-core/src/models/mod.rs index 41c2f63..f6130f1 100644 --- a/compliance-core/src/models/mod.rs +++ b/compliance-core/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod finding; pub mod graph; pub mod issue; pub mod mcp; +pub mod mcp_token; pub mod notification; pub mod pentest; pub mod repository; @@ -28,6 +29,7 @@ pub use graph::{ }; pub use issue::{IssueStatus, TrackerIssue, TrackerType}; pub use mcp::{McpServerConfig, McpServerStatus, McpTransport}; +pub use mcp_token::{McpToken, McpTokenView}; pub use notification::{CveNotification, NotificationSeverity, NotificationStatus}; pub use pentest::{ AttackChainNode, AttackNodeStatus, AuthMode, CodeContextHint, Environment, IdentityProvider, diff --git a/compliance-mcp/Cargo.toml b/compliance-mcp/Cargo.toml index 723a902..aabecdf 100644 --- a/compliance-mcp/Cargo.toml +++ b/compliance-mcp/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -compliance-core = { workspace = true, features = ["mongodb"] } +compliance-core = { workspace = true, features = ["mongodb", "axum"] } rmcp = { version = "0.16", features = ["server", "macros", "transport-io", "transport-streamable-http-server"] } tokio = { workspace = true } serde = { workspace = true } @@ -19,3 +19,6 @@ bson = { version = "2", features = ["chrono-0_4"] } schemars = "1.0" axum = "0.8" tower-http = { version = "0.6", features = ["cors"] } +sha2 = { workspace = true } +hex = { workspace = true } +dashmap = { workspace = true } diff --git a/compliance-mcp/src/auth.rs b/compliance-mcp/src/auth.rs new file mode 100644 index 0000000..c18c3f3 --- /dev/null +++ b/compliance-mcp/src/auth.rs @@ -0,0 +1,129 @@ +//! Bearer-token authentication for incoming MCP HTTP requests. +//! +//! LLM clients (Claude Desktop / Cursor / ChatGPT / etc.) can't run +//! Keycloak OIDC, so the MCP server uses opaque static tokens minted +//! per-tenant via the agent's `POST /api/v1/mcp-tokens` endpoint. +//! +//! Flow per request: +//! 1. Extract `Authorization: Bearer `. Missing → 401. +//! 2. SHA-256 hash the token. +//! 3. Look up the hash in `__admin.mcp_tokens`. Missing or +//! revoked → 401. +//! 4. Fire-and-forget update of `last_used_at` so the dashboard can +//! show staleness without blocking the handler. +//! 5. Stash the tenant_id in [`TENANT_ID`] (a `tokio::task_local`) so +//! the MCP tool handlers can read it without modifying rmcp's +//! handler signatures. +//! +//! The `task_local` is scoped around the inner service call via +//! [`bearer_auth`], so every handler invoked downstream sees the +//! tenant_id without us having to thread it through the macro- +//! generated tool router. + +use axum::body::Body; +use axum::extract::{Request, State}; +use axum::http::StatusCode; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use mongodb::bson::doc; +use sha2::{Digest, Sha256}; + +use crate::database::DatabasePool; + +tokio::task_local! { + /// Tenant id resolved from the bearer for this request. Set by + /// [`bearer_auth`] before the inner service runs; read by the + /// MCP tool handlers via [`current_tenant_id`]. + pub static TENANT_ID: String; +} + +/// Mongo collection name in `__admin`. +const COLLECTION: &str = "mcp_tokens"; + +/// Returns the tenant_id set by the auth middleware. `None` outside a +/// request scope (e.g. unit tests that bypass the middleware). +pub fn current_tenant_id() -> Option { + TENANT_ID.try_with(|s| s.clone()).ok() +} + +/// Axum middleware: validate bearer → set [`TENANT_ID`] → call inner. +pub async fn bearer_auth( + State(pool): State, + request: Request, + next: Next, +) -> Response { + let Some(token) = extract_bearer(&request) else { + return (StatusCode::UNAUTHORIZED, "Missing bearer token").into_response(); + }; + if !token.starts_with("mcpt_") { + return (StatusCode::UNAUTHORIZED, "Invalid token format").into_response(); + } + let token_hash = sha256_hex(&token); + + let col = pool.admin_db().collection::(COLLECTION); + let found = match col + .find_one(doc! { "token_hash": &token_hash, "revoked": false }) + .await + { + Ok(Some(t)) => t, + Ok(None) => { + return (StatusCode::UNAUTHORIZED, "Invalid or revoked token").into_response(); + } + Err(e) => { + tracing::error!("MCP token lookup failed: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "Token lookup error").into_response(); + } + }; + + // Fire-and-forget last_used_at update — never block the handler. + let col2 = pool.admin_db().collection::(COLLECTION); + let hash_for_update = token_hash.clone(); + tokio::spawn(async move { + let _ = col2 + .update_one( + doc! { "token_hash": &hash_for_update }, + doc! { "$set": { "last_used_at": mongodb::bson::DateTime::now() } }, + ) + .await; + }); + + let tenant_id = found.tenant_id; + let inner = next.run(request); + TENANT_ID.scope(tenant_id, inner).await +} + +/// Bare-bones projection — we don't need the whole `McpToken` here, +/// just enough to route and confirm validity. +#[derive(serde::Deserialize)] +struct TokenLookup { + tenant_id: String, +} + +fn extract_bearer(req: &Request) -> Option { + req.headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) +} + +fn sha256_hex(s: &str) -> String { + let mut h = Sha256::new(); + h.update(s.as_bytes()); + hex::encode(h.finalize()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sha256_known_value() { + // python -c 'import hashlib; print(hashlib.sha256(b"mcpt_known").hexdigest())' + assert_eq!( + sha256_hex("mcpt_known"), + "27cf6cf678a44244106863c1c031be8e57b84c2b3019d742f755f8e7afa75dfd" + ); + } +} diff --git a/compliance-mcp/src/database.rs b/compliance-mcp/src/database.rs index 041d151..670e1e2 100644 --- a/compliance-mcp/src/database.rs +++ b/compliance-mcp/src/database.rs @@ -1,19 +1,127 @@ -use mongodb::{Client, Collection}; +//! Per-tenant Mongo broker for the MCP server. +//! +//! Mirror of the agent's `compliance_agent::database::DatabasePool` — +//! duplicated here rather than lifted into `compliance-core` to keep +//! this PR focused. If a third consumer ever needs it, lift then. +//! +//! Bearer tokens (validated by the auth middleware) carry a tenant_id +//! and the handler resolves the per-tenant database via +//! [`DatabasePool::for_tenant_id`]. The admin database +//! (`__admin`) holds the cross-tenant `mcp_tokens` +//! collection that the middleware queries on every request. + +use std::sync::Arc; + +use dashmap::DashMap; +use mongodb::{bson::doc, Client, Collection}; +use sha2::{Digest, Sha256}; use compliance_core::models::*; +/// 63-byte Mongo db-name cap; same invariant as the agent's pool. +const MAX_DB_NAME_LEN: usize = 63; +/// 16-byte SHA-256 truncation, hex-encoded → 32 chars. +const HASH_HEX_LEN: usize = 32; +const MAX_PREFIX_LEN: usize = MAX_DB_NAME_LEN - 1 - HASH_HEX_LEN; + +#[derive(Clone, Debug)] +pub struct DatabasePool { + client: Client, + db_prefix: String, + /// Tenants we've handed out a [`Database`] for. The MCP server + /// doesn't ensure indexes (the agent owns that side of the + /// schema), so the marker exists only to satisfy the parallel + /// shape — current code never reads it. + #[allow(dead_code)] + seen: Arc>, +} + +#[derive(Debug, thiserror::Error)] +pub enum DbError { + #[error("db_prefix '{prefix}' is {len} chars; max is {max} so the hash-fallback DB name fits Mongo's 63-byte cap")] + PrefixTooLong { + prefix: String, + len: usize, + max: usize, + }, + #[error(transparent)] + Mongo(#[from] mongodb::error::Error), +} + +impl DatabasePool { + pub async fn connect(uri: &str, db_prefix: &str) -> Result { + if db_prefix.len() > MAX_PREFIX_LEN { + return Err(DbError::PrefixTooLong { + prefix: db_prefix.to_string(), + len: db_prefix.len(), + max: MAX_PREFIX_LEN, + }); + } + let client = Client::with_uri_str(uri).await?; + client + .database("admin") + .run_command(doc! { "ping": 1 }) + .await?; + tracing::info!( + "MCP MongoDB cluster reachable; per-tenant pool ready (db prefix '{db_prefix}')" + ); + Ok(Self { + client, + db_prefix: db_prefix.to_string(), + seen: Arc::new(DashMap::new()), + }) + } + + /// Read-only handle to the tenant's database. No indexes are + /// ensured here — the agent owns writes, MCP only reads. + pub fn for_tenant_id(&self, tenant_id: &str) -> Database { + let db_name = self.tenant_db_name(tenant_id); + self.seen.insert(tenant_id.to_string(), ()); + Database::new(self.client.database(&db_name)) + } + + /// Cross-tenant admin DB — holds the `mcp_tokens` collection that + /// the auth middleware queries to map bearer → tenant_id. + pub fn admin_db(&self) -> mongodb::Database { + self.client.database(&format!("{}__admin", self.db_prefix)) + } + + pub fn tenant_db_name(&self, tenant_id: &str) -> String { + let sanitized = sanitize_tenant_id(tenant_id); + let natural = format!("{}_{}", self.db_prefix, sanitized); + if natural.len() <= MAX_DB_NAME_LEN { + natural + } else { + let mut h = Sha256::new(); + h.update(tenant_id.as_bytes()); + let digest = h.finalize(); + let suffix = hex::encode(&digest[..HASH_HEX_LEN / 2]); + format!("{}_{}", self.db_prefix, suffix) + } + } +} + +fn sanitize_tenant_id(tenant_id: &str) -> String { + tenant_id + .chars() + .map(|c| match c { + '/' | '\\' | '.' | '"' | '$' | ' ' | '\0' => '_', + c => c, + }) + .collect() +} + +/// Typed accessors for the MCP-readable collections in a tenant DB. +/// Matches the agent's `Database` shape but only exposes what the MCP +/// tool handlers actually need. #[derive(Clone, Debug)] pub struct Database { inner: mongodb::Database, } impl Database { - pub async fn connect(uri: &str, db_name: &str) -> Result { - let client = Client::with_uri_str(uri).await?; - let db = client.database(db_name); - db.run_command(mongodb::bson::doc! { "ping": 1 }).await?; - tracing::info!("MCP server connected to MongoDB '{db_name}'"); - Ok(Self { inner: db }) + pub(crate) fn new(inner: mongodb::Database) -> Self { + Self { inner } } pub fn findings(&self) -> Collection { diff --git a/compliance-mcp/src/main.rs b/compliance-mcp/src/main.rs index 7bcfd96..b35547f 100644 --- a/compliance-mcp/src/main.rs +++ b/compliance-mcp/src/main.rs @@ -1,10 +1,11 @@ +mod auth; mod database; mod server; mod tools; use std::sync::Arc; -use database::Database; +use database::DatabasePool; use rmcp::transport::{ streamable_http_server::session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService, @@ -24,36 +25,60 @@ async fn main() -> Result<(), Box> { let mongo_uri = std::env::var("MONGODB_URI").unwrap_or_else(|_| "mongodb://localhost:27017".to_string()); - let db_name = + // MONGODB_DATABASE is reused as the per-tenant DB-name prefix — + // same convention as the agent so `__admin.mcp_tokens` + // and `_` line up across services. + let db_prefix = std::env::var("MONGODB_DATABASE").unwrap_or_else(|_| "compliance_scanner".to_string()); - let db = Database::connect(&mongo_uri, &db_name).await?; + let pool = DatabasePool::connect(&mongo_uri, &db_prefix).await?; - // If MCP_PORT is set, run as Streamable HTTP server; otherwise use stdio. + // HTTP transport: bind a small axum router with bearer-auth in + // front of the rmcp service. `/health` stays public for orca's + // container probe. if let Ok(port_str) = std::env::var("MCP_PORT") { let port: u16 = port_str.parse()?; tracing::info!("Starting MCP server on HTTP port {port}"); - let db_clone = db.clone(); + let pool_for_factory = pool.clone(); let service = StreamableHttpService::new( - move || Ok(ComplianceMcpServer::new(db_clone.clone())), + move || Ok(ComplianceMcpServer::new(pool_for_factory.clone())), Arc::new(LocalSessionManager::default()), StreamableHttpServerConfig::default(), ); let router = axum::Router::new() .route("/health", axum::routing::get(|| async { "ok" })) - .nest_service("/mcp", service); + .nest_service( + "/mcp", + axum::Router::new().fallback_service(service).layer( + axum::middleware::from_fn_with_state(pool.clone(), auth::bearer_auth), + ), + ); let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?; tracing::info!("MCP HTTP server listening on 0.0.0.0:{port}"); axum::serve(listener, router).await?; } else { + // stdio transport — used when run as a local MCP server next + // to the LLM client. There's no HTTP layer to do bearer auth, + // so we synthesize a tenant_id from STDIO_TENANT_ID for local + // development. NEVER use this in production. tracing::info!("Starting MCP server on stdio"); - let server = ComplianceMcpServer::new(db); + let synth_tenant = std::env::var("STDIO_TENANT_ID").unwrap_or_else(|_| "dev".to_string()); + tracing::warn!( + tenant_id = %synth_tenant, + "stdio transport — using synthetic tenant id; DO NOT use in production" + ); + let server = ComplianceMcpServer::new(pool); let transport = rmcp::transport::stdio(); use rmcp::ServiceExt; - let handle = server.serve(transport).await?; - handle.waiting().await?; + auth::TENANT_ID + .scope(synth_tenant, async { + let handle = server.serve(transport).await?; + handle.waiting().await?; + Ok::<_, Box>(()) + }) + .await?; } Ok(()) diff --git a/compliance-mcp/src/server.rs b/compliance-mcp/src/server.rs index d40ba82..bba442e 100644 --- a/compliance-mcp/src/server.rs +++ b/compliance-mcp/src/server.rs @@ -2,20 +2,37 @@ use rmcp::{ handler::server::wrapper::Parameters, model::*, tool, tool_handler, tool_router, ServerHandler, }; -use crate::database::Database; +use crate::auth::current_tenant_id; +use crate::database::{Database, DatabasePool}; use crate::tools::{dast, findings, pentest, sbom}; pub struct ComplianceMcpServer { - db: Database, + pool: DatabasePool, #[allow(dead_code)] tool_router: rmcp::handler::server::router::tool::ToolRouter, } +impl ComplianceMcpServer { + /// Resolve the per-tenant `Database` from the bearer-set + /// `task_local`. Every tool handler calls this; missing context + /// surfaces as `internal_error` because it means the auth + /// middleware was misconfigured (handler ran without scope). + fn tenant_db(&self) -> Result { + let tenant_id = current_tenant_id().ok_or_else(|| { + rmcp::ErrorData::internal_error( + "no tenant context — bearer middleware not in chain".to_string(), + None, + ) + })?; + Ok(self.pool.for_tenant_id(&tenant_id)) + } +} + #[tool_router] impl ComplianceMcpServer { - pub fn new(db: Database) -> Self { + pub fn new(pool: DatabasePool) -> Self { Self { - db, + pool, tool_router: Self::tool_router(), } } @@ -29,7 +46,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - findings::list_findings(&self.db, params).await + let db = self.tenant_db()?; + findings::list_findings(&db, params).await } #[tool(description = "Get a single finding by its ID")] @@ -37,7 +55,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - findings::get_finding(&self.db, params).await + let db = self.tenant_db()?; + findings::get_finding(&db, params).await } #[tool(description = "Get a summary of findings counts grouped by severity and status")] @@ -45,7 +64,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - findings::findings_summary(&self.db, params).await + let db = self.tenant_db()?; + findings::findings_summary(&db, params).await } // ── SBOM ────────────────────────────────────────────── @@ -57,7 +77,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - sbom::list_sbom_packages(&self.db, params).await + let db = self.tenant_db()?; + sbom::list_sbom_packages(&db, params).await } #[tool( @@ -67,7 +88,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - sbom::sbom_vuln_report(&self.db, params).await + let db = self.tenant_db()?; + sbom::sbom_vuln_report(&db, params).await } // ── DAST ────────────────────────────────────────────── @@ -79,7 +101,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - dast::list_dast_findings(&self.db, params).await + let db = self.tenant_db()?; + dast::list_dast_findings(&db, params).await } #[tool(description = "Get a summary of recent DAST scan runs and finding counts")] @@ -87,7 +110,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - dast::dast_scan_summary(&self.db, params).await + let db = self.tenant_db()?; + dast::dast_scan_summary(&db, params).await } // ── Pentest ───────────────────────────────────────────── @@ -99,7 +123,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - pentest::list_pentest_sessions(&self.db, params).await + let db = self.tenant_db()?; + pentest::list_pentest_sessions(&db, params).await } #[tool(description = "Get a single AI pentest session by its ID")] @@ -107,7 +132,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - pentest::get_pentest_session(&self.db, params).await + let db = self.tenant_db()?; + pentest::get_pentest_session(&db, params).await } #[tool( @@ -117,7 +143,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - pentest::get_attack_chain(&self.db, params).await + let db = self.tenant_db()?; + pentest::get_attack_chain(&db, params).await } #[tool(description = "Get chat messages from a pentest session")] @@ -125,7 +152,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - pentest::get_pentest_messages(&self.db, params).await + let db = self.tenant_db()?; + pentest::get_pentest_messages(&db, params).await } #[tool( @@ -135,7 +163,8 @@ impl ComplianceMcpServer { &self, Parameters(params): Parameters, ) -> Result { - pentest::pentest_stats(&self.db, params).await + let db = self.tenant_db()?; + pentest::pentest_stats(&db, params).await } } @@ -149,7 +178,7 @@ impl ServerHandler for ComplianceMcpServer { .build(), server_info: Implementation::from_build_env(), instructions: Some( - "Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions." + "Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions for your tenant." .to_string(), ), } -- 2.52.0