#!/usr/bin/env python3 """Sync canonical control tables between production and local DB. Modes: --pull Production → Local (initial sync, full table copy) --push Local → Production (incremental, only new obligation_candidates) --loop Run --push every N minutes (default 60) Usage: python3 sync_db.py --pull # Full sync production → local python3 sync_db.py --push # Push new obligations to production python3 sync_db.py --loop 60 # Push every 60 minutes python3 sync_db.py --pull --tables canonical_controls # Only one table """ import argparse import json import os import sys import time import urllib.parse import io import psycopg2 import psycopg2.extras import psycopg2.extensions # Register JSON adapter so dicts are automatically converted to JSONB psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json) # ── DB Config ──────────────────────────────────────────────────────── PROD_URL = os.environ.get( "PROD_DATABASE_URL", "postgresql://postgres:GmyFD3wnU1NrKBdpU1nwLdE8MLts0A0eez8L5XXdvUCe05lWnWfVp3C6JJ8Yrmt2" "@46.225.100.82:54321/postgres?sslmode=require", ) LOCAL_URL = os.environ.get( "LOCAL_DATABASE_URL", "postgresql://breakpilot:breakpilot123@localhost:5432/breakpilot_db", ) SCHEMA = "compliance" # Tables to sync (production → local) SYNC_TABLES = [ "canonical_control_frameworks", "canonical_control_licenses", "canonical_control_sources", "canonical_control_categories", "canonical_blocked_sources", "canonical_controls", "canonical_control_mappings", "canonical_processed_chunks", "canonical_generation_jobs", "control_patterns", "crosswalk_matrix", "obligation_extractions", "obligation_candidates", ] def connect(url, label="DB"): parsed = urllib.parse.urlparse(url) params = dict(urllib.parse.parse_qsl(parsed.query)) conn = psycopg2.connect( host=parsed.hostname, port=parsed.port or 5432, user=parsed.username, password=parsed.password, dbname=parsed.path.lstrip("/"), sslmode=params.get("sslmode", "prefer"), options=f"-c search_path={SCHEMA},public", keepalives=1, keepalives_idle=30, keepalives_interval=10, keepalives_count=5, ) conn.autocommit = False print(f" Connected to {label} ({parsed.hostname}:{parsed.port or 5432})") return conn def get_columns(cur, table): cur.execute(f""" SELECT column_name FROM information_schema.columns WHERE table_schema = '{SCHEMA}' AND table_name = '{table}' ORDER BY ordinal_position """) return [r[0] for r in cur.fetchall()] def pull_table(prod_conn, local_conn, table): """Copy entire table from production to local via SELECT + INSERT.""" prod_cur = prod_conn.cursor() local_cur = local_conn.cursor() # Check table exists on production prod_cur.execute(f""" SELECT 1 FROM pg_tables WHERE schemaname = '{SCHEMA}' AND tablename = '{table}' """) if not prod_cur.fetchone(): print(f" SKIP {table} — not found on production") return 0 # Drop local table local_cur.execute(f"DROP TABLE IF EXISTS {SCHEMA}.{table} CASCADE") local_conn.commit() # Build simple CREATE TABLE (no constraints, no defaults — just for data) prod_cur.execute(f""" SELECT column_name, data_type, udt_name, character_maximum_length FROM information_schema.columns WHERE table_schema = '{SCHEMA}' AND table_name = '{table}' ORDER BY ordinal_position """) col_defs = prod_cur.fetchall() parts = [] col_names = [] jsonb_cols = set() for name, dtype, udt, max_len in col_defs: col_names.append(name) if dtype == "ARRAY": type_map = { "_text": "text[]", "_varchar": "varchar[]", "_int4": "integer[]", "_uuid": "uuid[]", "_jsonb": "jsonb[]", "_float8": "float8[]", } sql_type = type_map.get(udt, f"{udt.lstrip('_')}[]") elif dtype == "USER-DEFINED" and udt == "jsonb": sql_type = "jsonb" jsonb_cols.add(name) elif dtype == "USER-DEFINED": sql_type = udt elif dtype == "jsonb": sql_type = "jsonb" jsonb_cols.add(name) elif max_len: sql_type = f"{dtype}({max_len})" else: sql_type = dtype parts.append(f'"{name}" {sql_type}') ddl = f"CREATE TABLE {SCHEMA}.{table} ({', '.join(parts)})" local_cur.execute(ddl) local_conn.commit() # Fetch all rows from production col_list = ", ".join(f'"{c}"' for c in col_names) prod_cur.execute(f"SELECT {col_list} FROM {SCHEMA}.{table}") rows = prod_cur.fetchall() if rows: # Wrap dict/list values in Json for JSONB columns adapted_rows = [] for row in rows: adapted = [] for i, val in enumerate(row): if col_names[i] in jsonb_cols and isinstance(val, (dict, list)): adapted.append(psycopg2.extras.Json(val)) else: adapted.append(val) adapted_rows.append(tuple(adapted)) placeholders = ", ".join(["%s"] * len(col_names)) insert_sql = f'INSERT INTO {SCHEMA}.{table} ({col_list}) VALUES ({placeholders})' psycopg2.extras.execute_batch(local_cur, insert_sql, adapted_rows, page_size=500) local_conn.commit() print(f" {table}: {len(rows)} rows") return len(rows) def pull(tables=None): """Full sync: production → local.""" print("\n=== PULL: Production → Local ===\n") prod_conn = connect(PROD_URL, "Production") local_conn = connect(LOCAL_URL, "Local") # Ensure schema exists local_cur = local_conn.cursor() local_cur.execute(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA}") local_conn.commit() sync_list = tables if tables else SYNC_TABLES total = 0 for table in sync_list: try: count = pull_table(prod_conn, local_conn, table) total += count except Exception as e: print(f" ERROR {table}: {e}") local_conn.rollback() prod_conn.rollback() print(f"\n Total: {total} rows synced") prod_conn.close() local_conn.close() def push(): """Incremental push: new obligation_candidates local → production.""" print(f"\n=== PUSH: Local → Production ({time.strftime('%H:%M:%S')}) ===\n") local_conn = connect(LOCAL_URL, "Local") prod_conn = connect(PROD_URL, "Production") local_cur = local_conn.cursor() prod_cur = prod_conn.cursor() # Find obligation_candidates in local that don't exist in production # Use candidate_id as the unique key local_cur.execute(f""" SELECT candidate_id FROM {SCHEMA}.obligation_candidates """) local_ids = {r[0] for r in local_cur.fetchall()} if not local_ids: print(" No obligation_candidates in local DB") local_conn.close() prod_conn.close() return 0 # Check which already exist on production prod_cur.execute(f""" SELECT candidate_id FROM {SCHEMA}.obligation_candidates """) prod_ids = {r[0] for r in prod_cur.fetchall()} new_ids = local_ids - prod_ids if not new_ids: print(f" All {len(local_ids)} obligations already on production") local_conn.close() prod_conn.close() return 0 print(f" {len(new_ids)} new obligations to push (local: {len(local_ids)}, prod: {len(prod_ids)})") # Get columns columns = get_columns(local_cur, "obligation_candidates") col_list = ", ".join(columns) placeholders = ", ".join(["%s"] * len(columns)) # Fetch new rows from local id_list = ", ".join(f"'{i}'" for i in new_ids) local_cur.execute(f""" SELECT {col_list} FROM {SCHEMA}.obligation_candidates WHERE candidate_id IN ({id_list}) """) rows = local_cur.fetchall() # Insert into production insert_sql = f"INSERT INTO {SCHEMA}.obligation_candidates ({col_list}) VALUES ({placeholders}) ON CONFLICT DO NOTHING" psycopg2.extras.execute_batch(prod_cur, insert_sql, rows, page_size=100) prod_conn.commit() print(f" Pushed {len(rows)} obligations to production") local_conn.close() prod_conn.close() return len(rows) def loop(interval_min): """Run push every N minutes.""" print(f"\n=== SYNC LOOP — Push every {interval_min} min ===") print(f" Started at {time.strftime('%Y-%m-%d %H:%M:%S')}") print(f" Press Ctrl+C to stop\n") while True: try: pushed = push() if pushed: print(f" Next sync in {interval_min} min...") except Exception as e: print(f" SYNC ERROR: {e}") time.sleep(interval_min * 60) def main(): parser = argparse.ArgumentParser(description="Sync canonical control tables") parser.add_argument("--pull", action="store_true", help="Production → Local (full copy)") parser.add_argument("--push", action="store_true", help="Local → Production (new obligations)") parser.add_argument("--loop", type=int, metavar="MIN", help="Push every N minutes") parser.add_argument("--tables", nargs="+", help="Only sync specific tables (with --pull)") args = parser.parse_args() if not any([args.pull, args.push, args.loop]): parser.print_help() return if args.pull: pull(args.tables) if args.push: push() if args.loop: loop(args.loop) if __name__ == "__main__": main()