Move 23 sources (18 national data protection laws + 5 EDPB guidelines/SCC) from bp_dsfa_corpus to bp_legal_corpus with vector preservation. Extend REGULATIONS array with national_law and eu_guideline types. Mark migrated sources in dsfa_corpus_ingestion.py to prevent re-ingestion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
308 lines
9.9 KiB
Python
308 lines
9.9 KiB
Python
"""
|
|
RAG Chunk Migration: bp_dsfa_corpus -> bp_legal_corpus
|
|
|
|
Verschiebt nationale Datenschutzgesetze und EU-Dokumente aus bp_dsfa_corpus
|
|
nach bp_legal_corpus. Vektoren werden 1:1 uebernommen (kein Re-Embedding).
|
|
|
|
Usage:
|
|
python migrate_rag_chunks.py # Dry run (default)
|
|
python migrate_rag_chunks.py --execute # Actually migrate
|
|
python migrate_rag_chunks.py --verify # Verify after migration
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
from datetime import datetime, timezone
|
|
from typing import List, Dict, Any
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
PointStruct, Filter, FieldCondition, MatchAny, ScrollRequest
|
|
)
|
|
|
|
# Configuration
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
|
SOURCE_COLLECTION = "bp_dsfa_corpus"
|
|
TARGET_COLLECTION = "bp_legal_corpus"
|
|
|
|
# Source codes to migrate from bp_dsfa_corpus -> bp_legal_corpus
|
|
SOURCES_TO_MIGRATE = [
|
|
# Nationale Datenschutzgesetze
|
|
"AT_DSG",
|
|
"BDSG_FULL",
|
|
"BE_DPA_LAW",
|
|
"CH_DSG",
|
|
"CZ_ZOU",
|
|
"ES_LOPDGDD",
|
|
"FI_TIETOSUOJALAKI",
|
|
"FR_CNIL_GUIDE",
|
|
"HU_INFOTV",
|
|
"IE_DPA_2018",
|
|
"IT_CODICE_PRIVACY",
|
|
"LI_DSG",
|
|
"NL_UAVG",
|
|
"NO_PERSONOPPLYSNINGSLOVEN",
|
|
"PL_UODO",
|
|
"SE_DATASKYDDSLAG",
|
|
"UK_DPA_2018",
|
|
"UK_GDPR",
|
|
# EU-Dokumente
|
|
"SCC_FULL_TEXT",
|
|
"EDPB_GUIDELINES_2_2019",
|
|
"EDPB_GUIDELINES_3_2019",
|
|
"EDPB_GUIDELINES_5_2020",
|
|
"EDPB_GUIDELINES_7_2020",
|
|
]
|
|
|
|
# Mapping: source_code -> regulation_type for bp_legal_corpus
|
|
REGULATION_TYPE_MAP = {
|
|
"AT_DSG": "national_law",
|
|
"BDSG_FULL": "de_law",
|
|
"BE_DPA_LAW": "national_law",
|
|
"CH_DSG": "national_law",
|
|
"CZ_ZOU": "national_law",
|
|
"ES_LOPDGDD": "national_law",
|
|
"FI_TIETOSUOJALAKI": "national_law",
|
|
"FR_CNIL_GUIDE": "national_law",
|
|
"HU_INFOTV": "national_law",
|
|
"IE_DPA_2018": "national_law",
|
|
"IT_CODICE_PRIVACY": "national_law",
|
|
"LI_DSG": "national_law",
|
|
"NL_UAVG": "national_law",
|
|
"NO_PERSONOPPLYSNINGSLOVEN": "national_law",
|
|
"PL_UODO": "national_law",
|
|
"SE_DATASKYDDSLAG": "national_law",
|
|
"UK_DPA_2018": "national_law",
|
|
"UK_GDPR": "national_law",
|
|
"SCC_FULL_TEXT": "eu_regulation",
|
|
"EDPB_GUIDELINES_2_2019": "eu_guideline",
|
|
"EDPB_GUIDELINES_3_2019": "eu_guideline",
|
|
"EDPB_GUIDELINES_5_2020": "eu_guideline",
|
|
"EDPB_GUIDELINES_7_2020": "eu_guideline",
|
|
}
|
|
|
|
# Mapping: source_code -> regulation_name for bp_legal_corpus
|
|
REGULATION_NAME_MAP = {
|
|
"AT_DSG": "DSG Oesterreich",
|
|
"BDSG_FULL": "BDSG",
|
|
"BE_DPA_LAW": "Datenschutzgesetz Belgien",
|
|
"CH_DSG": "DSG Schweiz",
|
|
"CZ_ZOU": "Zakon Tschechien",
|
|
"ES_LOPDGDD": "LOPDGDD Spanien",
|
|
"FI_TIETOSUOJALAKI": "Tietosuojalaki Finnland",
|
|
"FR_CNIL_GUIDE": "CNIL Guide RGPD",
|
|
"HU_INFOTV": "Infotv. Ungarn",
|
|
"IE_DPA_2018": "DPA 2018 Ireland",
|
|
"IT_CODICE_PRIVACY": "Codice Privacy Italien",
|
|
"LI_DSG": "DSG Liechtenstein",
|
|
"NL_UAVG": "UAVG Niederlande",
|
|
"NO_PERSONOPPLYSNINGSLOVEN": "Personopplysningsloven",
|
|
"PL_UODO": "UODO Polen",
|
|
"SE_DATASKYDDSLAG": "Dataskyddslag Schweden",
|
|
"UK_DPA_2018": "DPA 2018 UK",
|
|
"UK_GDPR": "UK GDPR",
|
|
"SCC_FULL_TEXT": "Standardvertragsklauseln",
|
|
"EDPB_GUIDELINES_2_2019": "EDPB GL 2/2019",
|
|
"EDPB_GUIDELINES_3_2019": "EDPB GL 3/2019",
|
|
"EDPB_GUIDELINES_5_2020": "EDPB GL 5/2020",
|
|
"EDPB_GUIDELINES_7_2020": "EDPB GL 7/2020",
|
|
}
|
|
|
|
|
|
def transform_payload(dsfa_payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Transform bp_dsfa_corpus payload to bp_legal_corpus format."""
|
|
source_code = dsfa_payload.get("source_code", "")
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
|
|
return {
|
|
"text": dsfa_payload.get("content", ""),
|
|
"regulation_code": source_code,
|
|
"regulation_name": REGULATION_NAME_MAP.get(source_code, dsfa_payload.get("source_name", "")),
|
|
"regulation_full_name": dsfa_payload.get("source_name", ""),
|
|
"regulation_type": REGULATION_TYPE_MAP.get(source_code, "national_law"),
|
|
"source_url": dsfa_payload.get("source_url", ""),
|
|
"chunk_index": dsfa_payload.get("chunk_index", 0),
|
|
"chunk_position": dsfa_payload.get("chunk_position", 0),
|
|
"article": dsfa_payload.get("article", None),
|
|
"paragraph": dsfa_payload.get("paragraph", None),
|
|
"language": dsfa_payload.get("language", "de"),
|
|
"indexed_at": now,
|
|
"training_allowed": False,
|
|
}
|
|
|
|
|
|
def scroll_all_points(client: QdrantClient, collection: str, source_codes: List[str]) -> List:
|
|
"""Scroll through all points matching the source codes."""
|
|
all_points = []
|
|
offset = None
|
|
batch_size = 100
|
|
|
|
scroll_filter = Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="source_code",
|
|
match=MatchAny(any=source_codes),
|
|
)
|
|
]
|
|
)
|
|
|
|
while True:
|
|
results, next_offset = client.scroll(
|
|
collection_name=collection,
|
|
scroll_filter=scroll_filter,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
with_vectors=True,
|
|
with_payload=True,
|
|
)
|
|
|
|
all_points.extend(results)
|
|
|
|
if next_offset is None:
|
|
break
|
|
offset = next_offset
|
|
|
|
return all_points
|
|
|
|
|
|
def migrate(execute: bool = False):
|
|
"""Run the migration."""
|
|
print(f"{'=' * 60}")
|
|
print(f"RAG Chunk Migration: {SOURCE_COLLECTION} -> {TARGET_COLLECTION}")
|
|
print(f"Mode: {'EXECUTE' if execute else 'DRY RUN'}")
|
|
print(f"{'=' * 60}")
|
|
print()
|
|
|
|
client = QdrantClient(url=QDRANT_URL)
|
|
|
|
# Get initial counts
|
|
source_info = client.get_collection(SOURCE_COLLECTION)
|
|
target_info = client.get_collection(TARGET_COLLECTION)
|
|
print(f"Before migration:")
|
|
print(f" {SOURCE_COLLECTION}: {source_info.points_count} points")
|
|
print(f" {TARGET_COLLECTION}: {target_info.points_count} points")
|
|
print()
|
|
|
|
# Scroll all points to migrate
|
|
print(f"Scrolling points for {len(SOURCES_TO_MIGRATE)} source codes...")
|
|
points = scroll_all_points(client, SOURCE_COLLECTION, SOURCES_TO_MIGRATE)
|
|
print(f" Found {len(points)} points to migrate")
|
|
print()
|
|
|
|
if not points:
|
|
print("No points found to migrate. Exiting.")
|
|
return
|
|
|
|
# Group by source_code for reporting
|
|
by_source: Dict[str, int] = {}
|
|
for p in points:
|
|
sc = p.payload.get("source_code", "UNKNOWN")
|
|
by_source[sc] = by_source.get(sc, 0) + 1
|
|
|
|
print("Points per source:")
|
|
for sc in sorted(by_source.keys()):
|
|
print(f" {sc}: {by_source[sc]} chunks")
|
|
print()
|
|
|
|
if not execute:
|
|
print("DRY RUN complete. Use --execute to actually migrate.")
|
|
return
|
|
|
|
# Transform and upsert in batches
|
|
batch_size = 50
|
|
upserted = 0
|
|
for i in range(0, len(points), batch_size):
|
|
batch = points[i:i + batch_size]
|
|
new_points = []
|
|
for p in batch:
|
|
new_payload = transform_payload(p.payload)
|
|
new_points.append(PointStruct(
|
|
id=p.id,
|
|
vector=p.vector,
|
|
payload=new_payload,
|
|
))
|
|
|
|
client.upsert(
|
|
collection_name=TARGET_COLLECTION,
|
|
points=new_points,
|
|
)
|
|
upserted += len(new_points)
|
|
print(f" Upserted {upserted}/{len(points)} points...")
|
|
|
|
print(f"\nUpsert complete: {upserted} points added to {TARGET_COLLECTION}")
|
|
|
|
# Delete from source collection
|
|
point_ids = [p.id for p in points]
|
|
for i in range(0, len(point_ids), 100):
|
|
batch_ids = point_ids[i:i + 100]
|
|
client.delete(
|
|
collection_name=SOURCE_COLLECTION,
|
|
points_selector=batch_ids,
|
|
)
|
|
print(f" Deleted {min(i + 100, len(point_ids))}/{len(point_ids)} from {SOURCE_COLLECTION}...")
|
|
|
|
print(f"\nDelete complete: {len(point_ids)} points removed from {SOURCE_COLLECTION}")
|
|
|
|
# Final counts
|
|
source_info = client.get_collection(SOURCE_COLLECTION)
|
|
target_info = client.get_collection(TARGET_COLLECTION)
|
|
print(f"\nAfter migration:")
|
|
print(f" {SOURCE_COLLECTION}: {source_info.points_count} points")
|
|
print(f" {TARGET_COLLECTION}: {target_info.points_count} points")
|
|
print(f"\nMigration complete!")
|
|
|
|
|
|
def verify():
|
|
"""Verify migration results."""
|
|
print(f"Verifying migration...")
|
|
client = QdrantClient(url=QDRANT_URL)
|
|
|
|
source_info = client.get_collection(SOURCE_COLLECTION)
|
|
target_info = client.get_collection(TARGET_COLLECTION)
|
|
print(f" {SOURCE_COLLECTION}: {source_info.points_count} points")
|
|
print(f" {TARGET_COLLECTION}: {target_info.points_count} points")
|
|
|
|
# Check that migrated sources are gone from dsfa
|
|
remaining = scroll_all_points(client, SOURCE_COLLECTION, SOURCES_TO_MIGRATE)
|
|
if remaining:
|
|
print(f"\n WARNING: {len(remaining)} points still in {SOURCE_COLLECTION}!")
|
|
by_source: Dict[str, int] = {}
|
|
for p in remaining:
|
|
sc = p.payload.get("source_code", "UNKNOWN")
|
|
by_source[sc] = by_source.get(sc, 0) + 1
|
|
for sc, cnt in sorted(by_source.items()):
|
|
print(f" {sc}: {cnt}")
|
|
else:
|
|
print(f"\n OK: No migrated sources remaining in {SOURCE_COLLECTION}")
|
|
|
|
# Check that migrated sources exist in legal
|
|
for code in SOURCES_TO_MIGRATE:
|
|
results, _ = client.scroll(
|
|
collection_name=TARGET_COLLECTION,
|
|
scroll_filter=Filter(
|
|
must=[FieldCondition(key="regulation_code", match=MatchAny(any=[code]))]
|
|
),
|
|
limit=1,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
status = f"{len(results)}+ chunks" if results else "MISSING"
|
|
print(f" {TARGET_COLLECTION}/{code}: {status}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Migrate RAG chunks between collections")
|
|
parser.add_argument("--execute", action="store_true", help="Actually execute the migration (default: dry run)")
|
|
parser.add_argument("--verify", action="store_true", help="Verify migration results")
|
|
args = parser.parse_args()
|
|
|
|
if args.verify:
|
|
verify()
|
|
else:
|
|
migrate(execute=args.execute)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|