fix(control-pipeline): anchor finder uses direct Qdrant search instead of Go SDK
The Go SDK RAG proxy returns 401 (Qdrant API key mismatch). Switch AnchorFinder to use direct Qdrant vector search + embedding service, same approach as the main pipeline. No dependency on Go SDK anymore. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1125,8 +1125,7 @@ async def _run_anchor_backfill(req: AnchorBackfillRequest, backfill_id: str):
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
rag_client = get_rag_client()
|
finder = AnchorFinder()
|
||||||
finder = AnchorFinder(rag_client=rag_client)
|
|
||||||
|
|
||||||
# Find controls without anchors
|
# Find controls without anchors
|
||||||
states = "('draft', 'needs_review')" if req.include_needs_review else "('draft',)"
|
states = "('draft', 'needs_review')" if req.include_needs_review else "('draft',)"
|
||||||
|
|||||||
@@ -2,19 +2,19 @@
|
|||||||
Anchor Finder — finds open-source references (OWASP, NIST, ENISA) for controls.
|
Anchor Finder — finds open-source references (OWASP, NIST, ENISA) for controls.
|
||||||
|
|
||||||
Two-stage search:
|
Two-stage search:
|
||||||
Stage A: RAG-internal search for open-source chunks matching the control topic
|
Stage A: Direct Qdrant vector search for open-source chunks matching the control topic
|
||||||
Stage B: Web search via DuckDuckGo Instant Answer API (no API key needed)
|
Stage B: Web search via DuckDuckGo Instant Answer API (no API key needed)
|
||||||
|
|
||||||
Only open-source references (Rule 1+2) are accepted as anchors.
|
Only open-source references (Rule 1+2) are accepted as anchors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .rag_client import ComplianceRAGClient, get_rag_client
|
|
||||||
from .control_generator import (
|
from .control_generator import (
|
||||||
GeneratedControl,
|
GeneratedControl,
|
||||||
REGULATION_LICENSE_MAP,
|
REGULATION_LICENSE_MAP,
|
||||||
@@ -25,9 +25,15 @@ from .control_generator import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
||||||
|
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
|
||||||
|
|
||||||
# Regulation codes that are safe to reference as open anchors (Rule 1+2)
|
# Regulation codes that are safe to reference as open anchors (Rule 1+2)
|
||||||
_OPEN_SOURCE_RULES = {1, 2}
|
_OPEN_SOURCE_RULES = {1, 2}
|
||||||
|
|
||||||
|
# Collections to search for anchors (open-source frameworks)
|
||||||
|
_ANCHOR_COLLECTIONS = ["bp_compliance_ce"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenAnchor:
|
class OpenAnchor:
|
||||||
@@ -39,8 +45,9 @@ class OpenAnchor:
|
|||||||
class AnchorFinder:
|
class AnchorFinder:
|
||||||
"""Finds open-source references to anchor generated controls."""
|
"""Finds open-source references to anchor generated controls."""
|
||||||
|
|
||||||
def __init__(self, rag_client: Optional[ComplianceRAGClient] = None):
|
def __init__(self, rag_client=None):
|
||||||
self.rag = rag_client or get_rag_client()
|
# rag_client kept for backwards compat but no longer used
|
||||||
|
pass
|
||||||
|
|
||||||
async def find_anchors(
|
async def find_anchors(
|
||||||
self,
|
self,
|
||||||
@@ -49,8 +56,8 @@ class AnchorFinder:
|
|||||||
min_anchors: int = 2,
|
min_anchors: int = 2,
|
||||||
) -> List[OpenAnchor]:
|
) -> List[OpenAnchor]:
|
||||||
"""Find open-source anchors for a control."""
|
"""Find open-source anchors for a control."""
|
||||||
# Stage A: RAG-internal search
|
# Stage A: Direct Qdrant vector search
|
||||||
anchors = await self._search_rag_for_open_anchors(control)
|
anchors = await self._search_qdrant_for_open_anchors(control)
|
||||||
|
|
||||||
# Stage B: Web search if not enough anchors
|
# Stage B: Web search if not enough anchors
|
||||||
if len(anchors) < min_anchors and not skip_web:
|
if len(anchors) < min_anchors and not skip_web:
|
||||||
@@ -63,39 +70,95 @@ class AnchorFinder:
|
|||||||
|
|
||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
async def _search_rag_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]:
|
async def _get_embedding(self, text: str) -> list:
|
||||||
"""Search RAG for chunks from open sources matching the control topic."""
|
"""Get embedding vector via embedding service."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{EMBEDDING_URL}/embed",
|
||||||
|
json={"texts": [text]},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
embeddings = resp.json().get("embeddings", [])
|
||||||
|
return embeddings[0] if embeddings else []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Embedding request failed: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_qdrant_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]:
|
||||||
|
"""Search Qdrant directly for chunks from open sources matching the control topic."""
|
||||||
# Build search query from control title + first 3 tags
|
# Build search query from control title + first 3 tags
|
||||||
tags_str = " ".join(control.tags[:3]) if control.tags else ""
|
tags_str = " ".join(control.tags[:3]) if control.tags else ""
|
||||||
query = f"{control.title} {tags_str}".strip()
|
query = f"{control.title} {tags_str}".strip()
|
||||||
|
|
||||||
results = await self.rag.search_with_rerank(
|
# Get embedding for query
|
||||||
query=query,
|
embedding = await self._get_embedding(query)
|
||||||
collection="bp_compliance_ce",
|
if not embedding:
|
||||||
top_k=15,
|
return []
|
||||||
)
|
|
||||||
|
|
||||||
anchors: List[OpenAnchor] = []
|
anchors: List[OpenAnchor] = []
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
|
|
||||||
for r in results:
|
for collection in _ANCHOR_COLLECTIONS:
|
||||||
if not r.regulation_code:
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{QDRANT_URL}/collections/{collection}/points/search",
|
||||||
|
json={
|
||||||
|
"vector": embedding,
|
||||||
|
"limit": 20,
|
||||||
|
"with_payload": True,
|
||||||
|
"with_vector": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning("Qdrant search %s failed: %d", collection, resp.status_code)
|
||||||
|
continue
|
||||||
|
|
||||||
|
results = resp.json().get("result", [])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Qdrant search error for %s: %s", collection, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for hit in results:
|
||||||
|
payload = hit.get("payload", {})
|
||||||
|
regulation_code = (
|
||||||
|
payload.get("regulation_code", "")
|
||||||
|
or payload.get("metadata", {}).get("regulation_code", "")
|
||||||
|
)
|
||||||
|
if not regulation_code:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only accept open-source references
|
# Only accept open-source references
|
||||||
license_info = _classify_regulation(r.regulation_code)
|
license_info = _classify_regulation(regulation_code)
|
||||||
if license_info.get("rule") not in _OPEN_SOURCE_RULES:
|
if license_info.get("rule") not in _OPEN_SOURCE_RULES:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Build reference key for dedup
|
# Build reference key for dedup
|
||||||
ref = r.article or r.category or ""
|
article = payload.get("article", "") or payload.get("metadata", {}).get("article", "")
|
||||||
key = f"{r.regulation_code}:{ref}"
|
category = payload.get("category", "") or payload.get("metadata", {}).get("category", "")
|
||||||
|
ref = article or category or ""
|
||||||
|
key = f"{regulation_code}:{ref}"
|
||||||
if key in seen:
|
if key in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(key)
|
seen.add(key)
|
||||||
|
|
||||||
framework_name = license_info.get("name", r.regulation_name or r.regulation_short or r.regulation_code)
|
reg_name = (
|
||||||
url = r.source_url or self._build_reference_url(r.regulation_code, ref)
|
payload.get("regulation_name", "")
|
||||||
|
or payload.get("metadata", {}).get("regulation_name", "")
|
||||||
|
)
|
||||||
|
reg_short = (
|
||||||
|
payload.get("regulation_short", "")
|
||||||
|
or payload.get("metadata", {}).get("regulation_short", "")
|
||||||
|
)
|
||||||
|
source_url = (
|
||||||
|
payload.get("source_url", "")
|
||||||
|
or payload.get("metadata", {}).get("source_url", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
framework_name = license_info.get("name", reg_name or reg_short or regulation_code)
|
||||||
|
url = source_url or self._build_reference_url(regulation_code, ref)
|
||||||
|
|
||||||
anchors.append(OpenAnchor(
|
anchors.append(OpenAnchor(
|
||||||
framework=framework_name,
|
framework=framework_name,
|
||||||
@@ -106,6 +169,9 @@ class AnchorFinder:
|
|||||||
if len(anchors) >= 5:
|
if len(anchors) >= 5:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if len(anchors) >= 5:
|
||||||
|
break
|
||||||
|
|
||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
async def _search_web(self, control: GeneratedControl) -> List[OpenAnchor]:
|
async def _search_web(self, control: GeneratedControl) -> List[OpenAnchor]:
|
||||||
|
|||||||
Reference in New Issue
Block a user