"""Shared embedding + sub-clustering utilities for the control pipeline.""" import logging import os from collections import defaultdict import httpx import numpy as np from sklearn.cluster import MiniBatchKMeans logger = logging.getLogger(__name__) EMBEDDING_URL = os.getenv( "EMBEDDING_SERVICE_URL", "http://embedding-service:8087" ) def embed_texts(texts: list[str]) -> np.ndarray | None: """Embed texts via the embedding-service in batches of 64.""" try: result = np.zeros((len(texts), 1024), dtype=np.float32) batch_size = 64 for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] for attempt in range(3): try: with httpx.Client( timeout=httpx.Timeout(60.0, connect=10.0) ) as client: resp = client.post( f"{EMBEDDING_URL}/embed", json={"texts": batch} ) resp.raise_for_status() embs = resp.json().get("embeddings", []) end = min(i + len(embs), len(texts)) result[i:end] = np.array(embs, dtype=np.float32) break except Exception as e: if attempt == 2: logger.error("Embed batch %d failed: %s", i, e) import time time.sleep(2) return result except Exception as e: logger.error("Embedding failed: %s", e) return None def subcluster_controls( controls: list[dict], target_size: int = 50 ) -> list[list[dict]]: """Sub-cluster controls by embedding similarity. Returns a list of clusters. Falls back to naive chunking if embedding fails. """ if len(controls) <= target_size: return [controls] texts = [c.get("title", "") or c.get("control_id", "") for c in controls] embeddings = embed_texts(texts) if embeddings is None: return [ controls[i : i + target_size] for i in range(0, len(controls), target_size) ] norms = np.linalg.norm(embeddings, axis=1, keepdims=True) norms[norms == 0] = 1 normalized = embeddings / norms k = max(2, min(len(controls) // target_size, 30)) kmeans = MiniBatchKMeans( n_clusters=k, batch_size=min(100, len(controls)), max_iter=50, random_state=42, ) labels = kmeans.fit_predict(normalized) clusters: dict[int, list[dict]] = defaultdict(list) for i, ctrl in enumerate(controls): clusters[int(labels[i])].append(ctrl) return list(clusters.values())