package store import ( "context" "encoding/json" "errors" "fmt" "strings" "time" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" ) // Postgres — pgxpool-backed Store. The schema this expects is produced by // the migrations/ package (M4.1 forward). type Postgres struct { pool *pgxpool.Pool } // NewPostgres opens a pool and pings. Caller must Close(). func NewPostgres(ctx context.Context, dsn string) (*Postgres, error) { cfg, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse dsn: %w", err) } cfg.MaxConns = 20 cfg.MinConns = 2 cfg.MaxConnLifetime = time.Hour pool, err := pgxpool.NewWithConfig(ctx, cfg) if err != nil { return nil, fmt.Errorf("create pool: %w", err) } if err := pool.Ping(ctx); err != nil { pool.Close() return nil, fmt.Errorf("ping: %w", err) } return &Postgres{pool: pool}, nil } func (p *Postgres) Close() { p.pool.Close() } func (p *Postgres) Ping(ctx context.Context) error { return p.pool.Ping(ctx) } // isUniqueViolation detects Postgres unique_violation (23505) so callers // can return ErrConflict cleanly. func isUniqueViolation(err error) bool { var pgErr *pgconn.PgError return errors.As(err, &pgErr) && pgErr.Code == pgerrcode.UniqueViolation } // isCheckViolation detects check_constraint_violation (23514) — used by the // slug regex check + plan/status enum guards. func isCheckViolation(err error) bool { var pgErr *pgconn.PgError return errors.As(err, &pgErr) && pgErr.Code == pgerrcode.CheckViolation } // ─── tenants ────────────────────────────────────────────────────────────── const tenantSelect = ` SELECT id::text, slug, name, status::text, kind::text, plan, COALESCE(erp_customer_id,''), COALESCE(stripe_cust_id,''), trial_ends_at, contract_start, contract_end, COALESCE(sales_owner,''), created_at, updated_at FROM tenants ` func scanTenant(row pgx.Row) (*Tenant, error) { var t Tenant var trialEnds, cStart, cEnd *time.Time err := row.Scan( &t.ID, &t.Slug, &t.Name, &t.Status, &t.Kind, &t.Plan, &t.ErpCustomerID, &t.StripeCustID, &trialEnds, &cStart, &cEnd, &t.SalesOwner, &t.CreatedAt, &t.UpdatedAt, ) if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } t.TrialEndsAt = trialEnds t.ContractStart = cStart t.ContractEnd = cEnd return &t, nil } func (p *Postgres) CreateTenant(ctx context.Context, in TenantCreate) (*Tenant, error) { kind := firstNonEmpty(in.Kind, "customer") plan := firstNonEmpty(in.Plan, "starter") row := p.pool.QueryRow(ctx, `INSERT INTO tenants (slug, name, kind, plan, sales_owner) VALUES ($1, $2, $3::tenant_kind, $4, NULLIF($5, '')) RETURNING id::text, slug, name, status::text, kind::text, plan, COALESCE(erp_customer_id,''), COALESCE(stripe_cust_id,''), trial_ends_at, contract_start, contract_end, COALESCE(sales_owner,''), created_at, updated_at`, in.Slug, in.Name, kind, plan, in.SalesOwner, ) t, err := scanTenant(row) if err != nil { if isUniqueViolation(err) { return nil, ErrConflict } if isCheckViolation(err) { return nil, ErrInvalidInput } return nil, fmt.Errorf("create tenant: %w", err) } return t, nil } func (p *Postgres) GetTenant(ctx context.Context, id string) (*Tenant, error) { return scanTenant(p.pool.QueryRow(ctx, tenantSelect+`WHERE id = $1::uuid`, id)) } func (p *Postgres) GetTenantBySlug(ctx context.Context, slug string) (*Tenant, error) { return scanTenant(p.pool.QueryRow(ctx, tenantSelect+`WHERE slug = $1`, slug)) } func (p *Postgres) UpdateTenant(ctx context.Context, id string, in TenantUpdate) (*Tenant, error) { // Build a partial UPDATE via COALESCE on each nullable input. Reads each // field once; trivially type-safe. row := p.pool.QueryRow(ctx, ` UPDATE tenants SET status = COALESCE($2::tenant_status, status), plan = COALESCE($3, plan), erp_customer_id = COALESCE($4, erp_customer_id), stripe_cust_id = COALESCE($5, stripe_cust_id), trial_ends_at = COALESCE($6, trial_ends_at), contract_start = COALESCE($7, contract_start), contract_end = COALESCE($8, contract_end), sales_owner = COALESCE($9, sales_owner) WHERE id = $1::uuid RETURNING id::text, slug, name, status::text, kind::text, plan, COALESCE(erp_customer_id,''), COALESCE(stripe_cust_id,''), trial_ends_at, contract_start, contract_end, COALESCE(sales_owner,''), created_at, updated_at`, id, nullableStr(in.Status), nullableStr(in.Plan), nullableStr(in.ErpCustomerID), nullableStr(in.StripeCustID), nullableTime(in.TrialEndsAt), nullableTime(in.ContractStart), nullableTime(in.ContractEnd), nullableStr(in.SalesOwner), ) t, err := scanTenant(row) if err != nil { if isCheckViolation(err) { return nil, ErrInvalidInput } return nil, err } return t, nil } // ─── entitlements ───────────────────────────────────────────────────────── func (p *Postgres) UpsertTenantProduct(ctx context.Context, tp TenantProduct) (*TenantProduct, error) { cfg, err := json.Marshal(tp.Config) if err != nil { return nil, fmt.Errorf("marshal config: %w", err) } if cfg == nil || string(cfg) == "null" { cfg = []byte("{}") } row := p.pool.QueryRow(ctx, ` INSERT INTO tenant_products (tenant_id, product, enabled, config, expires_at) VALUES ($1::uuid, $2, $3, $4::jsonb, $5) ON CONFLICT (tenant_id, product) DO UPDATE SET enabled = EXCLUDED.enabled, config = EXCLUDED.config, expires_at = EXCLUDED.expires_at RETURNING tenant_id::text, product, enabled, config, expires_at, created_at, updated_at`, tp.TenantID, tp.Product, tp.Enabled, cfg, tp.ExpiresAt, ) var out TenantProduct var rawCfg []byte var expires *time.Time err = row.Scan(&out.TenantID, &out.Product, &out.Enabled, &rawCfg, &expires, &out.CreatedAt, &out.UpdatedAt) if err != nil { // FK violation on tenant_id → not found var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == pgerrcode.ForeignKeyViolation { return nil, ErrNotFound } return nil, err } out.ExpiresAt = expires if err := json.Unmarshal(rawCfg, &out.Config); err != nil { out.Config = map[string]interface{}{} } return &out, nil } func (p *Postgres) ListTenantProducts(ctx context.Context, tenantID string) ([]TenantProduct, error) { // First confirm tenant exists so we can return ErrNotFound consistent with Memory. if _, err := p.GetTenant(ctx, tenantID); err != nil { return nil, err } rows, err := p.pool.Query(ctx, ` SELECT tenant_id::text, product, enabled, config, expires_at, created_at, updated_at FROM tenant_products WHERE tenant_id = $1::uuid ORDER BY product`, tenantID) if err != nil { return nil, err } defer rows.Close() out := []TenantProduct{} for rows.Next() { var tp TenantProduct var rawCfg []byte var expires *time.Time if err := rows.Scan(&tp.TenantID, &tp.Product, &tp.Enabled, &rawCfg, &expires, &tp.CreatedAt, &tp.UpdatedAt); err != nil { return nil, err } tp.ExpiresAt = expires if err := json.Unmarshal(rawCfg, &tp.Config); err != nil { tp.Config = map[string]interface{}{} } out = append(out, tp) } return out, rows.Err() } // ─── api keys ───────────────────────────────────────────────────────────── func (p *Postgres) CreateAPIKey(ctx context.Context, in APIKeyCreate) (*APIKey, error) { var product any if in.Product != "" { product = in.Product } var createdBy any if in.CreatedBy != "" { createdBy = in.CreatedBy } // Coerce nil to an empty slice — the schema's NOT NULL DEFAULT only // fires when the column is omitted, not when an explicit NULL is sent. scopes := in.Scopes if scopes == nil { scopes = []string{} } row := p.pool.QueryRow(ctx, ` INSERT INTO api_keys (tenant_id, product, name, scopes, hash, prefix, created_by) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7) RETURNING id::text, tenant_id::text, COALESCE(product,''), name, scopes, prefix, COALESCE(created_by,''), last_used_at, revoked_at, created_at`, in.TenantID, product, in.Name, scopes, in.Hash, in.Prefix, createdBy, ) k, err := scanAPIKey(row) if err != nil { if isUniqueViolation(err) { return nil, ErrConflict } var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == pgerrcode.ForeignKeyViolation { return nil, ErrNotFound } return nil, err } return k, nil } func scanAPIKey(row pgx.Row) (*APIKey, error) { var k APIKey var lastUsed, revoked *time.Time err := row.Scan(&k.ID, &k.TenantID, &k.Product, &k.Name, &k.Scopes, &k.Prefix, &k.CreatedBy, &lastUsed, &revoked, &k.CreatedAt) if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } k.LastUsedAt = lastUsed k.RevokedAt = revoked return &k, nil } func (p *Postgres) FindAPIKeyByPrefix(ctx context.Context, prefix string) (*APIKey, string, error) { row := p.pool.QueryRow(ctx, ` SELECT id::text, tenant_id::text, COALESCE(product,''), name, scopes, prefix, COALESCE(created_by,''), last_used_at, revoked_at, created_at, hash FROM api_keys WHERE prefix = $1 AND revoked_at IS NULL`, prefix) var k APIKey var lastUsed, revoked *time.Time var hash string err := row.Scan(&k.ID, &k.TenantID, &k.Product, &k.Name, &k.Scopes, &k.Prefix, &k.CreatedBy, &lastUsed, &revoked, &k.CreatedAt, &hash) if errors.Is(err, pgx.ErrNoRows) { return nil, "", ErrNotFound } if err != nil { return nil, "", err } k.LastUsedAt = lastUsed k.RevokedAt = revoked return &k, hash, nil } func (p *Postgres) TouchAPIKeyUsed(ctx context.Context, id string) error { tag, err := p.pool.Exec(ctx, `UPDATE api_keys SET last_used_at = NOW() WHERE id = $1::uuid`, id) if err != nil { return err } if tag.RowsAffected() == 0 { return ErrNotFound } return nil } func (p *Postgres) RevokeAPIKey(ctx context.Context, id string) error { tag, err := p.pool.Exec(ctx, `UPDATE api_keys SET revoked_at = NOW() WHERE id = $1::uuid AND revoked_at IS NULL`, id) if err != nil { return err } if tag.RowsAffected() == 0 { return ErrNotFound } return nil } func (p *Postgres) ListAPIKeys(ctx context.Context, tenantID string) ([]APIKey, error) { if _, err := p.GetTenant(ctx, tenantID); err != nil { return nil, err } rows, err := p.pool.Query(ctx, ` SELECT id::text, tenant_id::text, COALESCE(product,''), name, scopes, prefix, COALESCE(created_by,''), last_used_at, revoked_at, created_at FROM api_keys WHERE tenant_id = $1::uuid ORDER BY created_at DESC`, tenantID) if err != nil { return nil, err } defer rows.Close() out := []APIKey{} for rows.Next() { var k APIKey var lastUsed, revoked *time.Time if err := rows.Scan(&k.ID, &k.TenantID, &k.Product, &k.Name, &k.Scopes, &k.Prefix, &k.CreatedBy, &lastUsed, &revoked, &k.CreatedAt); err != nil { return nil, err } k.LastUsedAt = lastUsed k.RevokedAt = revoked out = append(out, k) } return out, rows.Err() } // ─── audit ──────────────────────────────────────────────────────────────── func (p *Postgres) AppendAudit(ctx context.Context, ev AuditEvent) (*AuditEvent, error) { meta, _ := json.Marshal(ev.Metadata) if meta == nil || string(meta) == "null" { meta = []byte("{}") } row := p.pool.QueryRow(ctx, ` INSERT INTO audit_log (tenant_id, project_id, actor_id, actor_name, actor_type, action, target_id, target_type, target_name, product, metadata, source_ip, user_agent) VALUES (NULLIF($1,'')::uuid, NULLIF($2,'')::uuid, NULLIF($3,''), NULLIF($4,''), NULLIF($5,''), $6, NULLIF($7,''), NULLIF($8,''), NULLIF($9,''), NULLIF($10,''), $11::jsonb, NULLIF($12,'')::inet, NULLIF($13,'')) RETURNING id, COALESCE(tenant_id::text,''), COALESCE(project_id::text,''), COALESCE(actor_id,''), COALESCE(actor_name,''), COALESCE(actor_type,''), action, COALESCE(target_id,''), COALESCE(target_type,''), COALESCE(target_name,''), COALESCE(product,''), metadata, COALESCE(host(source_ip),''), COALESCE(user_agent,''), created_at`, ev.TenantID, ev.ProjectID, ev.ActorID, ev.ActorName, ev.ActorType, ev.Action, ev.TargetID, ev.TargetType, ev.TargetName, ev.Product, meta, ev.SourceIP, ev.UserAgent, ) out, err := scanAudit(row) if err != nil { return nil, err } return out, nil } func scanAudit(row pgx.Row) (*AuditEvent, error) { var ev AuditEvent var rawMeta []byte err := row.Scan( &ev.ID, &ev.TenantID, &ev.ProjectID, &ev.ActorID, &ev.ActorName, &ev.ActorType, &ev.Action, &ev.TargetID, &ev.TargetType, &ev.TargetName, &ev.Product, &rawMeta, &ev.SourceIP, &ev.UserAgent, &ev.CreatedAt, ) if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } if len(rawMeta) > 0 { if err := json.Unmarshal(rawMeta, &ev.Metadata); err != nil { ev.Metadata = map[string]interface{}{} } } return &ev, nil } func (p *Postgres) ListAudit(ctx context.Context, f AuditFilter) ([]AuditEvent, int64, error) { limit := f.Limit if limit <= 0 || limit > 500 { limit = 100 } // Build WHERE clauses dynamically — keep param indices stable. where := []string{"id > $1"} args := []any{f.Cursor} add := func(clause string, v any) { args = append(args, v) where = append(where, fmt.Sprintf(clause, len(args))) } if f.TenantID != "" { add("tenant_id = $%d::uuid", f.TenantID) } if f.Product != "" { add("product = $%d", f.Product) } if f.ActorID != "" { add("actor_id = $%d", f.ActorID) } if f.Action != "" { add("action = $%d", f.Action) } if f.Since != nil { add("created_at >= $%d", *f.Since) } if f.Until != nil { add("created_at <= $%d", *f.Until) } args = append(args, limit) sql := ` SELECT id, COALESCE(tenant_id::text,''), COALESCE(project_id::text,''), COALESCE(actor_id,''), COALESCE(actor_name,''), COALESCE(actor_type,''), action, COALESCE(target_id,''), COALESCE(target_type,''), COALESCE(target_name,''), COALESCE(product,''), metadata, COALESCE(host(source_ip),''), COALESCE(user_agent,''), created_at FROM audit_log WHERE ` + strings.Join(where, " AND ") + ` ORDER BY id ASC LIMIT $` + fmt.Sprintf("%d", len(args)) rows, err := p.pool.Query(ctx, sql, args...) if err != nil { return nil, 0, err } defer rows.Close() out := []AuditEvent{} for rows.Next() { ev, err := scanAudit(rows) if err != nil { return nil, 0, err } out = append(out, *ev) } if err := rows.Err(); err != nil { return nil, 0, err } var nextCursor int64 if len(out) == limit && len(out) > 0 { nextCursor = out[len(out)-1].ID } return out, nextCursor, nil } func nullableStr(p *string) any { if p == nil { return nil } return *p } func nullableTime(p *time.Time) any { if p == nil { return nil } return *p }