Files
tenant-registry/internal/store/postgres.go
T
sharang ffab866c87
ci / shared (push) Successful in 6s
ci / test (push) Successful in 1m15s
ci / image (push) Has been skipped
feat(api): M4.2 — REST surface + pgx Postgres store + OpenAPI 3.1
Full M4.2 deliverable: 16 endpoints (tenants CRUD + lifecycle, catalog, entitlements, API keys with argon2 hashing, audit append + filter), Store interface with pgx-backed Postgres + in-memory parallel implementations exercised by the same eachStore harness, openapi.yaml at 3.1 with kin-openapi contract test. M4.3 adds auth.

Refs: M4.2
2026-05-19 10:51:59 +00:00

479 lines
15 KiB
Go

package store
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
// Postgres — pgxpool-backed Store. The schema this expects is produced by
// the migrations/ package (M4.1 forward).
type Postgres struct {
pool *pgxpool.Pool
}
// NewPostgres opens a pool and pings. Caller must Close().
func NewPostgres(ctx context.Context, dsn string) (*Postgres, error) {
cfg, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("parse dsn: %w", err)
}
cfg.MaxConns = 20
cfg.MinConns = 2
cfg.MaxConnLifetime = time.Hour
pool, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("create pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping: %w", err)
}
return &Postgres{pool: pool}, nil
}
func (p *Postgres) Close() { p.pool.Close() }
func (p *Postgres) Ping(ctx context.Context) error { return p.pool.Ping(ctx) }
// isUniqueViolation detects Postgres unique_violation (23505) so callers
// can return ErrConflict cleanly.
func isUniqueViolation(err error) bool {
var pgErr *pgconn.PgError
return errors.As(err, &pgErr) && pgErr.Code == pgerrcode.UniqueViolation
}
// isCheckViolation detects check_constraint_violation (23514) — used by the
// slug regex check + plan/status enum guards.
func isCheckViolation(err error) bool {
var pgErr *pgconn.PgError
return errors.As(err, &pgErr) && pgErr.Code == pgerrcode.CheckViolation
}
// ─── tenants ──────────────────────────────────────────────────────────────
const tenantSelect = `
SELECT id::text, slug, name, status::text, kind::text, plan,
COALESCE(erp_customer_id,''), COALESCE(stripe_cust_id,''),
trial_ends_at, contract_start, contract_end, COALESCE(sales_owner,''),
created_at, updated_at
FROM tenants `
func scanTenant(row pgx.Row) (*Tenant, error) {
var t Tenant
var trialEnds, cStart, cEnd *time.Time
err := row.Scan(
&t.ID, &t.Slug, &t.Name, &t.Status, &t.Kind, &t.Plan,
&t.ErpCustomerID, &t.StripeCustID,
&trialEnds, &cStart, &cEnd, &t.SalesOwner,
&t.CreatedAt, &t.UpdatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
t.TrialEndsAt = trialEnds
t.ContractStart = cStart
t.ContractEnd = cEnd
return &t, nil
}
func (p *Postgres) CreateTenant(ctx context.Context, in TenantCreate) (*Tenant, error) {
kind := firstNonEmpty(in.Kind, "customer")
plan := firstNonEmpty(in.Plan, "starter")
row := p.pool.QueryRow(ctx,
`INSERT INTO tenants (slug, name, kind, plan, sales_owner)
VALUES ($1, $2, $3::tenant_kind, $4, NULLIF($5, ''))
RETURNING id::text, slug, name, status::text, kind::text, plan,
COALESCE(erp_customer_id,''), COALESCE(stripe_cust_id,''),
trial_ends_at, contract_start, contract_end, COALESCE(sales_owner,''),
created_at, updated_at`,
in.Slug, in.Name, kind, plan, in.SalesOwner,
)
t, err := scanTenant(row)
if err != nil {
if isUniqueViolation(err) {
return nil, ErrConflict
}
if isCheckViolation(err) {
return nil, ErrInvalidInput
}
return nil, fmt.Errorf("create tenant: %w", err)
}
return t, nil
}
func (p *Postgres) GetTenant(ctx context.Context, id string) (*Tenant, error) {
return scanTenant(p.pool.QueryRow(ctx, tenantSelect+`WHERE id = $1::uuid`, id))
}
func (p *Postgres) GetTenantBySlug(ctx context.Context, slug string) (*Tenant, error) {
return scanTenant(p.pool.QueryRow(ctx, tenantSelect+`WHERE slug = $1`, slug))
}
func (p *Postgres) UpdateTenant(ctx context.Context, id string, in TenantUpdate) (*Tenant, error) {
// Build a partial UPDATE via COALESCE on each nullable input. Reads each
// field once; trivially type-safe.
row := p.pool.QueryRow(ctx, `
UPDATE tenants SET
status = COALESCE($2::tenant_status, status),
plan = COALESCE($3, plan),
erp_customer_id = COALESCE($4, erp_customer_id),
stripe_cust_id = COALESCE($5, stripe_cust_id),
trial_ends_at = COALESCE($6, trial_ends_at),
contract_start = COALESCE($7, contract_start),
contract_end = COALESCE($8, contract_end),
sales_owner = COALESCE($9, sales_owner)
WHERE id = $1::uuid
RETURNING id::text, slug, name, status::text, kind::text, plan,
COALESCE(erp_customer_id,''), COALESCE(stripe_cust_id,''),
trial_ends_at, contract_start, contract_end, COALESCE(sales_owner,''),
created_at, updated_at`,
id,
nullableStr(in.Status), nullableStr(in.Plan),
nullableStr(in.ErpCustomerID), nullableStr(in.StripeCustID),
nullableTime(in.TrialEndsAt), nullableTime(in.ContractStart), nullableTime(in.ContractEnd),
nullableStr(in.SalesOwner),
)
t, err := scanTenant(row)
if err != nil {
if isCheckViolation(err) {
return nil, ErrInvalidInput
}
return nil, err
}
return t, nil
}
// ─── entitlements ─────────────────────────────────────────────────────────
func (p *Postgres) UpsertTenantProduct(ctx context.Context, tp TenantProduct) (*TenantProduct, error) {
cfg, err := json.Marshal(tp.Config)
if err != nil {
return nil, fmt.Errorf("marshal config: %w", err)
}
if cfg == nil || string(cfg) == "null" {
cfg = []byte("{}")
}
row := p.pool.QueryRow(ctx, `
INSERT INTO tenant_products (tenant_id, product, enabled, config, expires_at)
VALUES ($1::uuid, $2, $3, $4::jsonb, $5)
ON CONFLICT (tenant_id, product) DO UPDATE SET
enabled = EXCLUDED.enabled,
config = EXCLUDED.config,
expires_at = EXCLUDED.expires_at
RETURNING tenant_id::text, product, enabled, config, expires_at, created_at, updated_at`,
tp.TenantID, tp.Product, tp.Enabled, cfg, tp.ExpiresAt,
)
var out TenantProduct
var rawCfg []byte
var expires *time.Time
err = row.Scan(&out.TenantID, &out.Product, &out.Enabled, &rawCfg, &expires, &out.CreatedAt, &out.UpdatedAt)
if err != nil {
// FK violation on tenant_id → not found
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == pgerrcode.ForeignKeyViolation {
return nil, ErrNotFound
}
return nil, err
}
out.ExpiresAt = expires
if err := json.Unmarshal(rawCfg, &out.Config); err != nil {
out.Config = map[string]interface{}{}
}
return &out, nil
}
func (p *Postgres) ListTenantProducts(ctx context.Context, tenantID string) ([]TenantProduct, error) {
// First confirm tenant exists so we can return ErrNotFound consistent with Memory.
if _, err := p.GetTenant(ctx, tenantID); err != nil {
return nil, err
}
rows, err := p.pool.Query(ctx, `
SELECT tenant_id::text, product, enabled, config, expires_at, created_at, updated_at
FROM tenant_products WHERE tenant_id = $1::uuid ORDER BY product`, tenantID)
if err != nil {
return nil, err
}
defer rows.Close()
out := []TenantProduct{}
for rows.Next() {
var tp TenantProduct
var rawCfg []byte
var expires *time.Time
if err := rows.Scan(&tp.TenantID, &tp.Product, &tp.Enabled, &rawCfg, &expires, &tp.CreatedAt, &tp.UpdatedAt); err != nil {
return nil, err
}
tp.ExpiresAt = expires
if err := json.Unmarshal(rawCfg, &tp.Config); err != nil {
tp.Config = map[string]interface{}{}
}
out = append(out, tp)
}
return out, rows.Err()
}
// ─── api keys ─────────────────────────────────────────────────────────────
func (p *Postgres) CreateAPIKey(ctx context.Context, in APIKeyCreate) (*APIKey, error) {
var product any
if in.Product != "" {
product = in.Product
}
var createdBy any
if in.CreatedBy != "" {
createdBy = in.CreatedBy
}
// Coerce nil to an empty slice — the schema's NOT NULL DEFAULT only
// fires when the column is omitted, not when an explicit NULL is sent.
scopes := in.Scopes
if scopes == nil {
scopes = []string{}
}
row := p.pool.QueryRow(ctx, `
INSERT INTO api_keys (tenant_id, product, name, scopes, hash, prefix, created_by)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7)
RETURNING id::text, tenant_id::text, COALESCE(product,''), name, scopes, prefix,
COALESCE(created_by,''), last_used_at, revoked_at, created_at`,
in.TenantID, product, in.Name, scopes, in.Hash, in.Prefix, createdBy,
)
k, err := scanAPIKey(row)
if err != nil {
if isUniqueViolation(err) {
return nil, ErrConflict
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == pgerrcode.ForeignKeyViolation {
return nil, ErrNotFound
}
return nil, err
}
return k, nil
}
func scanAPIKey(row pgx.Row) (*APIKey, error) {
var k APIKey
var lastUsed, revoked *time.Time
err := row.Scan(&k.ID, &k.TenantID, &k.Product, &k.Name, &k.Scopes, &k.Prefix,
&k.CreatedBy, &lastUsed, &revoked, &k.CreatedAt)
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
k.LastUsedAt = lastUsed
k.RevokedAt = revoked
return &k, nil
}
func (p *Postgres) FindAPIKeyByPrefix(ctx context.Context, prefix string) (*APIKey, string, error) {
row := p.pool.QueryRow(ctx, `
SELECT id::text, tenant_id::text, COALESCE(product,''), name, scopes, prefix,
COALESCE(created_by,''), last_used_at, revoked_at, created_at, hash
FROM api_keys WHERE prefix = $1 AND revoked_at IS NULL`, prefix)
var k APIKey
var lastUsed, revoked *time.Time
var hash string
err := row.Scan(&k.ID, &k.TenantID, &k.Product, &k.Name, &k.Scopes, &k.Prefix,
&k.CreatedBy, &lastUsed, &revoked, &k.CreatedAt, &hash)
if errors.Is(err, pgx.ErrNoRows) {
return nil, "", ErrNotFound
}
if err != nil {
return nil, "", err
}
k.LastUsedAt = lastUsed
k.RevokedAt = revoked
return &k, hash, nil
}
func (p *Postgres) TouchAPIKeyUsed(ctx context.Context, id string) error {
tag, err := p.pool.Exec(ctx, `UPDATE api_keys SET last_used_at = NOW() WHERE id = $1::uuid`, id)
if err != nil {
return err
}
if tag.RowsAffected() == 0 {
return ErrNotFound
}
return nil
}
func (p *Postgres) RevokeAPIKey(ctx context.Context, id string) error {
tag, err := p.pool.Exec(ctx, `UPDATE api_keys SET revoked_at = NOW() WHERE id = $1::uuid AND revoked_at IS NULL`, id)
if err != nil {
return err
}
if tag.RowsAffected() == 0 {
return ErrNotFound
}
return nil
}
func (p *Postgres) ListAPIKeys(ctx context.Context, tenantID string) ([]APIKey, error) {
if _, err := p.GetTenant(ctx, tenantID); err != nil {
return nil, err
}
rows, err := p.pool.Query(ctx, `
SELECT id::text, tenant_id::text, COALESCE(product,''), name, scopes, prefix,
COALESCE(created_by,''), last_used_at, revoked_at, created_at
FROM api_keys WHERE tenant_id = $1::uuid ORDER BY created_at DESC`, tenantID)
if err != nil {
return nil, err
}
defer rows.Close()
out := []APIKey{}
for rows.Next() {
var k APIKey
var lastUsed, revoked *time.Time
if err := rows.Scan(&k.ID, &k.TenantID, &k.Product, &k.Name, &k.Scopes, &k.Prefix,
&k.CreatedBy, &lastUsed, &revoked, &k.CreatedAt); err != nil {
return nil, err
}
k.LastUsedAt = lastUsed
k.RevokedAt = revoked
out = append(out, k)
}
return out, rows.Err()
}
// ─── audit ────────────────────────────────────────────────────────────────
func (p *Postgres) AppendAudit(ctx context.Context, ev AuditEvent) (*AuditEvent, error) {
meta, _ := json.Marshal(ev.Metadata)
if meta == nil || string(meta) == "null" {
meta = []byte("{}")
}
row := p.pool.QueryRow(ctx, `
INSERT INTO audit_log
(tenant_id, project_id, actor_id, actor_name, actor_type, action,
target_id, target_type, target_name, product, metadata, source_ip, user_agent)
VALUES
(NULLIF($1,'')::uuid, NULLIF($2,'')::uuid, NULLIF($3,''), NULLIF($4,''), NULLIF($5,''), $6,
NULLIF($7,''), NULLIF($8,''), NULLIF($9,''), NULLIF($10,''), $11::jsonb,
NULLIF($12,'')::inet, NULLIF($13,''))
RETURNING id, COALESCE(tenant_id::text,''), COALESCE(project_id::text,''),
COALESCE(actor_id,''), COALESCE(actor_name,''), COALESCE(actor_type,''),
action, COALESCE(target_id,''), COALESCE(target_type,''), COALESCE(target_name,''),
COALESCE(product,''), metadata, COALESCE(host(source_ip),''), COALESCE(user_agent,''),
created_at`,
ev.TenantID, ev.ProjectID, ev.ActorID, ev.ActorName, ev.ActorType, ev.Action,
ev.TargetID, ev.TargetType, ev.TargetName, ev.Product, meta, ev.SourceIP, ev.UserAgent,
)
out, err := scanAudit(row)
if err != nil {
return nil, err
}
return out, nil
}
func scanAudit(row pgx.Row) (*AuditEvent, error) {
var ev AuditEvent
var rawMeta []byte
err := row.Scan(
&ev.ID, &ev.TenantID, &ev.ProjectID,
&ev.ActorID, &ev.ActorName, &ev.ActorType,
&ev.Action, &ev.TargetID, &ev.TargetType, &ev.TargetName,
&ev.Product, &rawMeta, &ev.SourceIP, &ev.UserAgent,
&ev.CreatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if len(rawMeta) > 0 {
if err := json.Unmarshal(rawMeta, &ev.Metadata); err != nil {
ev.Metadata = map[string]interface{}{}
}
}
return &ev, nil
}
func (p *Postgres) ListAudit(ctx context.Context, f AuditFilter) ([]AuditEvent, int64, error) {
limit := f.Limit
if limit <= 0 || limit > 500 {
limit = 100
}
// Build WHERE clauses dynamically — keep param indices stable.
where := []string{"id > $1"}
args := []any{f.Cursor}
add := func(clause string, v any) {
args = append(args, v)
where = append(where, fmt.Sprintf(clause, len(args)))
}
if f.TenantID != "" {
add("tenant_id = $%d::uuid", f.TenantID)
}
if f.Product != "" {
add("product = $%d", f.Product)
}
if f.ActorID != "" {
add("actor_id = $%d", f.ActorID)
}
if f.Action != "" {
add("action = $%d", f.Action)
}
if f.Since != nil {
add("created_at >= $%d", *f.Since)
}
if f.Until != nil {
add("created_at <= $%d", *f.Until)
}
args = append(args, limit)
sql := `
SELECT id, COALESCE(tenant_id::text,''), COALESCE(project_id::text,''),
COALESCE(actor_id,''), COALESCE(actor_name,''), COALESCE(actor_type,''),
action, COALESCE(target_id,''), COALESCE(target_type,''), COALESCE(target_name,''),
COALESCE(product,''), metadata, COALESCE(host(source_ip),''), COALESCE(user_agent,''),
created_at
FROM audit_log
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY id ASC
LIMIT $` + fmt.Sprintf("%d", len(args))
rows, err := p.pool.Query(ctx, sql, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
out := []AuditEvent{}
for rows.Next() {
ev, err := scanAudit(rows)
if err != nil {
return nil, 0, err
}
out = append(out, *ev)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
var nextCursor int64
if len(out) == limit && len(out) > 0 {
nextCursor = out[len(out)-1].ID
}
return out, nextCursor, nil
}
func nullableStr(p *string) any {
if p == nil {
return nil
}
return *p
}
func nullableTime(p *time.Time) any {
if p == nil {
return nil
}
return *p
}