Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b96dda11fb | |||
| e67a13535a |
@@ -35,11 +35,16 @@ impl ComplianceAgent {
|
||||
config.litellm_model.clone(),
|
||||
config.litellm_embed_model.clone(),
|
||||
));
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_default();
|
||||
Self {
|
||||
config,
|
||||
db,
|
||||
llm,
|
||||
http: reqwest::Client::new(),
|
||||
http,
|
||||
session_streams: Arc::new(DashMap::new()),
|
||||
session_pause: Arc::new(DashMap::new()),
|
||||
session_semaphore: Arc::new(Semaphore::new(DEFAULT_MAX_CONCURRENT_SESSIONS)),
|
||||
|
||||
@@ -19,12 +19,17 @@ impl LlmClient {
|
||||
model: String,
|
||||
embed_model: String,
|
||||
) -> Self {
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_default();
|
||||
Self {
|
||||
base_url,
|
||||
api_key,
|
||||
model,
|
||||
embed_model,
|
||||
http: reqwest::Client::new(),
|
||||
http,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -174,19 +174,26 @@ impl PipelineOrchestrator {
|
||||
k.expose_secret().to_string()
|
||||
}),
|
||||
);
|
||||
let cve_alerts = match async {
|
||||
let cve_alerts = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(600),
|
||||
async {
|
||||
cve_scanner
|
||||
.scan_dependencies(&repo_id, &mut sbom_entries)
|
||||
.await
|
||||
}
|
||||
.instrument(tracing::info_span!("stage_cve_scanning"))
|
||||
.instrument(tracing::info_span!("stage_cve_scanning")),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(alerts) => alerts,
|
||||
Err(e) => {
|
||||
Ok(Ok(alerts)) => alerts,
|
||||
Ok(Err(e)) => {
|
||||
tracing::warn!("[{repo_id}] CVE scanning failed: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!("[{repo_id}] CVE scanning timed out after 10 minutes");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Stage 4: Pattern Scanning (GDPR + OAuth)
|
||||
|
||||
@@ -6,11 +6,16 @@ use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, Embed
|
||||
use compliance_core::models::graph::CodeNode;
|
||||
use compliance_graph::graph::chunking::extract_chunks;
|
||||
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
||||
use futures_util::stream::{FuturesUnordered, StreamExt};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::error::AgentError;
|
||||
use crate::llm::LlmClient;
|
||||
|
||||
const EMBED_BATCH_SIZE: usize = 20;
|
||||
const EMBED_CONCURRENCY: usize = 4;
|
||||
const EMBED_FLUSH_EVERY: usize = 200;
|
||||
|
||||
/// RAG pipeline for building embeddings and performing retrieval
|
||||
pub struct RagPipeline {
|
||||
llm: Arc<LlmClient>,
|
||||
@@ -77,25 +82,33 @@ impl RagPipeline {
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?;
|
||||
|
||||
// Step 3: Batch embed (small batches to stay within model limits)
|
||||
let batch_size = 20;
|
||||
let mut all_embeddings = Vec::new();
|
||||
// Step 3: Batch embed with bounded concurrency. Flush to Mongo and
|
||||
// update progress periodically so the dashboard can show live status.
|
||||
let mut pending = Vec::with_capacity(EMBED_FLUSH_EVERY);
|
||||
let mut embedded_count = 0u32;
|
||||
|
||||
for batch_start in (0..chunks.len()).step_by(batch_size) {
|
||||
let batch_end = (batch_start + batch_size).min(chunks.len());
|
||||
let batch_chunks = &chunks[batch_start..batch_end];
|
||||
|
||||
// Prepare texts: context_header + content
|
||||
let texts: Vec<String> = batch_chunks
|
||||
.iter()
|
||||
.map(|c| format!("{}\n{}", c.context_header, c.content))
|
||||
// Build the list of batch indices to process.
|
||||
let batches: Vec<(usize, usize)> = (0..chunks.len())
|
||||
.step_by(EMBED_BATCH_SIZE)
|
||||
.map(|start| (start, (start + EMBED_BATCH_SIZE).min(chunks.len())))
|
||||
.collect();
|
||||
|
||||
match self.llm.embed(texts).await {
|
||||
Ok(vectors) => {
|
||||
let mut batch_iter = batches.into_iter();
|
||||
let mut in_flight = FuturesUnordered::new();
|
||||
|
||||
// Prime up to EMBED_CONCURRENCY batches.
|
||||
for _ in 0..EMBED_CONCURRENCY {
|
||||
if let Some((start, end)) = batch_iter.next() {
|
||||
in_flight.push(self.embed_batch(&chunks[start..end], start, end));
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(result) = in_flight.next().await {
|
||||
match result {
|
||||
Ok((start, end, vectors)) => {
|
||||
let batch_chunks = &chunks[start..end];
|
||||
for (chunk, embedding) in batch_chunks.iter().zip(vectors) {
|
||||
all_embeddings.push(CodeEmbedding {
|
||||
pending.push(CodeEmbedding {
|
||||
id: None,
|
||||
repo_id: repo_id.to_string(),
|
||||
graph_build_id: graph_build_id.to_string(),
|
||||
@@ -113,9 +126,45 @@ impl RagPipeline {
|
||||
});
|
||||
}
|
||||
embedded_count += batch_chunks.len() as u32;
|
||||
|
||||
// Flush pending embeddings to Mongo periodically and update progress.
|
||||
if pending.len() >= EMBED_FLUSH_EVERY {
|
||||
self.embedding_store
|
||||
.store_embeddings(&pending)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
AgentError::Other(format!("Failed to store embeddings: {e}"))
|
||||
})?;
|
||||
pending.clear();
|
||||
}
|
||||
|
||||
// Always update the progress counter on the build doc — even if
|
||||
// we haven't flushed embeddings yet — so the UI shows movement.
|
||||
if let Err(e) = self
|
||||
.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Running,
|
||||
embedded_count,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("[{repo_id}] Failed to update build progress: {e}");
|
||||
}
|
||||
|
||||
// Queue the next batch to keep concurrency saturated.
|
||||
if let Some((s, e)) = batch_iter.next() {
|
||||
in_flight.push(self.embed_batch(&chunks[s..e], s, e));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("[{repo_id}] Embedding batch failed: {e}");
|
||||
// Flush whatever we have so partial progress isn't lost.
|
||||
if !pending.is_empty() {
|
||||
let _ = self.embedding_store.store_embeddings(&pending).await;
|
||||
}
|
||||
build.status = EmbeddingBuildStatus::Failed;
|
||||
build.error_message = Some(e.to_string());
|
||||
build.completed_at = Some(Utc::now());
|
||||
@@ -134,11 +183,13 @@ impl RagPipeline {
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Store all embeddings
|
||||
// Step 4: Flush any remaining embeddings
|
||||
if !pending.is_empty() {
|
||||
self.embedding_store
|
||||
.store_embeddings(&all_embeddings)
|
||||
.store_embeddings(&pending)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?;
|
||||
}
|
||||
|
||||
// Step 5: Update build status
|
||||
build.status = EmbeddingBuildStatus::Completed;
|
||||
@@ -161,4 +212,21 @@ impl RagPipeline {
|
||||
);
|
||||
Ok(build)
|
||||
}
|
||||
|
||||
/// Embed one batch of chunks. Returns the (start, end, vectors) tuple so
|
||||
/// out-of-order completion from `FuturesUnordered` can still be reconciled
|
||||
/// against the original chunk slice.
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
batch_chunks: &[compliance_graph::graph::chunking::CodeChunk],
|
||||
start: usize,
|
||||
end: usize,
|
||||
) -> Result<(usize, usize, Vec<Vec<f64>>), AgentError> {
|
||||
let texts: Vec<String> = batch_chunks
|
||||
.iter()
|
||||
.map(|c| format!("{}\n{}", c.context_header, c.content))
|
||||
.collect();
|
||||
let vectors = self.llm.embed(texts).await?;
|
||||
Ok((start, end, vectors))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user