package policy import ( "context" "fmt" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" ) // ============================================================================= // ALLOWED SOURCES // ============================================================================= // CreateSource creates a new allowed source. func (s *Store) CreateSource(ctx context.Context, req *CreateAllowedSourceRequest) (*AllowedSource, error) { trustBoost := 0.5 if req.TrustBoost != nil { trustBoost = *req.TrustBoost } source := &AllowedSource{ ID: uuid.New(), PolicyID: req.PolicyID, Domain: req.Domain, Name: req.Name, Description: req.Description, License: req.License, LegalBasis: req.LegalBasis, CitationTemplate: req.CitationTemplate, TrustBoost: trustBoost, IsActive: true, CreatedAt: time.Now(), UpdatedAt: time.Now(), } query := ` INSERT INTO allowed_sources (id, policy_id, domain, name, description, license, legal_basis, citation_template, trust_boost, is_active, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id` err := s.pool.QueryRow(ctx, query, source.ID, source.PolicyID, source.Domain, source.Name, source.Description, source.License, source.LegalBasis, source.CitationTemplate, source.TrustBoost, source.IsActive, source.CreatedAt, source.UpdatedAt, ).Scan(&source.ID) if err != nil { return nil, fmt.Errorf("failed to create source: %w", err) } // Create default operation permissions err = s.createDefaultOperations(ctx, source.ID) if err != nil { return nil, fmt.Errorf("failed to create default operations: %w", err) } return source, nil } // createDefaultOperations creates default operation permissions for a source. func (s *Store) createDefaultOperations(ctx context.Context, sourceID uuid.UUID) error { defaults := []struct { op Operation allowed bool citation bool }{ {OperationLookup, true, true}, {OperationRAG, true, true}, {OperationTraining, false, false}, // VERBOTEN by default {OperationExport, true, true}, } for _, d := range defaults { query := ` INSERT INTO operation_permissions (id, source_id, operation, is_allowed, requires_citation, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7)` _, err := s.pool.Exec(ctx, query, uuid.New(), sourceID, d.op, d.allowed, d.citation, time.Now(), time.Now()) if err != nil { return err } } return nil } // GetSource retrieves a source by ID. func (s *Store) GetSource(ctx context.Context, id uuid.UUID) (*AllowedSource, error) { query := ` SELECT als.id, als.policy_id, als.domain, als.name, als.description, als.license, als.legal_basis, als.citation_template, als.trust_boost, als.is_active, als.created_at, als.updated_at, sp.name as policy_name FROM allowed_sources als JOIN source_policies sp ON als.policy_id = sp.id WHERE als.id = $1` source := &AllowedSource{} err := s.pool.QueryRow(ctx, query, id).Scan( &source.ID, &source.PolicyID, &source.Domain, &source.Name, &source.Description, &source.License, &source.LegalBasis, &source.CitationTemplate, &source.TrustBoost, &source.IsActive, &source.CreatedAt, &source.UpdatedAt, &source.PolicyName, ) if err == pgx.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get source: %w", err) } // Load operations ops, err := s.GetOperationsBySourceID(ctx, source.ID) if err != nil { return nil, err } source.Operations = ops return source, nil } // GetSourceByDomain retrieves a source by domain with optional bundesland filter. func (s *Store) GetSourceByDomain(ctx context.Context, domain string, bundesland *Bundesland) (*AllowedSource, error) { query := ` SELECT als.id, als.policy_id, als.domain, als.name, als.description, als.license, als.legal_basis, als.citation_template, als.trust_boost, als.is_active, als.created_at, als.updated_at FROM allowed_sources als JOIN source_policies sp ON als.policy_id = sp.id WHERE als.is_active = true AND sp.is_active = true AND (als.domain = $1 OR $1 LIKE '%.' || als.domain) AND (sp.bundesland IS NULL OR sp.bundesland = $2) LIMIT 1` source := &AllowedSource{} err := s.pool.QueryRow(ctx, query, domain, bundesland).Scan( &source.ID, &source.PolicyID, &source.Domain, &source.Name, &source.Description, &source.License, &source.LegalBasis, &source.CitationTemplate, &source.TrustBoost, &source.IsActive, &source.CreatedAt, &source.UpdatedAt, ) if err == pgx.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get source by domain: %w", err) } // Load operations ops, err := s.GetOperationsBySourceID(ctx, source.ID) if err != nil { return nil, err } source.Operations = ops return source, nil } // ListSources retrieves sources with optional filters. func (s *Store) ListSources(ctx context.Context, filter *SourceListFilter) ([]AllowedSource, int, error) { baseQuery := `FROM allowed_sources als JOIN source_policies sp ON als.policy_id = sp.id WHERE 1=1` args := []interface{}{} argCount := 0 if filter.PolicyID != nil { argCount++ baseQuery += fmt.Sprintf(" AND als.policy_id = $%d", argCount) args = append(args, *filter.PolicyID) } if filter.Domain != nil { argCount++ baseQuery += fmt.Sprintf(" AND als.domain ILIKE $%d", argCount) args = append(args, "%"+*filter.Domain+"%") } if filter.License != nil { argCount++ baseQuery += fmt.Sprintf(" AND als.license = $%d", argCount) args = append(args, *filter.License) } if filter.IsActive != nil { argCount++ baseQuery += fmt.Sprintf(" AND als.is_active = $%d", argCount) args = append(args, *filter.IsActive) } // Count query var total int countQuery := "SELECT COUNT(*) " + baseQuery err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&total) if err != nil { return nil, 0, fmt.Errorf("failed to count sources: %w", err) } // Data query dataQuery := `SELECT als.id, als.policy_id, als.domain, als.name, als.description, als.license, als.legal_basis, als.citation_template, als.trust_boost, als.is_active, als.created_at, als.updated_at, sp.name as policy_name ` + baseQuery + ` ORDER BY als.created_at DESC` if filter.Limit > 0 { argCount++ dataQuery += fmt.Sprintf(" LIMIT $%d", argCount) args = append(args, filter.Limit) } if filter.Offset > 0 { argCount++ dataQuery += fmt.Sprintf(" OFFSET $%d", argCount) args = append(args, filter.Offset) } rows, err := s.pool.Query(ctx, dataQuery, args...) if err != nil { return nil, 0, fmt.Errorf("failed to list sources: %w", err) } defer rows.Close() sources := []AllowedSource{} for rows.Next() { var src AllowedSource err := rows.Scan( &src.ID, &src.PolicyID, &src.Domain, &src.Name, &src.Description, &src.License, &src.LegalBasis, &src.CitationTemplate, &src.TrustBoost, &src.IsActive, &src.CreatedAt, &src.UpdatedAt, &src.PolicyName, ) if err != nil { return nil, 0, fmt.Errorf("failed to scan source: %w", err) } sources = append(sources, src) } return sources, total, nil } // UpdateSource updates an existing source. func (s *Store) UpdateSource(ctx context.Context, id uuid.UUID, req *UpdateAllowedSourceRequest) (*AllowedSource, error) { source, err := s.GetSource(ctx, id) if err != nil { return nil, err } if source == nil { return nil, fmt.Errorf("source not found") } if req.Domain != nil { source.Domain = *req.Domain } if req.Name != nil { source.Name = *req.Name } if req.Description != nil { source.Description = req.Description } if req.License != nil { source.License = *req.License } if req.LegalBasis != nil { source.LegalBasis = req.LegalBasis } if req.CitationTemplate != nil { source.CitationTemplate = req.CitationTemplate } if req.TrustBoost != nil { source.TrustBoost = *req.TrustBoost } if req.IsActive != nil { source.IsActive = *req.IsActive } source.UpdatedAt = time.Now() query := ` UPDATE allowed_sources SET domain = $2, name = $3, description = $4, license = $5, legal_basis = $6, citation_template = $7, trust_boost = $8, is_active = $9, updated_at = $10 WHERE id = $1` _, err = s.pool.Exec(ctx, query, id, source.Domain, source.Name, source.Description, source.License, source.LegalBasis, source.CitationTemplate, source.TrustBoost, source.IsActive, source.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to update source: %w", err) } return source, nil } // DeleteSource deletes a source by ID. func (s *Store) DeleteSource(ctx context.Context, id uuid.UUID) error { query := `DELETE FROM allowed_sources WHERE id = $1` _, err := s.pool.Exec(ctx, query, id) if err != nil { return fmt.Errorf("failed to delete source: %w", err) } return nil } // ============================================================================= // OPERATION PERMISSIONS // ============================================================================= // GetOperationsBySourceID retrieves all operation permissions for a source. func (s *Store) GetOperationsBySourceID(ctx context.Context, sourceID uuid.UUID) ([]OperationPermission, error) { query := ` SELECT id, source_id, operation, is_allowed, requires_citation, notes, created_at, updated_at FROM operation_permissions WHERE source_id = $1 ORDER BY operation` rows, err := s.pool.Query(ctx, query, sourceID) if err != nil { return nil, fmt.Errorf("failed to get operations: %w", err) } defer rows.Close() ops := []OperationPermission{} for rows.Next() { var op OperationPermission err := rows.Scan( &op.ID, &op.SourceID, &op.Operation, &op.IsAllowed, &op.RequiresCitation, &op.Notes, &op.CreatedAt, &op.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to scan operation: %w", err) } ops = append(ops, op) } return ops, nil } // UpdateOperationPermission updates an operation permission. func (s *Store) UpdateOperationPermission(ctx context.Context, id uuid.UUID, req *UpdateOperationPermissionRequest) (*OperationPermission, error) { query := `SELECT id, source_id, operation, is_allowed, requires_citation, notes, created_at, updated_at FROM operation_permissions WHERE id = $1` op := &OperationPermission{} err := s.pool.QueryRow(ctx, query, id).Scan( &op.ID, &op.SourceID, &op.Operation, &op.IsAllowed, &op.RequiresCitation, &op.Notes, &op.CreatedAt, &op.UpdatedAt, ) if err == pgx.ErrNoRows { return nil, fmt.Errorf("operation permission not found") } if err != nil { return nil, fmt.Errorf("failed to get operation: %w", err) } if req.IsAllowed != nil { op.IsAllowed = *req.IsAllowed } if req.RequiresCitation != nil { op.RequiresCitation = *req.RequiresCitation } if req.Notes != nil { op.Notes = req.Notes } op.UpdatedAt = time.Now() updateQuery := ` UPDATE operation_permissions SET is_allowed = $2, requires_citation = $3, notes = $4, updated_at = $5 WHERE id = $1` _, err = s.pool.Exec(ctx, updateQuery, id, op.IsAllowed, op.RequiresCitation, op.Notes, op.UpdatedAt) if err != nil { return nil, fmt.Errorf("failed to update operation: %w", err) } return op, nil } // GetOperationsMatrix retrieves all operation permissions grouped by source. func (s *Store) GetOperationsMatrix(ctx context.Context) ([]AllowedSource, error) { query := ` SELECT als.id, als.domain, als.name, als.license, als.is_active, sp.name as policy_name, sp.bundesland FROM allowed_sources als JOIN source_policies sp ON als.policy_id = sp.id WHERE als.is_active = true AND sp.is_active = true ORDER BY sp.bundesland NULLS FIRST, als.name` rows, err := s.pool.Query(ctx, query) if err != nil { return nil, fmt.Errorf("failed to get operations matrix: %w", err) } defer rows.Close() sources := []AllowedSource{} for rows.Next() { var src AllowedSource var bundesland *Bundesland err := rows.Scan( &src.ID, &src.Domain, &src.Name, &src.License, &src.IsActive, &src.PolicyName, &bundesland, ) if err != nil { return nil, fmt.Errorf("failed to scan source: %w", err) } // Load operations for each source ops, err := s.GetOperationsBySourceID(ctx, src.ID) if err != nil { return nil, err } src.Operations = ops sources = append(sources, src) } return sources, nil }