package usecase import ( "context" "encoding/json" "fmt" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) // Store handles database operations for use-case audits. type Store struct { pool *pgxpool.Pool } // NewStore creates a new Store. func NewStore(pool *pgxpool.Pool) *Store { return &Store{pool: pool} } // ── Audit CRUD ───────────────────────────────────────────────────── // CreateAudit inserts a new audit. func (s *Store) CreateAudit(a *Audit) error { ctx := context.Background() a.ID = uuid.New() a.CreatedAt = time.Now() a.UpdatedAt = time.Now() a.Status = StatusDraft questionsJSON, err := json.Marshal(a.Questions) if err != nil { return fmt.Errorf("marshal questions: %w", err) } _, err = s.pool.Exec(ctx, ` INSERT INTO compliance.usecase_audits (id, tenant_id, template_id, name, target_name, status, total_questions, answered_questions, compliance_score, questions, created_at, updated_at) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12)`, a.ID, a.TenantID, a.TemplateID, a.Name, a.TargetName, a.Status, a.TotalQuestions, a.AnsweredQuestions, a.ComplianceScore, questionsJSON, a.CreatedAt, a.UpdatedAt, ) return err } // GetAudit loads an audit by ID. func (s *Store) GetAudit(id uuid.UUID) (*Audit, error) { ctx := context.Background() a := &Audit{} var questionsJSON []byte err := s.pool.QueryRow(ctx, ` SELECT id, tenant_id, template_id, name, target_name, status, total_questions, answered_questions, compliance_score, questions, created_at, updated_at, completed_at FROM compliance.usecase_audits WHERE id = $1`, id, ).Scan( &a.ID, &a.TenantID, &a.TemplateID, &a.Name, &a.TargetName, &a.Status, &a.TotalQuestions, &a.AnsweredQuestions, &a.ComplianceScore, &questionsJSON, &a.CreatedAt, &a.UpdatedAt, &a.CompletedAt, ) if err != nil { return nil, err } if len(questionsJSON) > 0 { json.Unmarshal(questionsJSON, &a.Questions) } return a, nil } // ListAudits returns all audits for a tenant. func (s *Store) ListAudits(tenantID uuid.UUID) ([]Audit, error) { ctx := context.Background() rows, err := s.pool.Query(ctx, ` SELECT id, template_id, name, target_name, status, total_questions, answered_questions, compliance_score, created_at, updated_at, completed_at FROM compliance.usecase_audits WHERE tenant_id = $1 ORDER BY created_at DESC`, tenantID) if err != nil { return nil, err } defer rows.Close() var audits []Audit for rows.Next() { var a Audit a.TenantID = tenantID if err := rows.Scan( &a.ID, &a.TemplateID, &a.Name, &a.TargetName, &a.Status, &a.TotalQuestions, &a.AnsweredQuestions, &a.ComplianceScore, &a.CreatedAt, &a.UpdatedAt, &a.CompletedAt, ); err != nil { return nil, err } audits = append(audits, a) } return audits, nil } // UpdateAuditScore updates the score and status of an audit. func (s *Store) UpdateAuditScore(id uuid.UUID, answered int, score float64, status AuditStatus) error { ctx := context.Background() now := time.Now() query := ` UPDATE compliance.usecase_audits SET answered_questions = $2, compliance_score = $3, status = $4, updated_at = $5` args := []interface{}{id, answered, score, status, now} if status == StatusCompleted { query += `, completed_at = $6 WHERE id = $1` args = append(args, now) } else { query += ` WHERE id = $1` } _, err := s.pool.Exec(ctx, query, args...) return err } // ── Answer CRUD ──────────────────────────────────────────────────── // SaveAnswer upserts an answer (INSERT ... ON CONFLICT UPDATE). func (s *Store) SaveAnswer(a *Answer) error { ctx := context.Background() a.ID = uuid.New() a.AnsweredAt = time.Now() answerJSON, err := json.Marshal(map[string]interface{}{ "value": a.Value, "comment": a.Comment, }) if err != nil { return fmt.Errorf("marshal answer: %w", err) } evidenceJSON, _ := json.Marshal(a.EvidenceIDs) _, err = s.pool.Exec(ctx, ` INSERT INTO compliance.usecase_answers (id, audit_id, question_id, mc_id, answer, evidence_ids, status, answered_at) VALUES ($1,$2,$3,$4,$5,$6,$7,$8) ON CONFLICT (audit_id, question_id) DO UPDATE SET answer = $5, evidence_ids = $6, status = $7, answered_at = $8`, a.ID, a.AuditID, a.QuestionID, a.MCID, answerJSON, evidenceJSON, a.Status, a.AnsweredAt, ) return err } // ListAnswers returns all answers for an audit. func (s *Store) ListAnswers(auditID uuid.UUID) ([]Answer, error) { ctx := context.Background() rows, err := s.pool.Query(ctx, ` SELECT id, audit_id, question_id, mc_id, answer, evidence_ids, status, answered_at FROM compliance.usecase_answers WHERE audit_id = $1 ORDER BY answered_at`, auditID) if err != nil { return nil, err } defer rows.Close() var answers []Answer for rows.Next() { var a Answer var answerJSON, evidenceJSON []byte if err := rows.Scan( &a.ID, &a.AuditID, &a.QuestionID, &a.MCID, &answerJSON, &evidenceJSON, &a.Status, &a.AnsweredAt, ); err != nil { return nil, err } var payload map[string]interface{} if json.Unmarshal(answerJSON, &payload) == nil { a.Value = payload["value"] if c, ok := payload["comment"].(string); ok { a.Comment = c } } json.Unmarshal(evidenceJSON, &a.EvidenceIDs) answers = append(answers, a) } return answers, nil } // ── MC Queries ───────────────────────────────────────────────────── // MCInfo holds minimal data about a Master Control for compilation. type MCInfo struct { MasterControlID string `json:"master_control_id"` CanonicalName string `json:"canonical_name"` TotalControls int `json:"total_controls"` RegSource string `json:"regulation_source"` } // FetchMCsByFilters returns MCs whose canonical_name matches any filter pattern. func (s *Store) FetchMCsByFilters(filters []string) ([]MCInfo, error) { if len(filters) == 0 { return nil, nil } ctx := context.Background() // Build LIKE conditions from filter patterns (support trailing *) conditions := make([]string, len(filters)) args := make([]interface{}, len(filters)) for i, f := range filters { // Convert "third_party_management_*" → "third_party_management_%" pattern := f if len(pattern) > 0 && pattern[len(pattern)-1] == '*' { pattern = pattern[:len(pattern)-1] + "%" } conditions[i] = fmt.Sprintf("mc.canonical_name LIKE $%d", i+1) args[i] = pattern } query := fmt.Sprintf(` SELECT DISTINCT mc.master_control_id, mc.canonical_name, mc.total_controls, COALESCE( (SELECT pc.source_citation::jsonb->>'source' FROM compliance.master_control_members mcm2 JOIN compliance.canonical_controls cc2 ON cc2.id = mcm2.control_uuid LEFT JOIN compliance.canonical_controls pc ON pc.id = cc2.parent_control_uuid WHERE mcm2.master_control_uuid = mc.id AND pc.source_citation IS NOT NULL LIMIT 1), '' ) as regulation_source FROM compliance.master_controls mc WHERE %s ORDER BY mc.total_controls DESC LIMIT 200`, joinOr(conditions)) rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("fetch MCs: %w", err) } defer rows.Close() var mcs []MCInfo for rows.Next() { var m MCInfo if err := rows.Scan(&m.MasterControlID, &m.CanonicalName, &m.TotalControls, &m.RegSource); err != nil { return nil, err } mcs = append(mcs, m) } return mcs, nil } // FetchCheckQuestions loads existing doc_check_controls for MCs. func (s *Store) FetchCheckQuestions(mcIDs []string) (map[string][]CheckQuestion, error) { if len(mcIDs) == 0 { return nil, nil } ctx := context.Background() rows, err := s.pool.Query(ctx, ` SELECT control_id, check_question, pass_criteria, fail_criteria, severity FROM compliance.doc_check_controls WHERE control_id = ANY($1)`, mcIDs) if err != nil { return nil, err } defer rows.Close() result := make(map[string][]CheckQuestion) for rows.Next() { var cq CheckQuestion if err := rows.Scan(&cq.ControlID, &cq.Question, &cq.PassCriteria, &cq.FailCriteria, &cq.Severity); err != nil { return nil, err } result[cq.ControlID] = append(result[cq.ControlID], cq) } return result, nil } // CheckQuestion holds an existing doc_check_control question. type CheckQuestion struct { ControlID string `json:"control_id"` Question string `json:"check_question"` PassCriteria string `json:"pass_criteria"` FailCriteria string `json:"fail_criteria"` Severity string `json:"severity"` } // CountMCSourceCitations counts controls with source_citation per MC. func (s *Store) CountMCSourceCitations(mcIDs []string) (map[string]int, error) { if len(mcIDs) == 0 { return nil, nil } ctx := context.Background() rows, err := s.pool.Query(ctx, ` SELECT mc.master_control_id, COUNT(CASE WHEN cc.source_citation IS NOT NULL AND cc.source_citation != '' THEN 1 END) FROM compliance.master_controls mc JOIN compliance.master_control_members mcm ON mcm.master_control_uuid = mc.id JOIN compliance.canonical_controls cc ON cc.id = mcm.control_uuid WHERE mc.master_control_id = ANY($1) GROUP BY mc.master_control_id`, mcIDs) if err != nil { return nil, err } defer rows.Close() result := make(map[string]int) for rows.Next() { var id string var count int if err := rows.Scan(&id, &count); err != nil { return nil, err } result[id] = count } return result, nil } func joinOr(conditions []string) string { if len(conditions) == 1 { return conditions[0] } result := "(" for i, c := range conditions { if i > 0 { result += " OR " } result += c } return result + ")" }