package training import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" ) // ============================================================================ // Block Config CRUD // ============================================================================ // CreateBlockConfig creates a new training block configuration func (s *Store) CreateBlockConfig(ctx context.Context, config *TrainingBlockConfig) error { config.ID = uuid.New() config.CreatedAt = time.Now().UTC() config.UpdatedAt = config.CreatedAt if !config.IsActive { config.IsActive = true } _, err := s.pool.Exec(ctx, ` INSERT INTO training_block_configs ( id, tenant_id, name, description, domain_filter, category_filter, severity_filter, target_audience_filter, regulation_area, module_code_prefix, frequency_type, duration_minutes, pass_threshold, max_controls_per_module, is_active, created_at, updated_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 ) `, config.ID, config.TenantID, config.Name, config.Description, nilIfEmpty(config.DomainFilter), nilIfEmpty(config.CategoryFilter), nilIfEmpty(config.SeverityFilter), nilIfEmpty(config.TargetAudienceFilter), string(config.RegulationArea), config.ModuleCodePrefix, string(config.FrequencyType), config.DurationMinutes, config.PassThreshold, config.MaxControlsPerModule, config.IsActive, config.CreatedAt, config.UpdatedAt, ) return err } // GetBlockConfig retrieves a block config by ID func (s *Store) GetBlockConfig(ctx context.Context, id uuid.UUID) (*TrainingBlockConfig, error) { var config TrainingBlockConfig var regulationArea, frequencyType string var domainFilter, categoryFilter, severityFilter, targetAudienceFilter *string err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, name, description, domain_filter, category_filter, severity_filter, target_audience_filter, regulation_area, module_code_prefix, frequency_type, duration_minutes, pass_threshold, max_controls_per_module, is_active, last_generated_at, created_at, updated_at FROM training_block_configs WHERE id = $1 `, id).Scan( &config.ID, &config.TenantID, &config.Name, &config.Description, &domainFilter, &categoryFilter, &severityFilter, &targetAudienceFilter, ®ulationArea, &config.ModuleCodePrefix, &frequencyType, &config.DurationMinutes, &config.PassThreshold, &config.MaxControlsPerModule, &config.IsActive, &config.LastGeneratedAt, &config.CreatedAt, &config.UpdatedAt, ) if err == pgx.ErrNoRows { return nil, nil } if err != nil { return nil, err } config.RegulationArea = RegulationArea(regulationArea) config.FrequencyType = FrequencyType(frequencyType) if domainFilter != nil { config.DomainFilter = *domainFilter } if categoryFilter != nil { config.CategoryFilter = *categoryFilter } if severityFilter != nil { config.SeverityFilter = *severityFilter } if targetAudienceFilter != nil { config.TargetAudienceFilter = *targetAudienceFilter } return &config, nil } // ListBlockConfigs returns all block configs for a tenant func (s *Store) ListBlockConfigs(ctx context.Context, tenantID uuid.UUID) ([]TrainingBlockConfig, error) { rows, err := s.pool.Query(ctx, ` SELECT id, tenant_id, name, description, domain_filter, category_filter, severity_filter, target_audience_filter, regulation_area, module_code_prefix, frequency_type, duration_minutes, pass_threshold, max_controls_per_module, is_active, last_generated_at, created_at, updated_at FROM training_block_configs WHERE tenant_id = $1 ORDER BY created_at DESC `, tenantID) if err != nil { return nil, err } defer rows.Close() var configs []TrainingBlockConfig for rows.Next() { var config TrainingBlockConfig var regulationArea, frequencyType string var domainFilter, categoryFilter, severityFilter, targetAudienceFilter *string if err := rows.Scan( &config.ID, &config.TenantID, &config.Name, &config.Description, &domainFilter, &categoryFilter, &severityFilter, &targetAudienceFilter, ®ulationArea, &config.ModuleCodePrefix, &frequencyType, &config.DurationMinutes, &config.PassThreshold, &config.MaxControlsPerModule, &config.IsActive, &config.LastGeneratedAt, &config.CreatedAt, &config.UpdatedAt, ); err != nil { return nil, err } config.RegulationArea = RegulationArea(regulationArea) config.FrequencyType = FrequencyType(frequencyType) if domainFilter != nil { config.DomainFilter = *domainFilter } if categoryFilter != nil { config.CategoryFilter = *categoryFilter } if severityFilter != nil { config.SeverityFilter = *severityFilter } if targetAudienceFilter != nil { config.TargetAudienceFilter = *targetAudienceFilter } configs = append(configs, config) } if configs == nil { configs = []TrainingBlockConfig{} } return configs, nil } // UpdateBlockConfig updates a block config func (s *Store) UpdateBlockConfig(ctx context.Context, config *TrainingBlockConfig) error { config.UpdatedAt = time.Now().UTC() _, err := s.pool.Exec(ctx, ` UPDATE training_block_configs SET name = $2, description = $3, domain_filter = $4, category_filter = $5, severity_filter = $6, target_audience_filter = $7, max_controls_per_module = $8, duration_minutes = $9, pass_threshold = $10, is_active = $11, updated_at = $12 WHERE id = $1 `, config.ID, config.Name, config.Description, nilIfEmpty(config.DomainFilter), nilIfEmpty(config.CategoryFilter), nilIfEmpty(config.SeverityFilter), nilIfEmpty(config.TargetAudienceFilter), config.MaxControlsPerModule, config.DurationMinutes, config.PassThreshold, config.IsActive, config.UpdatedAt, ) return err } // DeleteBlockConfig deletes a block config (cascades to control links) func (s *Store) DeleteBlockConfig(ctx context.Context, id uuid.UUID) error { _, err := s.pool.Exec(ctx, `DELETE FROM training_block_configs WHERE id = $1`, id) return err } // UpdateBlockConfigLastGenerated updates the last_generated_at timestamp func (s *Store) UpdateBlockConfigLastGenerated(ctx context.Context, id uuid.UUID) error { now := time.Now().UTC() _, err := s.pool.Exec(ctx, ` UPDATE training_block_configs SET last_generated_at = $2, updated_at = $2 WHERE id = $1 `, id, now) return err } // ============================================================================ // Block Control Links // ============================================================================ // CreateBlockControlLink creates a link between a block config, a module, and a control func (s *Store) CreateBlockControlLink(ctx context.Context, link *TrainingBlockControlLink) error { link.ID = uuid.New() link.CreatedAt = time.Now().UTC() requirements, _ := json.Marshal(link.ControlRequirements) _, err := s.pool.Exec(ctx, ` INSERT INTO training_block_control_links ( id, block_config_id, module_id, control_id, control_title, control_objective, control_requirements, sort_order, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) `, link.ID, link.BlockConfigID, link.ModuleID, link.ControlID, link.ControlTitle, link.ControlObjective, requirements, link.SortOrder, link.CreatedAt, ) return err } // GetControlLinksForBlock returns all control links for a block config func (s *Store) GetControlLinksForBlock(ctx context.Context, blockConfigID uuid.UUID) ([]TrainingBlockControlLink, error) { rows, err := s.pool.Query(ctx, ` SELECT id, block_config_id, module_id, control_id, control_title, control_objective, control_requirements, sort_order, created_at FROM training_block_control_links WHERE block_config_id = $1 ORDER BY sort_order `, blockConfigID) if err != nil { return nil, err } defer rows.Close() var links []TrainingBlockControlLink for rows.Next() { var link TrainingBlockControlLink var requirements []byte if err := rows.Scan( &link.ID, &link.BlockConfigID, &link.ModuleID, &link.ControlID, &link.ControlTitle, &link.ControlObjective, &requirements, &link.SortOrder, &link.CreatedAt, ); err != nil { return nil, err } json.Unmarshal(requirements, &link.ControlRequirements) if link.ControlRequirements == nil { link.ControlRequirements = []string{} } links = append(links, link) } if links == nil { links = []TrainingBlockControlLink{} } return links, nil } // GetControlLinksForModule returns all control links for a specific module func (s *Store) GetControlLinksForModule(ctx context.Context, moduleID uuid.UUID) ([]TrainingBlockControlLink, error) { rows, err := s.pool.Query(ctx, ` SELECT id, block_config_id, module_id, control_id, control_title, control_objective, control_requirements, sort_order, created_at FROM training_block_control_links WHERE module_id = $1 ORDER BY sort_order `, moduleID) if err != nil { return nil, err } defer rows.Close() var links []TrainingBlockControlLink for rows.Next() { var link TrainingBlockControlLink var requirements []byte if err := rows.Scan( &link.ID, &link.BlockConfigID, &link.ModuleID, &link.ControlID, &link.ControlTitle, &link.ControlObjective, &requirements, &link.SortOrder, &link.CreatedAt, ); err != nil { return nil, err } json.Unmarshal(requirements, &link.ControlRequirements) if link.ControlRequirements == nil { link.ControlRequirements = []string{} } links = append(links, link) } if links == nil { links = []TrainingBlockControlLink{} } return links, nil } // ============================================================================ // Canonical Controls Query (reads from shared DB table) // ============================================================================ // QueryCanonicalControls queries canonical_controls with dynamic filters. // Domain is derived from the control_id prefix (e.g. "AUTH" from "AUTH-042"). func (s *Store) QueryCanonicalControls(ctx context.Context, domain, category, severity, targetAudience string, ) ([]CanonicalControlSummary, error) { query := `SELECT control_id, title, objective, rationale, requirements, severity, COALESCE(category, ''), COALESCE(target_audience, ''), COALESCE(tags, '[]') FROM canonical_controls WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true` args := []interface{}{} argIdx := 1 if domain != "" { query += fmt.Sprintf(` AND LEFT(control_id, %d) = $%d`, len(domain), argIdx) args = append(args, domain) argIdx++ } if category != "" { query += fmt.Sprintf(` AND category = $%d`, argIdx) args = append(args, category) argIdx++ } if severity != "" { query += fmt.Sprintf(` AND severity = $%d`, argIdx) args = append(args, severity) argIdx++ } if targetAudience != "" { query += fmt.Sprintf(` AND (target_audience = $%d OR target_audience = 'all')`, argIdx) args = append(args, targetAudience) argIdx++ } query += ` ORDER BY control_id` rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("query canonical controls: %w", err) } defer rows.Close() var controls []CanonicalControlSummary for rows.Next() { var c CanonicalControlSummary var requirementsJSON, tagsJSON []byte if err := rows.Scan( &c.ControlID, &c.Title, &c.Objective, &c.Rationale, &requirementsJSON, &c.Severity, &c.Category, &c.TargetAudience, &tagsJSON, ); err != nil { return nil, err } json.Unmarshal(requirementsJSON, &c.Requirements) if c.Requirements == nil { c.Requirements = []string{} } json.Unmarshal(tagsJSON, &c.Tags) if c.Tags == nil { c.Tags = []string{} } controls = append(controls, c) } if controls == nil { controls = []CanonicalControlSummary{} } return controls, nil } // GetCanonicalControlMeta returns aggregated metadata about canonical controls func (s *Store) GetCanonicalControlMeta(ctx context.Context) (*CanonicalControlMeta, error) { meta := &CanonicalControlMeta{} // Total count err := s.pool.QueryRow(ctx, ` SELECT COUNT(*) FROM canonical_controls WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true `).Scan(&meta.Total) if err != nil { return nil, fmt.Errorf("count canonical controls: %w", err) } // Domains (derived from control_id prefix) rows, err := s.pool.Query(ctx, ` SELECT LEFT(control_id, POSITION('-' IN control_id) - 1) AS domain, COUNT(*) AS cnt FROM canonical_controls WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true AND POSITION('-' IN control_id) > 0 GROUP BY domain ORDER BY cnt DESC `) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var d DomainCount if err := rows.Scan(&d.Domain, &d.Count); err != nil { return nil, err } meta.Domains = append(meta.Domains, d) } if meta.Domains == nil { meta.Domains = []DomainCount{} } // Categories catRows, err := s.pool.Query(ctx, ` SELECT COALESCE(category, 'uncategorized') AS cat, COUNT(*) AS cnt FROM canonical_controls WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true GROUP BY cat ORDER BY cnt DESC `) if err != nil { return nil, err } defer catRows.Close() for catRows.Next() { var c CategoryCount if err := catRows.Scan(&c.Category, &c.Count); err != nil { return nil, err } meta.Categories = append(meta.Categories, c) } if meta.Categories == nil { meta.Categories = []CategoryCount{} } // Target audiences audRows, err := s.pool.Query(ctx, ` SELECT COALESCE(target_audience, 'unset') AS aud, COUNT(*) AS cnt FROM canonical_controls WHERE release_state NOT IN ('deprecated', 'draft') AND customer_visible = true GROUP BY aud ORDER BY cnt DESC `) if err != nil { return nil, err } defer audRows.Close() for audRows.Next() { var a AudienceCount if err := audRows.Scan(&a.Audience, &a.Count); err != nil { return nil, err } meta.Audiences = append(meta.Audiences, a) } if meta.Audiences == nil { meta.Audiences = []AudienceCount{} } return meta, nil } // ============================================================================ // Helpers // ============================================================================ // CountModulesWithPrefix counts existing modules with a given code prefix for auto-numbering func (s *Store) CountModulesWithPrefix(ctx context.Context, tenantID uuid.UUID, prefix string) (int, error) { var count int err := s.pool.QueryRow(ctx, ` SELECT COUNT(*) FROM training_modules WHERE tenant_id = $1 AND module_code LIKE $2 `, tenantID, prefix+"-%").Scan(&count) return count, err } func nilIfEmpty(s string) *string { s = strings.TrimSpace(s) if s == "" { return nil } return &s }