Files
Benjamin Boenisch 4435e7ea0a Initial commit: breakpilot-compliance - Compliance SDK Platform
Services: Admin-Compliance, Backend-Compliance,
AI-Compliance-SDK, Consent-SDK, Developer-Portal,
PCA-Platform, DSMS

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:28 +01:00

473 lines
15 KiB
Go

package audit
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
// Store provides database operations for audit logs
type Store struct {
pool *pgxpool.Pool
}
// NewStore creates a new audit store
func NewStore(pool *pgxpool.Pool) *Store {
return &Store{pool: pool}
}
// LLMAuditEntry represents an LLM audit log entry
type LLMAuditEntry struct {
ID uuid.UUID `json:"id" db:"id"`
TenantID uuid.UUID `json:"tenant_id" db:"tenant_id"`
NamespaceID *uuid.UUID `json:"namespace_id,omitempty" db:"namespace_id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
SessionID string `json:"session_id,omitempty" db:"session_id"`
Operation string `json:"operation" db:"operation"`
ModelUsed string `json:"model_used" db:"model_used"`
Provider string `json:"provider" db:"provider"`
PromptHash string `json:"prompt_hash" db:"prompt_hash"`
PromptLength int `json:"prompt_length" db:"prompt_length"`
ResponseLength int `json:"response_length,omitempty" db:"response_length"`
TokensUsed int `json:"tokens_used" db:"tokens_used"`
DurationMS int `json:"duration_ms" db:"duration_ms"`
PIIDetected bool `json:"pii_detected" db:"pii_detected"`
PIITypesDetected []string `json:"pii_types_detected,omitempty" db:"pii_types_detected"`
PIIRedacted bool `json:"pii_redacted" db:"pii_redacted"`
PolicyID *uuid.UUID `json:"policy_id,omitempty" db:"policy_id"`
PolicyViolations []string `json:"policy_violations,omitempty" db:"policy_violations"`
DataCategoriesAccessed []string `json:"data_categories_accessed,omitempty" db:"data_categories_accessed"`
ErrorMessage string `json:"error_message,omitempty" db:"error_message"`
RequestMetadata map[string]any `json:"request_metadata,omitempty" db:"request_metadata"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// GeneralAuditEntry represents a general audit trail entry
type GeneralAuditEntry struct {
ID uuid.UUID `json:"id" db:"id"`
TenantID uuid.UUID `json:"tenant_id" db:"tenant_id"`
NamespaceID *uuid.UUID `json:"namespace_id,omitempty" db:"namespace_id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
Action string `json:"action" db:"action"`
ResourceType string `json:"resource_type" db:"resource_type"`
ResourceID *uuid.UUID `json:"resource_id,omitempty" db:"resource_id"`
OldValues map[string]any `json:"old_values,omitempty" db:"old_values"`
NewValues map[string]any `json:"new_values,omitempty" db:"new_values"`
IPAddress string `json:"ip_address,omitempty" db:"ip_address"`
UserAgent string `json:"user_agent,omitempty" db:"user_agent"`
Reason string `json:"reason,omitempty" db:"reason"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// CreateLLMAuditEntry creates a new LLM audit log entry
func (s *Store) CreateLLMAuditEntry(ctx context.Context, entry *LLMAuditEntry) error {
if entry.ID == uuid.Nil {
entry.ID = uuid.New()
}
if entry.CreatedAt.IsZero() {
entry.CreatedAt = time.Now().UTC()
}
metadataJSON, _ := json.Marshal(entry.RequestMetadata)
_, err := s.pool.Exec(ctx, `
INSERT INTO compliance_llm_audit_log (
id, tenant_id, namespace_id, user_id, session_id,
operation, model_used, provider, prompt_hash, prompt_length, response_length,
tokens_used, duration_ms, pii_detected, pii_types_detected, pii_redacted,
policy_id, policy_violations, data_categories_accessed, error_message,
request_metadata, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
`,
entry.ID, entry.TenantID, entry.NamespaceID, entry.UserID, entry.SessionID,
entry.Operation, entry.ModelUsed, entry.Provider, entry.PromptHash, entry.PromptLength, entry.ResponseLength,
entry.TokensUsed, entry.DurationMS, entry.PIIDetected, entry.PIITypesDetected, entry.PIIRedacted,
entry.PolicyID, entry.PolicyViolations, entry.DataCategoriesAccessed, entry.ErrorMessage,
metadataJSON, entry.CreatedAt,
)
return err
}
// CreateGeneralAuditEntry creates a new general audit entry
func (s *Store) CreateGeneralAuditEntry(ctx context.Context, entry *GeneralAuditEntry) error {
if entry.ID == uuid.Nil {
entry.ID = uuid.New()
}
if entry.CreatedAt.IsZero() {
entry.CreatedAt = time.Now().UTC()
}
oldValuesJSON, _ := json.Marshal(entry.OldValues)
newValuesJSON, _ := json.Marshal(entry.NewValues)
_, err := s.pool.Exec(ctx, `
INSERT INTO compliance_audit_trail (
id, tenant_id, namespace_id, user_id, action, resource_type, resource_id,
old_values, new_values, ip_address, user_agent, reason, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
`,
entry.ID, entry.TenantID, entry.NamespaceID, entry.UserID,
entry.Action, entry.ResourceType, entry.ResourceID,
oldValuesJSON, newValuesJSON, entry.IPAddress, entry.UserAgent,
entry.Reason, entry.CreatedAt,
)
return err
}
// LLMAuditFilter defines filters for LLM audit queries
type LLMAuditFilter struct {
TenantID uuid.UUID
NamespaceID *uuid.UUID
UserID *uuid.UUID
Operation string
Model string
PIIDetected *bool
HasViolations *bool
StartDate *time.Time
EndDate *time.Time
Limit int
Offset int
}
// QueryLLMAuditEntries queries LLM audit entries with filters
func (s *Store) QueryLLMAuditEntries(ctx context.Context, filter *LLMAuditFilter) ([]*LLMAuditEntry, int, error) {
query := `
SELECT id, tenant_id, namespace_id, user_id, session_id,
operation, model_used, provider, prompt_hash, prompt_length, response_length,
tokens_used, duration_ms, pii_detected, pii_types_detected, pii_redacted,
policy_id, policy_violations, data_categories_accessed, error_message,
request_metadata, created_at
FROM compliance_llm_audit_log
WHERE tenant_id = $1
`
countQuery := `SELECT COUNT(*) FROM compliance_llm_audit_log WHERE tenant_id = $1`
args := []any{filter.TenantID}
argIndex := 2
if filter.NamespaceID != nil {
query += fmt.Sprintf(" AND namespace_id = $%d", argIndex)
countQuery += fmt.Sprintf(" AND namespace_id = $%d", argIndex)
args = append(args, *filter.NamespaceID)
argIndex++
}
if filter.UserID != nil {
query += fmt.Sprintf(" AND user_id = $%d", argIndex)
countQuery += fmt.Sprintf(" AND user_id = $%d", argIndex)
args = append(args, *filter.UserID)
argIndex++
}
if filter.Operation != "" {
query += fmt.Sprintf(" AND operation = $%d", argIndex)
countQuery += fmt.Sprintf(" AND operation = $%d", argIndex)
args = append(args, filter.Operation)
argIndex++
}
if filter.Model != "" {
query += fmt.Sprintf(" AND model_used = $%d", argIndex)
countQuery += fmt.Sprintf(" AND model_used = $%d", argIndex)
args = append(args, filter.Model)
argIndex++
}
if filter.PIIDetected != nil {
query += fmt.Sprintf(" AND pii_detected = $%d", argIndex)
countQuery += fmt.Sprintf(" AND pii_detected = $%d", argIndex)
args = append(args, *filter.PIIDetected)
argIndex++
}
if filter.HasViolations != nil && *filter.HasViolations {
query += " AND array_length(policy_violations, 1) > 0"
countQuery += " AND array_length(policy_violations, 1) > 0"
}
if filter.StartDate != nil {
query += fmt.Sprintf(" AND created_at >= $%d", argIndex)
countQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex)
args = append(args, *filter.StartDate)
argIndex++
}
if filter.EndDate != nil {
query += fmt.Sprintf(" AND created_at <= $%d", argIndex)
countQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex)
args = append(args, *filter.EndDate)
argIndex++
}
// Get total count
var totalCount int
if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&totalCount); err != nil {
return nil, 0, err
}
// Add ordering and pagination
query += " ORDER BY created_at DESC"
if filter.Limit > 0 {
query += fmt.Sprintf(" LIMIT $%d", argIndex)
args = append(args, filter.Limit)
argIndex++
}
if filter.Offset > 0 {
query += fmt.Sprintf(" OFFSET $%d", argIndex)
args = append(args, filter.Offset)
}
rows, err := s.pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var entries []*LLMAuditEntry
for rows.Next() {
var entry LLMAuditEntry
var metadataJSON []byte
err := rows.Scan(
&entry.ID, &entry.TenantID, &entry.NamespaceID, &entry.UserID, &entry.SessionID,
&entry.Operation, &entry.ModelUsed, &entry.Provider, &entry.PromptHash, &entry.PromptLength, &entry.ResponseLength,
&entry.TokensUsed, &entry.DurationMS, &entry.PIIDetected, &entry.PIITypesDetected, &entry.PIIRedacted,
&entry.PolicyID, &entry.PolicyViolations, &entry.DataCategoriesAccessed, &entry.ErrorMessage,
&metadataJSON, &entry.CreatedAt,
)
if err != nil {
continue
}
if metadataJSON != nil {
json.Unmarshal(metadataJSON, &entry.RequestMetadata)
}
entries = append(entries, &entry)
}
return entries, totalCount, nil
}
// GeneralAuditFilter defines filters for general audit queries
type GeneralAuditFilter struct {
TenantID uuid.UUID
NamespaceID *uuid.UUID
UserID *uuid.UUID
Action string
ResourceType string
ResourceID *uuid.UUID
StartDate *time.Time
EndDate *time.Time
Limit int
Offset int
}
// QueryGeneralAuditEntries queries general audit entries with filters
func (s *Store) QueryGeneralAuditEntries(ctx context.Context, filter *GeneralAuditFilter) ([]*GeneralAuditEntry, int, error) {
query := `
SELECT id, tenant_id, namespace_id, user_id, action, resource_type, resource_id,
old_values, new_values, ip_address, user_agent, reason, created_at
FROM compliance_audit_trail
WHERE tenant_id = $1
`
countQuery := `SELECT COUNT(*) FROM compliance_audit_trail WHERE tenant_id = $1`
args := []any{filter.TenantID}
argIndex := 2
if filter.NamespaceID != nil {
query += fmt.Sprintf(" AND namespace_id = $%d", argIndex)
countQuery += fmt.Sprintf(" AND namespace_id = $%d", argIndex)
args = append(args, *filter.NamespaceID)
argIndex++
}
if filter.UserID != nil {
query += fmt.Sprintf(" AND user_id = $%d", argIndex)
countQuery += fmt.Sprintf(" AND user_id = $%d", argIndex)
args = append(args, *filter.UserID)
argIndex++
}
if filter.Action != "" {
query += fmt.Sprintf(" AND action = $%d", argIndex)
countQuery += fmt.Sprintf(" AND action = $%d", argIndex)
args = append(args, filter.Action)
argIndex++
}
if filter.ResourceType != "" {
query += fmt.Sprintf(" AND resource_type = $%d", argIndex)
countQuery += fmt.Sprintf(" AND resource_type = $%d", argIndex)
args = append(args, filter.ResourceType)
argIndex++
}
if filter.ResourceID != nil {
query += fmt.Sprintf(" AND resource_id = $%d", argIndex)
countQuery += fmt.Sprintf(" AND resource_id = $%d", argIndex)
args = append(args, *filter.ResourceID)
argIndex++
}
if filter.StartDate != nil {
query += fmt.Sprintf(" AND created_at >= $%d", argIndex)
countQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex)
args = append(args, *filter.StartDate)
argIndex++
}
if filter.EndDate != nil {
query += fmt.Sprintf(" AND created_at <= $%d", argIndex)
countQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex)
args = append(args, *filter.EndDate)
argIndex++
}
// Get total count
var totalCount int
if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&totalCount); err != nil {
return nil, 0, err
}
// Add ordering and pagination
query += " ORDER BY created_at DESC"
if filter.Limit > 0 {
query += fmt.Sprintf(" LIMIT $%d", argIndex)
args = append(args, filter.Limit)
argIndex++
}
if filter.Offset > 0 {
query += fmt.Sprintf(" OFFSET $%d", argIndex)
args = append(args, filter.Offset)
}
rows, err := s.pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var entries []*GeneralAuditEntry
for rows.Next() {
var entry GeneralAuditEntry
var oldValuesJSON, newValuesJSON []byte
err := rows.Scan(
&entry.ID, &entry.TenantID, &entry.NamespaceID, &entry.UserID,
&entry.Action, &entry.ResourceType, &entry.ResourceID,
&oldValuesJSON, &newValuesJSON, &entry.IPAddress, &entry.UserAgent,
&entry.Reason, &entry.CreatedAt,
)
if err != nil {
continue
}
if oldValuesJSON != nil {
json.Unmarshal(oldValuesJSON, &entry.OldValues)
}
if newValuesJSON != nil {
json.Unmarshal(newValuesJSON, &entry.NewValues)
}
entries = append(entries, &entry)
}
return entries, totalCount, nil
}
// GetLLMUsageStats retrieves aggregated LLM usage statistics
func (s *Store) GetLLMUsageStats(ctx context.Context, tenantID uuid.UUID, startDate, endDate time.Time) (*LLMUsageStats, error) {
var stats LLMUsageStats
err := s.pool.QueryRow(ctx, `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(tokens_used), 0) as total_tokens,
COALESCE(SUM(duration_ms), 0) as total_duration_ms,
COUNT(*) FILTER (WHERE pii_detected = TRUE) as requests_with_pii,
COUNT(*) FILTER (WHERE array_length(policy_violations, 1) > 0) as policy_violations
FROM compliance_llm_audit_log
WHERE tenant_id = $1 AND created_at >= $2 AND created_at <= $3
`, tenantID, startDate, endDate).Scan(
&stats.TotalRequests,
&stats.TotalTokens,
&stats.TotalDurationMS,
&stats.RequestsWithPII,
&stats.PolicyViolations,
)
if err != nil {
return nil, err
}
// Get model usage breakdown
rows, err := s.pool.Query(ctx, `
SELECT model_used, COUNT(*) as count
FROM compliance_llm_audit_log
WHERE tenant_id = $1 AND created_at >= $2 AND created_at <= $3
GROUP BY model_used
`, tenantID, startDate, endDate)
if err != nil {
return nil, err
}
defer rows.Close()
stats.ModelsUsed = make(map[string]int)
for rows.Next() {
var model string
var count int
if err := rows.Scan(&model, &count); err == nil {
stats.ModelsUsed[model] = count
}
}
return &stats, nil
}
// LLMUsageStats represents aggregated LLM usage statistics
type LLMUsageStats struct {
TotalRequests int `json:"total_requests"`
TotalTokens int `json:"total_tokens"`
TotalDurationMS int64 `json:"total_duration_ms"`
RequestsWithPII int `json:"requests_with_pii"`
PolicyViolations int `json:"policy_violations"`
ModelsUsed map[string]int `json:"models_used"`
}
// CleanupOldEntries removes audit entries older than the retention period
func (s *Store) CleanupOldEntries(ctx context.Context, retentionDays int) (int, int, error) {
cutoff := time.Now().UTC().AddDate(0, 0, -retentionDays)
// Cleanup LLM audit log
llmResult, err := s.pool.Exec(ctx, `
DELETE FROM compliance_llm_audit_log WHERE created_at < $1
`, cutoff)
if err != nil {
return 0, 0, err
}
// Cleanup general audit trail
generalResult, err := s.pool.Exec(ctx, `
DELETE FROM compliance_audit_trail WHERE created_at < $1
`, cutoff)
if err != nil {
return int(llmResult.RowsAffected()), 0, err
}
return int(llmResult.RowsAffected()), int(generalResult.RowsAffected()), nil
}