fix(control-pipeline): anchor finder uses direct Qdrant search instead of Go SDK
The Go SDK RAG proxy returns 401 (Qdrant API key mismatch). Switch AnchorFinder to use direct Qdrant vector search + embedding service, same approach as the main pipeline. No dependency on Go SDK anymore. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -2,19 +2,19 @@
|
||||
Anchor Finder — finds open-source references (OWASP, NIST, ENISA) for controls.
|
||||
|
||||
Two-stage search:
|
||||
Stage A: RAG-internal search for open-source chunks matching the control topic
|
||||
Stage A: Direct Qdrant vector search for open-source chunks matching the control topic
|
||||
Stage B: Web search via DuckDuckGo Instant Answer API (no API key needed)
|
||||
|
||||
Only open-source references (Rule 1+2) are accepted as anchors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .rag_client import ComplianceRAGClient, get_rag_client
|
||||
from .control_generator import (
|
||||
GeneratedControl,
|
||||
REGULATION_LICENSE_MAP,
|
||||
@@ -25,9 +25,15 @@ from .control_generator import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
||||
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
|
||||
|
||||
# Regulation codes that are safe to reference as open anchors (Rule 1+2)
|
||||
_OPEN_SOURCE_RULES = {1, 2}
|
||||
|
||||
# Collections to search for anchors (open-source frameworks)
|
||||
_ANCHOR_COLLECTIONS = ["bp_compliance_ce"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAnchor:
|
||||
@@ -39,8 +45,9 @@ class OpenAnchor:
|
||||
class AnchorFinder:
|
||||
"""Finds open-source references to anchor generated controls."""
|
||||
|
||||
def __init__(self, rag_client: Optional[ComplianceRAGClient] = None):
|
||||
self.rag = rag_client or get_rag_client()
|
||||
def __init__(self, rag_client=None):
|
||||
# rag_client kept for backwards compat but no longer used
|
||||
pass
|
||||
|
||||
async def find_anchors(
|
||||
self,
|
||||
@@ -49,8 +56,8 @@ class AnchorFinder:
|
||||
min_anchors: int = 2,
|
||||
) -> List[OpenAnchor]:
|
||||
"""Find open-source anchors for a control."""
|
||||
# Stage A: RAG-internal search
|
||||
anchors = await self._search_rag_for_open_anchors(control)
|
||||
# Stage A: Direct Qdrant vector search
|
||||
anchors = await self._search_qdrant_for_open_anchors(control)
|
||||
|
||||
# Stage B: Web search if not enough anchors
|
||||
if len(anchors) < min_anchors and not skip_web:
|
||||
@@ -63,45 +70,104 @@ class AnchorFinder:
|
||||
|
||||
return anchors
|
||||
|
||||
async def _search_rag_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]:
|
||||
"""Search RAG for chunks from open sources matching the control topic."""
|
||||
async def _get_embedding(self, text: str) -> list:
|
||||
"""Get embedding vector via embedding service."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{EMBEDDING_URL}/embed",
|
||||
json={"texts": [text]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
return embeddings[0] if embeddings else []
|
||||
except Exception as e:
|
||||
logger.warning("Embedding request failed: %s", e)
|
||||
return []
|
||||
|
||||
async def _search_qdrant_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]:
|
||||
"""Search Qdrant directly for chunks from open sources matching the control topic."""
|
||||
# Build search query from control title + first 3 tags
|
||||
tags_str = " ".join(control.tags[:3]) if control.tags else ""
|
||||
query = f"{control.title} {tags_str}".strip()
|
||||
|
||||
results = await self.rag.search_with_rerank(
|
||||
query=query,
|
||||
collection="bp_compliance_ce",
|
||||
top_k=15,
|
||||
)
|
||||
# Get embedding for query
|
||||
embedding = await self._get_embedding(query)
|
||||
if not embedding:
|
||||
return []
|
||||
|
||||
anchors: List[OpenAnchor] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for r in results:
|
||||
if not r.regulation_code:
|
||||
for collection in _ANCHOR_COLLECTIONS:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.post(
|
||||
f"{QDRANT_URL}/collections/{collection}/points/search",
|
||||
json={
|
||||
"vector": embedding,
|
||||
"limit": 20,
|
||||
"with_payload": True,
|
||||
"with_vector": False,
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning("Qdrant search %s failed: %d", collection, resp.status_code)
|
||||
continue
|
||||
|
||||
results = resp.json().get("result", [])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Qdrant search error for %s: %s", collection, e)
|
||||
continue
|
||||
|
||||
# Only accept open-source references
|
||||
license_info = _classify_regulation(r.regulation_code)
|
||||
if license_info.get("rule") not in _OPEN_SOURCE_RULES:
|
||||
continue
|
||||
for hit in results:
|
||||
payload = hit.get("payload", {})
|
||||
regulation_code = (
|
||||
payload.get("regulation_code", "")
|
||||
or payload.get("metadata", {}).get("regulation_code", "")
|
||||
)
|
||||
if not regulation_code:
|
||||
continue
|
||||
|
||||
# Build reference key for dedup
|
||||
ref = r.article or r.category or ""
|
||||
key = f"{r.regulation_code}:{ref}"
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
# Only accept open-source references
|
||||
license_info = _classify_regulation(regulation_code)
|
||||
if license_info.get("rule") not in _OPEN_SOURCE_RULES:
|
||||
continue
|
||||
|
||||
framework_name = license_info.get("name", r.regulation_name or r.regulation_short or r.regulation_code)
|
||||
url = r.source_url or self._build_reference_url(r.regulation_code, ref)
|
||||
# Build reference key for dedup
|
||||
article = payload.get("article", "") or payload.get("metadata", {}).get("article", "")
|
||||
category = payload.get("category", "") or payload.get("metadata", {}).get("category", "")
|
||||
ref = article or category or ""
|
||||
key = f"{regulation_code}:{ref}"
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
|
||||
anchors.append(OpenAnchor(
|
||||
framework=framework_name,
|
||||
ref=ref,
|
||||
url=url,
|
||||
))
|
||||
reg_name = (
|
||||
payload.get("regulation_name", "")
|
||||
or payload.get("metadata", {}).get("regulation_name", "")
|
||||
)
|
||||
reg_short = (
|
||||
payload.get("regulation_short", "")
|
||||
or payload.get("metadata", {}).get("regulation_short", "")
|
||||
)
|
||||
source_url = (
|
||||
payload.get("source_url", "")
|
||||
or payload.get("metadata", {}).get("source_url", "")
|
||||
)
|
||||
|
||||
framework_name = license_info.get("name", reg_name or reg_short or regulation_code)
|
||||
url = source_url or self._build_reference_url(regulation_code, ref)
|
||||
|
||||
anchors.append(OpenAnchor(
|
||||
framework=framework_name,
|
||||
ref=ref,
|
||||
url=url,
|
||||
))
|
||||
|
||||
if len(anchors) >= 5:
|
||||
break
|
||||
|
||||
if len(anchors) >= 5:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user