""" Banner A/B Testing Service — variant assignment, stats, significance. Deterministic variant assignment via device fingerprint hash ensures the same device always sees the same variant (sticky bucketing). """ import hashlib import math import uuid from datetime import datetime, timezone from typing import Any, Optional from sqlalchemy import text from sqlalchemy.orm import Session class BannerABService: """A/B testing for consent banner variants.""" def __init__(self, db: Session) -> None: self.db = db # ------------------------------------------------------------------ # Variant CRUD # ------------------------------------------------------------------ def list_variants(self, tenant_id: str, site_config_id: str) -> list[dict]: q = text(""" SELECT * FROM compliance_banner_variants WHERE tenant_id = :tid AND site_config_id = :scid ORDER BY variant_key """) rows = self.db.execute(q, {"tid": tenant_id, "scid": site_config_id}).fetchall() return [dict(r._mapping) for r in rows] def create_variant(self, tenant_id: str, site_config_id: str, data: dict) -> dict: q = text(""" INSERT INTO compliance_banner_variants (tenant_id, site_config_id, variant_name, variant_key, traffic_percent, is_control, banner_title, banner_description, position, style, primary_color, show_decline_all, theme_overrides) VALUES (:tid, :scid, :name, :key, :pct, :ctrl, :title, :desc, :pos, :style, :color, :decline, :theme) RETURNING * """) row = self.db.execute(q, { "tid": tenant_id, "scid": site_config_id, "name": data.get("variant_name", ""), "key": data.get("variant_key", "A"), "pct": data.get("traffic_percent", 50), "ctrl": data.get("is_control", False), "title": data.get("banner_title"), "desc": data.get("banner_description"), "pos": data.get("position"), "style": data.get("style"), "color": data.get("primary_color"), "decline": data.get("show_decline_all"), "theme": data.get("theme_overrides", "{}"), }).fetchone() self.db.commit() return dict(row._mapping) def update_variant(self, variant_id: str, data: dict) -> Optional[dict]: sets, params = [], {"vid": variant_id} for field in ["variant_name", "traffic_percent", "is_control", "banner_title", "banner_description", "position", "style", "primary_color", "show_decline_all", "is_active"]: if field in data and data[field] is not None: sets.append(f"{field} = :{field}") params[field] = data[field] if not sets: return None sets.append("updated_at = NOW()") q = text(f"UPDATE compliance_banner_variants SET {', '.join(sets)} WHERE id = :vid RETURNING *") row = self.db.execute(q, params).fetchone() self.db.commit() return dict(row._mapping) if row else None def delete_variant(self, variant_id: str) -> bool: q = text("DELETE FROM compliance_banner_variants WHERE id = :vid") result = self.db.execute(q, {"vid": variant_id}) self.db.commit() return result.rowcount > 0 # ------------------------------------------------------------------ # Variant Assignment (deterministic sticky bucketing) # ------------------------------------------------------------------ def assign_variant(self, site_config_id: str, device_fingerprint: str) -> Optional[dict]: """Assign a variant based on device fingerprint hash. Returns variant or None.""" variants = self.db.execute(text(""" SELECT * FROM compliance_banner_variants WHERE site_config_id = :scid AND is_active = TRUE ORDER BY variant_key """), {"scid": site_config_id}).fetchall() if not variants: return None # Deterministic bucket 0-99 from device fingerprint bucket = int(hashlib.md5(f"{site_config_id}:{device_fingerprint}".encode()).hexdigest(), 16) % 100 cumulative = 0 for v in variants: cumulative += v.traffic_percent if bucket < cumulative: return dict(v._mapping) # Fallback to last variant return dict(variants[-1]._mapping) # ------------------------------------------------------------------ # Stats with statistical significance # ------------------------------------------------------------------ def get_variant_stats(self, tenant_id: str, site_config_id: str) -> list[dict]: """Per-variant stats with chi-squared significance test.""" variants = self.list_variants(tenant_id, site_config_id) if not variants: return [] results = [] for v in variants: vid = str(v["id"]) vkey = v["variant_key"] q = text(""" SELECT COUNT(*) AS total, COUNT(*) FILTER (WHERE action = 'consent_given') AS accepted, COUNT(*) FILTER (WHERE action IN ('consent_withdrawn', 'consent_revoked')) AS rejected FROM compliance_banner_consent_audit_log WHERE tenant_id = :tid AND variant_key = :vkey """) row = self.db.execute(q, {"tid": tenant_id, "vkey": vkey}).fetchone() total = row.total if row else 0 accepted = row.accepted if row else 0 results.append({ "variant_id": vid, "variant_key": vkey, "variant_name": v["variant_name"], "traffic_percent": v["traffic_percent"], "is_control": v["is_control"], "total": total, "accepted": accepted, "opt_in_rate": round(accepted / total * 100, 1) if total > 0 else 0, }) # Chi-squared test between control and best variant control = next((r for r in results if r["is_control"]), None) if control and len(results) > 1: best = max((r for r in results if not r["is_control"]), key=lambda x: x["opt_in_rate"], default=None) if best and control["total"] > 0 and best["total"] > 0: sig = self._chi_squared_significance( control["accepted"], control["total"], best["accepted"], best["total"], ) best["is_winner"] = sig > 0.95 best["significance"] = round(sig * 100, 1) control["is_winner"] = False control["significance"] = round((1 - sig) * 100, 1) return results @staticmethod def _chi_squared_significance(a_success: int, a_total: int, b_success: int, b_total: int) -> float: """Simple chi-squared test for 2x2 contingency table. Returns confidence 0-1.""" a_fail = a_total - a_success b_fail = b_total - b_success n = a_total + b_total if n == 0: return 0.0 # Expected values exp_a_s = a_total * (a_success + b_success) / n exp_a_f = a_total * (a_fail + b_fail) / n exp_b_s = b_total * (a_success + b_success) / n exp_b_f = b_total * (a_fail + b_fail) / n chi2 = 0.0 for obs, exp in [(a_success, exp_a_s), (a_fail, exp_a_f), (b_success, exp_b_s), (b_fail, exp_b_f)]: if exp > 0: chi2 += (obs - exp) ** 2 / exp # Approximate p-value for 1 df using Wilson-Hilferty if chi2 < 0.001: return 0.0 if chi2 > 10.83: return 0.999 # Lookup table for common thresholds (1 df) thresholds = [(2.706, 0.90), (3.841, 0.95), (5.024, 0.975), (6.635, 0.99), (10.83, 0.999)] confidence = 0.0 for threshold, conf in thresholds: if chi2 >= threshold: confidence = conf return confidence