package audit import ( "context" "encoding/json" "fmt" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) // Store provides database operations for audit logs type Store struct { pool *pgxpool.Pool } // NewStore creates a new audit store func NewStore(pool *pgxpool.Pool) *Store { return &Store{pool: pool} } // LLMAuditEntry represents an LLM audit log entry type LLMAuditEntry struct { ID uuid.UUID `json:"id" db:"id"` TenantID uuid.UUID `json:"tenant_id" db:"tenant_id"` NamespaceID *uuid.UUID `json:"namespace_id,omitempty" db:"namespace_id"` UserID uuid.UUID `json:"user_id" db:"user_id"` SessionID string `json:"session_id,omitempty" db:"session_id"` Operation string `json:"operation" db:"operation"` ModelUsed string `json:"model_used" db:"model_used"` Provider string `json:"provider" db:"provider"` PromptHash string `json:"prompt_hash" db:"prompt_hash"` PromptLength int `json:"prompt_length" db:"prompt_length"` ResponseLength int `json:"response_length,omitempty" db:"response_length"` TokensUsed int `json:"tokens_used" db:"tokens_used"` DurationMS int `json:"duration_ms" db:"duration_ms"` PIIDetected bool `json:"pii_detected" db:"pii_detected"` PIITypesDetected []string `json:"pii_types_detected,omitempty" db:"pii_types_detected"` PIIRedacted bool `json:"pii_redacted" db:"pii_redacted"` PolicyID *uuid.UUID `json:"policy_id,omitempty" db:"policy_id"` PolicyViolations []string `json:"policy_violations,omitempty" db:"policy_violations"` DataCategoriesAccessed []string `json:"data_categories_accessed,omitempty" db:"data_categories_accessed"` ErrorMessage string `json:"error_message,omitempty" db:"error_message"` RequestMetadata map[string]any `json:"request_metadata,omitempty" db:"request_metadata"` CreatedAt time.Time `json:"created_at" db:"created_at"` } // GeneralAuditEntry represents a general audit trail entry type GeneralAuditEntry struct { ID uuid.UUID `json:"id" db:"id"` TenantID uuid.UUID `json:"tenant_id" db:"tenant_id"` NamespaceID *uuid.UUID `json:"namespace_id,omitempty" db:"namespace_id"` UserID uuid.UUID `json:"user_id" db:"user_id"` Action string `json:"action" db:"action"` ResourceType string `json:"resource_type" db:"resource_type"` ResourceID *uuid.UUID `json:"resource_id,omitempty" db:"resource_id"` OldValues map[string]any `json:"old_values,omitempty" db:"old_values"` NewValues map[string]any `json:"new_values,omitempty" db:"new_values"` IPAddress string `json:"ip_address,omitempty" db:"ip_address"` UserAgent string `json:"user_agent,omitempty" db:"user_agent"` Reason string `json:"reason,omitempty" db:"reason"` CreatedAt time.Time `json:"created_at" db:"created_at"` } // CreateLLMAuditEntry creates a new LLM audit log entry func (s *Store) CreateLLMAuditEntry(ctx context.Context, entry *LLMAuditEntry) error { if entry.ID == uuid.Nil { entry.ID = uuid.New() } if entry.CreatedAt.IsZero() { entry.CreatedAt = time.Now().UTC() } metadataJSON, _ := json.Marshal(entry.RequestMetadata) _, err := s.pool.Exec(ctx, ` INSERT INTO compliance_llm_audit_log ( id, tenant_id, namespace_id, user_id, session_id, operation, model_used, provider, prompt_hash, prompt_length, response_length, tokens_used, duration_ms, pii_detected, pii_types_detected, pii_redacted, policy_id, policy_violations, data_categories_accessed, error_message, request_metadata, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22) `, entry.ID, entry.TenantID, entry.NamespaceID, entry.UserID, entry.SessionID, entry.Operation, entry.ModelUsed, entry.Provider, entry.PromptHash, entry.PromptLength, entry.ResponseLength, entry.TokensUsed, entry.DurationMS, entry.PIIDetected, entry.PIITypesDetected, entry.PIIRedacted, entry.PolicyID, entry.PolicyViolations, entry.DataCategoriesAccessed, entry.ErrorMessage, metadataJSON, entry.CreatedAt, ) return err } // CreateGeneralAuditEntry creates a new general audit entry func (s *Store) CreateGeneralAuditEntry(ctx context.Context, entry *GeneralAuditEntry) error { if entry.ID == uuid.Nil { entry.ID = uuid.New() } if entry.CreatedAt.IsZero() { entry.CreatedAt = time.Now().UTC() } oldValuesJSON, _ := json.Marshal(entry.OldValues) newValuesJSON, _ := json.Marshal(entry.NewValues) _, err := s.pool.Exec(ctx, ` INSERT INTO compliance_audit_trail ( id, tenant_id, namespace_id, user_id, action, resource_type, resource_id, old_values, new_values, ip_address, user_agent, reason, created_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) `, entry.ID, entry.TenantID, entry.NamespaceID, entry.UserID, entry.Action, entry.ResourceType, entry.ResourceID, oldValuesJSON, newValuesJSON, entry.IPAddress, entry.UserAgent, entry.Reason, entry.CreatedAt, ) return err } // LLMAuditFilter defines filters for LLM audit queries type LLMAuditFilter struct { TenantID uuid.UUID NamespaceID *uuid.UUID UserID *uuid.UUID Operation string Model string PIIDetected *bool HasViolations *bool StartDate *time.Time EndDate *time.Time Limit int Offset int } // QueryLLMAuditEntries queries LLM audit entries with filters func (s *Store) QueryLLMAuditEntries(ctx context.Context, filter *LLMAuditFilter) ([]*LLMAuditEntry, int, error) { query := ` SELECT id, tenant_id, namespace_id, user_id, session_id, operation, model_used, provider, prompt_hash, prompt_length, response_length, tokens_used, duration_ms, pii_detected, pii_types_detected, pii_redacted, policy_id, policy_violations, data_categories_accessed, error_message, request_metadata, created_at FROM compliance_llm_audit_log WHERE tenant_id = $1 ` countQuery := `SELECT COUNT(*) FROM compliance_llm_audit_log WHERE tenant_id = $1` args := []any{filter.TenantID} argIndex := 2 if filter.NamespaceID != nil { query += fmt.Sprintf(" AND namespace_id = $%d", argIndex) countQuery += fmt.Sprintf(" AND namespace_id = $%d", argIndex) args = append(args, *filter.NamespaceID) argIndex++ } if filter.UserID != nil { query += fmt.Sprintf(" AND user_id = $%d", argIndex) countQuery += fmt.Sprintf(" AND user_id = $%d", argIndex) args = append(args, *filter.UserID) argIndex++ } if filter.Operation != "" { query += fmt.Sprintf(" AND operation = $%d", argIndex) countQuery += fmt.Sprintf(" AND operation = $%d", argIndex) args = append(args, filter.Operation) argIndex++ } if filter.Model != "" { query += fmt.Sprintf(" AND model_used = $%d", argIndex) countQuery += fmt.Sprintf(" AND model_used = $%d", argIndex) args = append(args, filter.Model) argIndex++ } if filter.PIIDetected != nil { query += fmt.Sprintf(" AND pii_detected = $%d", argIndex) countQuery += fmt.Sprintf(" AND pii_detected = $%d", argIndex) args = append(args, *filter.PIIDetected) argIndex++ } if filter.HasViolations != nil && *filter.HasViolations { query += " AND array_length(policy_violations, 1) > 0" countQuery += " AND array_length(policy_violations, 1) > 0" } if filter.StartDate != nil { query += fmt.Sprintf(" AND created_at >= $%d", argIndex) countQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex) args = append(args, *filter.StartDate) argIndex++ } if filter.EndDate != nil { query += fmt.Sprintf(" AND created_at <= $%d", argIndex) countQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex) args = append(args, *filter.EndDate) argIndex++ } // Get total count var totalCount int if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&totalCount); err != nil { return nil, 0, err } // Add ordering and pagination query += " ORDER BY created_at DESC" if filter.Limit > 0 { query += fmt.Sprintf(" LIMIT $%d", argIndex) args = append(args, filter.Limit) argIndex++ } if filter.Offset > 0 { query += fmt.Sprintf(" OFFSET $%d", argIndex) args = append(args, filter.Offset) } rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, 0, err } defer rows.Close() var entries []*LLMAuditEntry for rows.Next() { var entry LLMAuditEntry var metadataJSON []byte err := rows.Scan( &entry.ID, &entry.TenantID, &entry.NamespaceID, &entry.UserID, &entry.SessionID, &entry.Operation, &entry.ModelUsed, &entry.Provider, &entry.PromptHash, &entry.PromptLength, &entry.ResponseLength, &entry.TokensUsed, &entry.DurationMS, &entry.PIIDetected, &entry.PIITypesDetected, &entry.PIIRedacted, &entry.PolicyID, &entry.PolicyViolations, &entry.DataCategoriesAccessed, &entry.ErrorMessage, &metadataJSON, &entry.CreatedAt, ) if err != nil { continue } if metadataJSON != nil { json.Unmarshal(metadataJSON, &entry.RequestMetadata) } entries = append(entries, &entry) } return entries, totalCount, nil } // GeneralAuditFilter defines filters for general audit queries type GeneralAuditFilter struct { TenantID uuid.UUID NamespaceID *uuid.UUID UserID *uuid.UUID Action string ResourceType string ResourceID *uuid.UUID StartDate *time.Time EndDate *time.Time Limit int Offset int } // QueryGeneralAuditEntries queries general audit entries with filters func (s *Store) QueryGeneralAuditEntries(ctx context.Context, filter *GeneralAuditFilter) ([]*GeneralAuditEntry, int, error) { query := ` SELECT id, tenant_id, namespace_id, user_id, action, resource_type, resource_id, old_values, new_values, ip_address, user_agent, reason, created_at FROM compliance_audit_trail WHERE tenant_id = $1 ` countQuery := `SELECT COUNT(*) FROM compliance_audit_trail WHERE tenant_id = $1` args := []any{filter.TenantID} argIndex := 2 if filter.NamespaceID != nil { query += fmt.Sprintf(" AND namespace_id = $%d", argIndex) countQuery += fmt.Sprintf(" AND namespace_id = $%d", argIndex) args = append(args, *filter.NamespaceID) argIndex++ } if filter.UserID != nil { query += fmt.Sprintf(" AND user_id = $%d", argIndex) countQuery += fmt.Sprintf(" AND user_id = $%d", argIndex) args = append(args, *filter.UserID) argIndex++ } if filter.Action != "" { query += fmt.Sprintf(" AND action = $%d", argIndex) countQuery += fmt.Sprintf(" AND action = $%d", argIndex) args = append(args, filter.Action) argIndex++ } if filter.ResourceType != "" { query += fmt.Sprintf(" AND resource_type = $%d", argIndex) countQuery += fmt.Sprintf(" AND resource_type = $%d", argIndex) args = append(args, filter.ResourceType) argIndex++ } if filter.ResourceID != nil { query += fmt.Sprintf(" AND resource_id = $%d", argIndex) countQuery += fmt.Sprintf(" AND resource_id = $%d", argIndex) args = append(args, *filter.ResourceID) argIndex++ } if filter.StartDate != nil { query += fmt.Sprintf(" AND created_at >= $%d", argIndex) countQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex) args = append(args, *filter.StartDate) argIndex++ } if filter.EndDate != nil { query += fmt.Sprintf(" AND created_at <= $%d", argIndex) countQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex) args = append(args, *filter.EndDate) argIndex++ } // Get total count var totalCount int if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&totalCount); err != nil { return nil, 0, err } // Add ordering and pagination query += " ORDER BY created_at DESC" if filter.Limit > 0 { query += fmt.Sprintf(" LIMIT $%d", argIndex) args = append(args, filter.Limit) argIndex++ } if filter.Offset > 0 { query += fmt.Sprintf(" OFFSET $%d", argIndex) args = append(args, filter.Offset) } rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, 0, err } defer rows.Close() var entries []*GeneralAuditEntry for rows.Next() { var entry GeneralAuditEntry var oldValuesJSON, newValuesJSON []byte err := rows.Scan( &entry.ID, &entry.TenantID, &entry.NamespaceID, &entry.UserID, &entry.Action, &entry.ResourceType, &entry.ResourceID, &oldValuesJSON, &newValuesJSON, &entry.IPAddress, &entry.UserAgent, &entry.Reason, &entry.CreatedAt, ) if err != nil { continue } if oldValuesJSON != nil { json.Unmarshal(oldValuesJSON, &entry.OldValues) } if newValuesJSON != nil { json.Unmarshal(newValuesJSON, &entry.NewValues) } entries = append(entries, &entry) } return entries, totalCount, nil } // GetLLMUsageStats retrieves aggregated LLM usage statistics func (s *Store) GetLLMUsageStats(ctx context.Context, tenantID uuid.UUID, startDate, endDate time.Time) (*LLMUsageStats, error) { var stats LLMUsageStats err := s.pool.QueryRow(ctx, ` SELECT COUNT(*) as total_requests, COALESCE(SUM(tokens_used), 0) as total_tokens, COALESCE(SUM(duration_ms), 0) as total_duration_ms, COUNT(*) FILTER (WHERE pii_detected = TRUE) as requests_with_pii, COUNT(*) FILTER (WHERE array_length(policy_violations, 1) > 0) as policy_violations FROM compliance_llm_audit_log WHERE tenant_id = $1 AND created_at >= $2 AND created_at <= $3 `, tenantID, startDate, endDate).Scan( &stats.TotalRequests, &stats.TotalTokens, &stats.TotalDurationMS, &stats.RequestsWithPII, &stats.PolicyViolations, ) if err != nil { return nil, err } // Get model usage breakdown rows, err := s.pool.Query(ctx, ` SELECT model_used, COUNT(*) as count FROM compliance_llm_audit_log WHERE tenant_id = $1 AND created_at >= $2 AND created_at <= $3 GROUP BY model_used `, tenantID, startDate, endDate) if err != nil { return nil, err } defer rows.Close() stats.ModelsUsed = make(map[string]int) for rows.Next() { var model string var count int if err := rows.Scan(&model, &count); err == nil { stats.ModelsUsed[model] = count } } return &stats, nil } // LLMUsageStats represents aggregated LLM usage statistics type LLMUsageStats struct { TotalRequests int `json:"total_requests"` TotalTokens int `json:"total_tokens"` TotalDurationMS int64 `json:"total_duration_ms"` RequestsWithPII int `json:"requests_with_pii"` PolicyViolations int `json:"policy_violations"` ModelsUsed map[string]int `json:"models_used"` } // CleanupOldEntries removes audit entries older than the retention period func (s *Store) CleanupOldEntries(ctx context.Context, retentionDays int) (int, int, error) { cutoff := time.Now().UTC().AddDate(0, 0, -retentionDays) // Cleanup LLM audit log llmResult, err := s.pool.Exec(ctx, ` DELETE FROM compliance_llm_audit_log WHERE created_at < $1 `, cutoff) if err != nil { return 0, 0, err } // Cleanup general audit trail generalResult, err := s.pool.Exec(ctx, ` DELETE FROM compliance_audit_trail WHERE created_at < $1 `, cutoff) if err != nil { return int(llmResult.RowsAffected()), 0, err } return int(llmResult.RowsAffected()), int(generalResult.RowsAffected()), nil }