package migrations import ( "context" "database/sql" "errors" "fmt" "testing" "time" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source/iofs" _ "github.com/jackc/pgx/v5/stdlib" // pgx stdlib driver for database/sql tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" ) // startPostgres spins a fresh postgres:16-alpine container and returns its // DSN + a cleanup func. Skips the test if Docker is unreachable. func startPostgres(t *testing.T) (string, func()) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) defer cancel() pgc, err := tcpostgres.Run(ctx, "postgres:16-alpine", tcpostgres.WithDatabase("tenant_registry_test"), tcpostgres.WithUsername("test"), tcpostgres.WithPassword("test"), tcpostgres.BasicWaitStrategies(), ) if err != nil { t.Skipf("skipping: docker unreachable (%v)", err) } dsn, err := pgc.ConnectionString(ctx, "sslmode=disable") if err != nil { _ = pgc.Terminate(context.Background()) t.Fatalf("dsn: %v", err) } cleanup := func() { c, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() _ = pgc.Terminate(c) } return dsn, cleanup } func newMigrator(t *testing.T, dsn string) *migrate.Migrate { t.Helper() src, err := iofs.New(FS, ".") if err != nil { t.Fatal(err) } db, err := sql.Open("pgx", dsn) if err != nil { t.Fatal(err) } driver, err := postgres.WithInstance(db, &postgres.Config{}) if err != nil { t.Fatal(err) } m, err := migrate.NewWithInstance("iofs", src, "postgres", driver) if err != nil { t.Fatal(err) } t.Cleanup(func() { _, _ = m.Close() _ = db.Close() }) return m } func TestMigrate_upDownRoundTrip(t *testing.T) { if testing.Short() { t.Skip("skipping integration test under -short") } dsn, stop := startPostgres(t) defer stop() m := newMigrator(t, dsn) if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { t.Fatalf("up: %v", err) } // Schema assertions — every table the spec requires must exist. db, err := sql.Open("pgx", dsn) if err != nil { t.Fatal(err) } defer func() { _ = db.Close() }() wantTables := []string{ "tenants", "tenant_projects", "tenant_products", "tenant_idp_config", "api_keys", "audit_log", } for _, table := range wantTables { var exists bool err := db.QueryRow( "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema='public' AND table_name=$1)", table, ).Scan(&exists) if err != nil { t.Fatalf("query for table %s: %v", table, err) } if !exists { t.Errorf("table %s missing after migrate up", table) } } // Enum assertions. wantEnums := map[string][]string{ "tenant_status": {"demo", "trial", "active", "frozen", "archived"}, "tenant_kind": {"customer", "demo"}, "idp_kind": {"oidc", "saml"}, "tenant_project_status": {"active", "archived"}, } for enum, values := range wantEnums { rows, err := db.Query( "SELECT e.enumlabel FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid WHERE t.typname = $1 ORDER BY e.enumsortorder", enum, ) if err != nil { t.Fatalf("query enum %s: %v", enum, err) } var got []string for rows.Next() { var v string if err := rows.Scan(&v); err != nil { t.Fatal(err) } got = append(got, v) } _ = rows.Close() if fmt.Sprint(got) != fmt.Sprint(values) { t.Errorf("enum %s = %v, want %v", enum, got, values) } } // Round-trip: down all, then up again — must succeed without leftover state. if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { t.Fatalf("down: %v", err) } var afterDown int err = db.QueryRow( "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema='public' AND table_name = ANY($1)", wantTables, ).Scan(&afterDown) if err != nil { t.Fatal(err) } if afterDown != 0 { t.Errorf("after down: %d tables still present, want 0", afterDown) } if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { t.Fatalf("up after down: %v", err) } } // TestSeed_canInsertAndQuery is the lightweight happy-path: insert a tenant, // give it a project + a product + an api_key + an audit record, query back. // Catches schema-level mistakes (NOT NULL, FK direction, enum cast) that // table-existence checks miss. func TestSeed_canInsertAndQuery(t *testing.T) { if testing.Short() { t.Skip("skipping integration test under -short") } dsn, stop := startPostgres(t) defer stop() m := newMigrator(t, dsn) if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { t.Fatalf("up: %v", err) } db, err := sql.Open("pgx", dsn) if err != nil { t.Fatal(err) } defer func() { _ = db.Close() }() ctx := context.Background() var tid string err = db.QueryRowContext(ctx, `INSERT INTO tenants (slug, name, plan, status, kind) VALUES ($1, $2, 'professional', 'active', 'customer') RETURNING id`, "acme", "Acme Inc.").Scan(&tid) if err != nil { t.Fatalf("insert tenant: %v", err) } if _, err := db.ExecContext(ctx, `INSERT INTO tenant_projects (tenant_id, name, slug) VALUES ($1, $2, $3)`, tid, "Production", "prod"); err != nil { t.Fatalf("insert project: %v", err) } if _, err := db.ExecContext(ctx, `INSERT INTO tenant_products (tenant_id, product, config) VALUES ($1, 'certifai', '{"max_seats":10}'::jsonb)`, tid); err != nil { t.Fatalf("insert product: %v", err) } if _, err := db.ExecContext(ctx, `INSERT INTO api_keys (tenant_id, name, hash, prefix, scopes) VALUES ($1, 'ci-bot', 'argon2-hash', 'bp_12345', ARRAY['certifai:read'])`, tid); err != nil { t.Fatalf("insert api_key: %v", err) } if _, err := db.ExecContext(ctx, `INSERT INTO audit_log (tenant_id, action, actor_id, actor_name, metadata) VALUES ($1, 'tenant.created', 'sys', 'system', '{"source":"test"}'::jsonb)`, tid); err != nil { t.Fatalf("insert audit: %v", err) } // Round-trip read. var slug, status string err = db.QueryRowContext(ctx, `SELECT slug, status::text FROM tenants WHERE id = $1`, tid).Scan(&slug, &status) if err != nil { t.Fatal(err) } if slug != "acme" || status != "active" { t.Errorf("tenant readback: slug=%q status=%q", slug, status) } // FK cascade — delete tenant, projects/products/keys/audit_log handling. if _, err := db.ExecContext(ctx, `DELETE FROM tenants WHERE id = $1`, tid); err != nil { t.Fatalf("delete tenant: %v", err) } var nProjects, nProducts, nKeys int _ = db.QueryRowContext(ctx, `SELECT COUNT(*) FROM tenant_projects WHERE tenant_id = $1`, tid).Scan(&nProjects) _ = db.QueryRowContext(ctx, `SELECT COUNT(*) FROM tenant_products WHERE tenant_id = $1`, tid).Scan(&nProducts) _ = db.QueryRowContext(ctx, `SELECT COUNT(*) FROM api_keys WHERE tenant_id = $1`, tid).Scan(&nKeys) if nProjects != 0 || nProducts != 0 || nKeys != 0 { t.Errorf("FK cascade incomplete: projects=%d products=%d keys=%d", nProjects, nProducts, nKeys) } // audit_log uses ON DELETE SET NULL — tenant_id becomes NULL but row stays var nAudit, nAuditNullTenant int _ = db.QueryRowContext(ctx, `SELECT COUNT(*) FROM audit_log`).Scan(&nAudit) _ = db.QueryRowContext(ctx, `SELECT COUNT(*) FROM audit_log WHERE tenant_id IS NULL`).Scan(&nAuditNullTenant) if nAudit != 1 || nAuditNullTenant != 1 { t.Errorf("audit_log SET NULL: total=%d null=%d, want 1/1", nAudit, nAuditNullTenant) } } func TestSlugConstraint(t *testing.T) { if testing.Short() { t.Skip("skipping integration test under -short") } dsn, stop := startPostgres(t) defer stop() m := newMigrator(t, dsn) if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { t.Fatalf("up: %v", err) } db, err := sql.Open("pgx", dsn) if err != nil { t.Fatal(err) } defer func() { _ = db.Close() }() cases := []struct { slug string wantErr bool }{ {"acme", false}, {"a-c-m-e", false}, {"a1b2c3", false}, {"a", true}, // too short {"-acme", true}, // leading dash {"acme-", true}, // trailing dash {"AcMe", true}, // uppercase {"a_b", true}, // underscore } for _, c := range cases { _, err := db.Exec(`INSERT INTO tenants (slug, name) VALUES ($1, 'X')`, c.slug) gotErr := err != nil if gotErr != c.wantErr { t.Errorf("slug %q: gotErr=%v wantErr=%v (err=%v)", c.slug, gotErr, c.wantErr, err) } } }