feat: edu-search-service migriert, voice-service/geo-service entfernt
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 27s
CI / test-python-klausur (push) Successful in 1m45s
CI / test-python-agent-core (push) Successful in 16s
CI / test-nodejs-website (push) Successful in 21s

- edu-search-service von breakpilot-pwa nach breakpilot-lehrer kopiert (ohne vendor)
- opensearch + edu-search-service in docker-compose.yml hinzugefuegt
- voice-service aus docker-compose.yml entfernt (jetzt in breakpilot-core)
- geo-service aus docker-compose.yml entfernt (nicht mehr benoetigt)
- CI/CD: edu-search-service zu Gitea Actions und Woodpecker hinzugefuegt
  (Go lint, test mit go mod download, build, SBOM)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Boenisch
2026-02-15 18:36:38 +01:00
parent d4e1d6bab6
commit 414e0f5ec0
73 changed files with 23938 additions and 92 deletions

View File

@@ -0,0 +1,406 @@
package handlers
import (
"encoding/json"
"net/http"
"os"
"path/filepath"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// SeedURL represents a seed URL configuration
type SeedURL struct {
ID string `json:"id"`
URL string `json:"url"`
Category string `json:"category"`
Name string `json:"name"`
Description string `json:"description"`
TrustBoost float64 `json:"trustBoost"`
Enabled bool `json:"enabled"`
LastCrawled *string `json:"lastCrawled,omitempty"`
DocumentCount int `json:"documentCount,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// CrawlStats contains crawl statistics
type CrawlStats struct {
TotalDocuments int `json:"totalDocuments"`
TotalSeeds int `json:"totalSeeds"`
LastCrawlTime *string `json:"lastCrawlTime,omitempty"`
CrawlStatus string `json:"crawlStatus"`
DocumentsPerCategory map[string]int `json:"documentsPerCategory"`
DocumentsPerDocType map[string]int `json:"documentsPerDocType"`
AvgTrustScore float64 `json:"avgTrustScore"`
}
// SeedStore manages seed URLs in memory and file
type SeedStore struct {
seeds map[string]SeedURL
mu sync.RWMutex
filePath string
}
var seedStore *SeedStore
var crawlStatus = "idle"
var lastCrawlTime *string
// InitSeedStore initializes the seed store
func InitSeedStore(seedsDir string) error {
seedStore = &SeedStore{
seeds: make(map[string]SeedURL),
filePath: filepath.Join(seedsDir, "seeds.json"),
}
// Try to load existing seeds from JSON file
if err := seedStore.loadFromFile(); err != nil {
// If file doesn't exist, load from txt files
return seedStore.loadFromTxtFiles(seedsDir)
}
return nil
}
func (s *SeedStore) loadFromFile() error {
data, err := os.ReadFile(s.filePath)
if err != nil {
return err
}
var seeds []SeedURL
if err := json.Unmarshal(data, &seeds); err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
for _, seed := range seeds {
s.seeds[seed.ID] = seed
}
return nil
}
func (s *SeedStore) loadFromTxtFiles(seedsDir string) error {
// Default seeds from category files
defaultSeeds := []SeedURL{
{ID: uuid.New().String(), URL: "https://www.kmk.org", Category: "federal", Name: "Kultusministerkonferenz", Description: "Beschlüsse und Bildungsstandards", TrustBoost: 0.50, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.bildungsserver.de", Category: "federal", Name: "Deutscher Bildungsserver", Description: "Zentrale Bildungsinformationen", TrustBoost: 0.50, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.bpb.de", Category: "federal", Name: "Bundeszentrale politische Bildung", Description: "Politische Bildung", TrustBoost: 0.45, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.bmbf.de", Category: "federal", Name: "BMBF", Description: "Bundesbildungsministerium", TrustBoost: 0.50, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.iqb.hu-berlin.de", Category: "federal", Name: "IQB", Description: "Institut Qualitätsentwicklung", TrustBoost: 0.50, Enabled: true},
// Science
{ID: uuid.New().String(), URL: "https://www.bertelsmann-stiftung.de/de/themen/bildung", Category: "science", Name: "Bertelsmann Stiftung", Description: "Bildungsstudien und Ländermonitor", TrustBoost: 0.40, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.oecd.org/pisa", Category: "science", Name: "PISA-Studien", Description: "Internationale Schulleistungsstudie", TrustBoost: 0.45, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.iea.nl/studies/iea/pirls", Category: "science", Name: "IGLU/PIRLS", Description: "Internationale Grundschul-Lese-Untersuchung", TrustBoost: 0.45, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.iea.nl/studies/iea/timss", Category: "science", Name: "TIMSS", Description: "Trends in International Mathematics and Science Study", TrustBoost: 0.45, Enabled: true},
// Bundesländer
{ID: uuid.New().String(), URL: "https://www.km.bayern.de", Category: "states", Name: "Bayern Kultusministerium", Description: "Lehrpläne Bayern", TrustBoost: 0.45, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.schulministerium.nrw", Category: "states", Name: "NRW Schulministerium", Description: "Lehrpläne NRW", TrustBoost: 0.45, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.berlin.de/sen/bildung", Category: "states", Name: "Berlin Bildung", Description: "Rahmenlehrpläne Berlin", TrustBoost: 0.45, Enabled: true},
{ID: uuid.New().String(), URL: "https://kultusministerium.hessen.de", Category: "states", Name: "Hessen Kultusministerium", Description: "Kerncurricula Hessen", TrustBoost: 0.45, Enabled: true},
// Portale
{ID: uuid.New().String(), URL: "https://www.lehrer-online.de", Category: "portals", Name: "Lehrer-Online", Description: "Unterrichtsmaterialien", TrustBoost: 0.20, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.4teachers.de", Category: "portals", Name: "4teachers", Description: "Lehrercommunity", TrustBoost: 0.20, Enabled: true},
{ID: uuid.New().String(), URL: "https://www.zum.de", Category: "portals", Name: "ZUM", Description: "Zentrale für Unterrichtsmedien", TrustBoost: 0.25, Enabled: true},
}
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for _, seed := range defaultSeeds {
seed.CreatedAt = now
seed.UpdatedAt = now
s.seeds[seed.ID] = seed
}
return s.saveToFile()
}
func (s *SeedStore) saveToFile() error {
seeds := make([]SeedURL, 0, len(s.seeds))
for _, seed := range s.seeds {
seeds = append(seeds, seed)
}
data, err := json.MarshalIndent(seeds, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.filePath, data, 0644)
}
// GetAllSeeds returns all seeds
func (s *SeedStore) GetAllSeeds() []SeedURL {
s.mu.RLock()
defer s.mu.RUnlock()
seeds := make([]SeedURL, 0, len(s.seeds))
for _, seed := range s.seeds {
seeds = append(seeds, seed)
}
return seeds
}
// GetSeed returns a single seed by ID
func (s *SeedStore) GetSeed(id string) (SeedURL, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
seed, ok := s.seeds[id]
return seed, ok
}
// CreateSeed adds a new seed
func (s *SeedStore) CreateSeed(seed SeedURL) (SeedURL, error) {
s.mu.Lock()
defer s.mu.Unlock()
seed.ID = uuid.New().String()
seed.CreatedAt = time.Now()
seed.UpdatedAt = time.Now()
s.seeds[seed.ID] = seed
if err := s.saveToFile(); err != nil {
delete(s.seeds, seed.ID)
return SeedURL{}, err
}
return seed, nil
}
// UpdateSeed updates an existing seed
func (s *SeedStore) UpdateSeed(id string, updates SeedURL) (SeedURL, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
seed, ok := s.seeds[id]
if !ok {
return SeedURL{}, false, nil
}
// Update fields
if updates.URL != "" {
seed.URL = updates.URL
}
if updates.Name != "" {
seed.Name = updates.Name
}
if updates.Category != "" {
seed.Category = updates.Category
}
if updates.Description != "" {
seed.Description = updates.Description
}
seed.TrustBoost = updates.TrustBoost
seed.Enabled = updates.Enabled
seed.UpdatedAt = time.Now()
s.seeds[id] = seed
if err := s.saveToFile(); err != nil {
return SeedURL{}, true, err
}
return seed, true, nil
}
// DeleteSeed removes a seed
func (s *SeedStore) DeleteSeed(id string) bool {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.seeds[id]; !ok {
return false
}
delete(s.seeds, id)
s.saveToFile()
return true
}
// Admin Handlers
// GetSeeds returns all seed URLs
func (h *Handler) GetSeeds(c *gin.Context) {
if seedStore == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Seed store not initialized"})
return
}
seeds := seedStore.GetAllSeeds()
c.JSON(http.StatusOK, seeds)
}
// CreateSeed adds a new seed URL
func (h *Handler) CreateSeed(c *gin.Context) {
if seedStore == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Seed store not initialized"})
return
}
var seed SeedURL
if err := c.ShouldBindJSON(&seed); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
if seed.URL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "URL is required"})
return
}
created, err := seedStore.CreateSeed(seed)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create seed", "details": err.Error()})
return
}
c.JSON(http.StatusCreated, created)
}
// UpdateSeed updates an existing seed URL
func (h *Handler) UpdateSeed(c *gin.Context) {
if seedStore == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Seed store not initialized"})
return
}
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Seed ID required"})
return
}
var updates SeedURL
if err := c.ShouldBindJSON(&updates); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
updated, found, err := seedStore.UpdateSeed(id, updates)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update seed", "details": err.Error()})
return
}
if !found {
c.JSON(http.StatusNotFound, gin.H{"error": "Seed not found"})
return
}
c.JSON(http.StatusOK, updated)
}
// DeleteSeed removes a seed URL
func (h *Handler) DeleteSeed(c *gin.Context) {
if seedStore == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Seed store not initialized"})
return
}
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Seed ID required"})
return
}
if !seedStore.DeleteSeed(id) {
c.JSON(http.StatusNotFound, gin.H{"error": "Seed not found"})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true, "id": id})
}
// GetStats returns crawl statistics
func (h *Handler) GetStats(c *gin.Context) {
// Get document count from OpenSearch
totalDocs := 0
// TODO: Get real count from OpenSearch
seeds := []SeedURL{}
if seedStore != nil {
seeds = seedStore.GetAllSeeds()
}
enabledSeeds := 0
for _, seed := range seeds {
if seed.Enabled {
enabledSeeds++
}
}
stats := CrawlStats{
TotalDocuments: totalDocs,
TotalSeeds: enabledSeeds,
LastCrawlTime: lastCrawlTime,
CrawlStatus: crawlStatus,
DocumentsPerCategory: map[string]int{
"federal": 0,
"states": 0,
"science": 0,
"universities": 0,
"portals": 0,
},
DocumentsPerDocType: map[string]int{
"Lehrplan": 0,
"Arbeitsblatt": 0,
"Unterrichtsentwurf": 0,
"Erlass_Verordnung": 0,
"Pruefung_Abitur": 0,
"Studie_Bericht": 0,
"Sonstiges": 0,
},
AvgTrustScore: 0.0,
}
c.JSON(http.StatusOK, stats)
}
// StartCrawl initiates a crawl run
func (h *Handler) StartCrawl(c *gin.Context) {
if crawlStatus == "running" {
c.JSON(http.StatusConflict, gin.H{"error": "Crawl already running"})
return
}
crawlStatus = "running"
// TODO: Start actual crawl in background goroutine
go func() {
time.Sleep(5 * time.Second) // Simulate crawl
now := time.Now().Format(time.RFC3339)
lastCrawlTime = &now
crawlStatus = "idle"
}()
c.JSON(http.StatusAccepted, gin.H{
"status": "started",
"message": "Crawl initiated",
})
}
// SetupAdminRoutes configures admin API routes
func SetupAdminRoutes(r *gin.RouterGroup, h *Handler) {
admin := r.Group("/admin")
{
// Seeds CRUD
admin.GET("/seeds", h.GetSeeds)
admin.POST("/seeds", h.CreateSeed)
admin.PUT("/seeds/:id", h.UpdateSeed)
admin.DELETE("/seeds/:id", h.DeleteSeed)
// Stats
admin.GET("/stats", h.GetStats)
// Crawl control
admin.POST("/crawl/start", h.StartCrawl)
}
}

View File

@@ -0,0 +1,554 @@
package handlers
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/breakpilot/edu-search-service/internal/database"
)
// AIExtractionHandlers handles AI-based profile extraction endpoints
// These endpoints are designed for vast.ai or similar AI services to:
// 1. Get profile URLs that need extraction
// 2. Submit extracted data back
type AIExtractionHandlers struct {
repo *database.Repository
}
// NewAIExtractionHandlers creates new AI extraction handlers
func NewAIExtractionHandlers(repo *database.Repository) *AIExtractionHandlers {
return &AIExtractionHandlers{repo: repo}
}
// ProfileExtractionTask represents a profile URL to be processed by AI
type ProfileExtractionTask struct {
StaffID uuid.UUID `json:"staff_id"`
ProfileURL string `json:"profile_url"`
UniversityID uuid.UUID `json:"university_id"`
UniversityURL string `json:"university_url,omitempty"`
FullName string `json:"full_name,omitempty"`
CurrentData struct {
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Office string `json:"office,omitempty"`
Position string `json:"position,omitempty"`
Department string `json:"department,omitempty"`
} `json:"current_data"`
}
// GetPendingProfiles returns staff profiles that need AI extraction
// GET /api/v1/ai/extraction/pending?limit=10&university_id=...
func (h *AIExtractionHandlers) GetPendingProfiles(c *gin.Context) {
limit := parseIntDefault(c.Query("limit"), 10)
if limit > 100 {
limit = 100
}
var universityID *uuid.UUID
if uniIDStr := c.Query("university_id"); uniIDStr != "" {
id, err := uuid.Parse(uniIDStr)
if err == nil {
universityID = &id
}
}
// Get staff that have profile URLs but missing key data
params := database.StaffSearchParams{
UniversityID: universityID,
Limit: limit * 2, // Get more to filter
}
result, err := h.repo.SearchStaff(c.Request.Context(), params)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Filter to only include profiles that need extraction
var tasks []ProfileExtractionTask
for _, staff := range result.Staff {
// Skip if no profile URL
if staff.ProfileURL == nil || *staff.ProfileURL == "" {
continue
}
// Include if missing email or other important data
needsExtraction := staff.Email == nil || *staff.Email == ""
if needsExtraction {
task := ProfileExtractionTask{
StaffID: staff.ID,
ProfileURL: *staff.ProfileURL,
UniversityID: staff.UniversityID,
}
if staff.FullName != nil {
task.FullName = *staff.FullName
}
if staff.Email != nil {
task.CurrentData.Email = *staff.Email
}
if staff.Phone != nil {
task.CurrentData.Phone = *staff.Phone
}
if staff.Office != nil {
task.CurrentData.Office = *staff.Office
}
if staff.Position != nil {
task.CurrentData.Position = *staff.Position
}
if staff.DepartmentName != nil {
task.CurrentData.Department = *staff.DepartmentName
}
tasks = append(tasks, task)
if len(tasks) >= limit {
break
}
}
}
c.JSON(http.StatusOK, gin.H{
"tasks": tasks,
"total": len(tasks),
})
}
// ExtractedProfileData represents data extracted by AI from a profile page
type ExtractedProfileData struct {
StaffID uuid.UUID `json:"staff_id" binding:"required"`
// Contact info
Email string `json:"email,omitempty"`
Phone string `json:"phone,omitempty"`
Office string `json:"office,omitempty"`
// Professional info
Position string `json:"position,omitempty"`
PositionType string `json:"position_type,omitempty"` // professor, researcher, phd_student, staff
AcademicTitle string `json:"academic_title,omitempty"`
IsProfessor *bool `json:"is_professor,omitempty"`
DepartmentName string `json:"department_name,omitempty"`
// Hierarchy
SupervisorName string `json:"supervisor_name,omitempty"`
TeamRole string `json:"team_role,omitempty"` // leitung, mitarbeiter, sekretariat, hiwi, doktorand
// Research
ResearchInterests []string `json:"research_interests,omitempty"`
ResearchSummary string `json:"research_summary,omitempty"`
// Teaching (Lehrveranstaltungen)
TeachingTopics []string `json:"teaching_topics,omitempty"`
// External profiles
ORCID string `json:"orcid,omitempty"`
GoogleScholarID string `json:"google_scholar_id,omitempty"`
ResearchgateURL string `json:"researchgate_url,omitempty"`
LinkedInURL string `json:"linkedin_url,omitempty"`
PersonalWebsite string `json:"personal_website,omitempty"`
PhotoURL string `json:"photo_url,omitempty"`
// Institute/Department links discovered
InstituteURL string `json:"institute_url,omitempty"`
InstituteName string `json:"institute_name,omitempty"`
// Confidence score (0-1)
Confidence float64 `json:"confidence,omitempty"`
}
// SubmitExtractedData saves AI-extracted profile data
// POST /api/v1/ai/extraction/submit
func (h *AIExtractionHandlers) SubmitExtractedData(c *gin.Context) {
var data ExtractedProfileData
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request: " + err.Error()})
return
}
// Get existing staff record
staff, err := h.repo.GetStaff(c.Request.Context(), data.StaffID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Staff not found"})
return
}
// Update fields if provided and not empty
updated := false
if data.Email != "" && (staff.Email == nil || *staff.Email == "") {
staff.Email = &data.Email
updated = true
}
if data.Phone != "" && (staff.Phone == nil || *staff.Phone == "") {
staff.Phone = &data.Phone
updated = true
}
if data.Office != "" && (staff.Office == nil || *staff.Office == "") {
staff.Office = &data.Office
updated = true
}
if data.Position != "" && (staff.Position == nil || *staff.Position == "") {
staff.Position = &data.Position
updated = true
}
if data.PositionType != "" && (staff.PositionType == nil || *staff.PositionType == "") {
staff.PositionType = &data.PositionType
updated = true
}
if data.AcademicTitle != "" && (staff.AcademicTitle == nil || *staff.AcademicTitle == "") {
staff.AcademicTitle = &data.AcademicTitle
updated = true
}
if data.IsProfessor != nil {
staff.IsProfessor = *data.IsProfessor
updated = true
}
if data.TeamRole != "" && (staff.TeamRole == nil || *staff.TeamRole == "") {
staff.TeamRole = &data.TeamRole
updated = true
}
if len(data.ResearchInterests) > 0 && len(staff.ResearchInterests) == 0 {
staff.ResearchInterests = data.ResearchInterests
updated = true
}
if data.ResearchSummary != "" && (staff.ResearchSummary == nil || *staff.ResearchSummary == "") {
staff.ResearchSummary = &data.ResearchSummary
updated = true
}
if data.ORCID != "" && (staff.ORCID == nil || *staff.ORCID == "") {
staff.ORCID = &data.ORCID
updated = true
}
if data.GoogleScholarID != "" && (staff.GoogleScholarID == nil || *staff.GoogleScholarID == "") {
staff.GoogleScholarID = &data.GoogleScholarID
updated = true
}
if data.ResearchgateURL != "" && (staff.ResearchgateURL == nil || *staff.ResearchgateURL == "") {
staff.ResearchgateURL = &data.ResearchgateURL
updated = true
}
if data.LinkedInURL != "" && (staff.LinkedInURL == nil || *staff.LinkedInURL == "") {
staff.LinkedInURL = &data.LinkedInURL
updated = true
}
if data.PersonalWebsite != "" && (staff.PersonalWebsite == nil || *staff.PersonalWebsite == "") {
staff.PersonalWebsite = &data.PersonalWebsite
updated = true
}
if data.PhotoURL != "" && (staff.PhotoURL == nil || *staff.PhotoURL == "") {
staff.PhotoURL = &data.PhotoURL
updated = true
}
// Try to resolve supervisor by name
if data.SupervisorName != "" && staff.SupervisorID == nil {
// Search for supervisor in same university
supervisorParams := database.StaffSearchParams{
Query: data.SupervisorName,
UniversityID: &staff.UniversityID,
Limit: 1,
}
result, err := h.repo.SearchStaff(c.Request.Context(), supervisorParams)
if err == nil && len(result.Staff) > 0 {
staff.SupervisorID = &result.Staff[0].ID
updated = true
}
}
// Update last verified timestamp
now := time.Now()
staff.LastVerified = &now
if updated {
err = h.repo.CreateStaff(c.Request.Context(), staff)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update: " + err.Error()})
return
}
}
c.JSON(http.StatusOK, gin.H{
"status": "success",
"updated": updated,
"staff_id": staff.ID,
})
}
// SubmitBatchExtractedData saves multiple AI-extracted profile data items
// POST /api/v1/ai/extraction/submit-batch
func (h *AIExtractionHandlers) SubmitBatchExtractedData(c *gin.Context) {
var batch struct {
Items []ExtractedProfileData `json:"items" binding:"required"`
}
if err := c.ShouldBindJSON(&batch); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request: " + err.Error()})
return
}
results := make([]gin.H, 0, len(batch.Items))
successCount := 0
errorCount := 0
for _, item := range batch.Items {
// Get existing staff record
staff, err := h.repo.GetStaff(c.Request.Context(), item.StaffID)
if err != nil {
results = append(results, gin.H{
"staff_id": item.StaffID,
"status": "error",
"error": "Staff not found",
})
errorCount++
continue
}
// Apply updates (same logic as single submit)
updated := false
if item.Email != "" && (staff.Email == nil || *staff.Email == "") {
staff.Email = &item.Email
updated = true
}
if item.Phone != "" && (staff.Phone == nil || *staff.Phone == "") {
staff.Phone = &item.Phone
updated = true
}
if item.Office != "" && (staff.Office == nil || *staff.Office == "") {
staff.Office = &item.Office
updated = true
}
if item.Position != "" && (staff.Position == nil || *staff.Position == "") {
staff.Position = &item.Position
updated = true
}
if item.PositionType != "" && (staff.PositionType == nil || *staff.PositionType == "") {
staff.PositionType = &item.PositionType
updated = true
}
if item.TeamRole != "" && (staff.TeamRole == nil || *staff.TeamRole == "") {
staff.TeamRole = &item.TeamRole
updated = true
}
if len(item.ResearchInterests) > 0 && len(staff.ResearchInterests) == 0 {
staff.ResearchInterests = item.ResearchInterests
updated = true
}
if item.ORCID != "" && (staff.ORCID == nil || *staff.ORCID == "") {
staff.ORCID = &item.ORCID
updated = true
}
// Update last verified
now := time.Now()
staff.LastVerified = &now
if updated {
err = h.repo.CreateStaff(c.Request.Context(), staff)
if err != nil {
results = append(results, gin.H{
"staff_id": item.StaffID,
"status": "error",
"error": err.Error(),
})
errorCount++
continue
}
}
results = append(results, gin.H{
"staff_id": item.StaffID,
"status": "success",
"updated": updated,
})
successCount++
}
c.JSON(http.StatusOK, gin.H{
"results": results,
"success_count": successCount,
"error_count": errorCount,
"total": len(batch.Items),
})
}
// InstituteHierarchyTask represents an institute page to crawl for hierarchy
type InstituteHierarchyTask struct {
InstituteURL string `json:"institute_url"`
InstituteName string `json:"institute_name,omitempty"`
UniversityID uuid.UUID `json:"university_id"`
}
// GetInstitutePages returns institute pages that need hierarchy crawling
// GET /api/v1/ai/extraction/institutes?university_id=...
func (h *AIExtractionHandlers) GetInstitutePages(c *gin.Context) {
var universityID *uuid.UUID
if uniIDStr := c.Query("university_id"); uniIDStr != "" {
id, err := uuid.Parse(uniIDStr)
if err == nil {
universityID = &id
}
}
// Get unique institute/department URLs from staff profiles
params := database.StaffSearchParams{
UniversityID: universityID,
Limit: 1000,
}
result, err := h.repo.SearchStaff(c.Request.Context(), params)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Collect unique source URLs (these are typically department pages)
urlSet := make(map[string]bool)
var tasks []InstituteHierarchyTask
for _, staff := range result.Staff {
if staff.SourceURL != nil && *staff.SourceURL != "" {
url := *staff.SourceURL
if !urlSet[url] {
urlSet[url] = true
tasks = append(tasks, InstituteHierarchyTask{
InstituteURL: url,
UniversityID: staff.UniversityID,
})
}
}
}
c.JSON(http.StatusOK, gin.H{
"institutes": tasks,
"total": len(tasks),
})
}
// InstituteHierarchyData represents hierarchy data extracted from an institute page
type InstituteHierarchyData struct {
InstituteURL string `json:"institute_url" binding:"required"`
UniversityID uuid.UUID `json:"university_id" binding:"required"`
InstituteName string `json:"institute_name,omitempty"`
// Leadership
LeaderName string `json:"leader_name,omitempty"`
LeaderTitle string `json:"leader_title,omitempty"` // e.g., "Professor", "Lehrstuhlinhaber"
// Staff organization
StaffGroups []struct {
Role string `json:"role"` // e.g., "Leitung", "Wissenschaftliche Mitarbeiter", "Sekretariat"
Members []string `json:"members"` // Names of people in this group
} `json:"staff_groups,omitempty"`
// Teaching info (Lehrveranstaltungen)
TeachingCourses []struct {
Title string `json:"title"`
Teacher string `json:"teacher,omitempty"`
} `json:"teaching_courses,omitempty"`
}
// SubmitInstituteHierarchy saves hierarchy data from an institute page
// POST /api/v1/ai/extraction/institutes/submit
func (h *AIExtractionHandlers) SubmitInstituteHierarchy(c *gin.Context) {
var data InstituteHierarchyData
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request: " + err.Error()})
return
}
// Find or create department
dept := &database.Department{
UniversityID: data.UniversityID,
Name: data.InstituteName,
}
if data.InstituteURL != "" {
dept.URL = &data.InstituteURL
}
err := h.repo.CreateDepartment(c.Request.Context(), dept)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create department: " + err.Error()})
return
}
// Find leader and set as supervisor for all staff in this institute
var leaderID *uuid.UUID
if data.LeaderName != "" {
// Search for leader
leaderParams := database.StaffSearchParams{
Query: data.LeaderName,
UniversityID: &data.UniversityID,
Limit: 1,
}
result, err := h.repo.SearchStaff(c.Request.Context(), leaderParams)
if err == nil && len(result.Staff) > 0 {
leaderID = &result.Staff[0].ID
// Update leader with department and role
leader := &result.Staff[0]
leader.DepartmentID = &dept.ID
roleLeitung := "leitung"
leader.TeamRole = &roleLeitung
leader.IsProfessor = true
if data.LeaderTitle != "" {
leader.AcademicTitle = &data.LeaderTitle
}
h.repo.CreateStaff(c.Request.Context(), leader)
}
}
// Process staff groups
updatedCount := 0
for _, group := range data.StaffGroups {
for _, memberName := range group.Members {
// Find staff member
memberParams := database.StaffSearchParams{
Query: memberName,
UniversityID: &data.UniversityID,
Limit: 1,
}
result, err := h.repo.SearchStaff(c.Request.Context(), memberParams)
if err != nil || len(result.Staff) == 0 {
continue
}
member := &result.Staff[0]
member.DepartmentID = &dept.ID
member.TeamRole = &group.Role
// Set supervisor if leader was found and this is not the leader
if leaderID != nil && member.ID != *leaderID {
member.SupervisorID = leaderID
}
h.repo.CreateStaff(c.Request.Context(), member)
updatedCount++
}
}
c.JSON(http.StatusOK, gin.H{
"status": "success",
"department_id": dept.ID,
"leader_id": leaderID,
"members_updated": updatedCount,
})
}
// RegisterAIExtractionRoutes registers AI extraction routes
func (h *AIExtractionHandlers) RegisterRoutes(r *gin.RouterGroup) {
ai := r.Group("/ai/extraction")
// Profile extraction endpoints
ai.GET("/pending", h.GetPendingProfiles)
ai.POST("/submit", h.SubmitExtractedData)
ai.POST("/submit-batch", h.SubmitBatchExtractedData)
// Institute hierarchy endpoints
ai.GET("/institutes", h.GetInstitutePages)
ai.POST("/institutes/submit", h.SubmitInstituteHierarchy)
}

View File

@@ -0,0 +1,314 @@
package handlers
import (
"net/http"
"strconv"
"github.com/breakpilot/edu-search-service/internal/orchestrator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// AudienceHandler handles audience-related HTTP requests
type AudienceHandler struct {
repo orchestrator.AudienceRepository
}
// NewAudienceHandler creates a new audience handler
func NewAudienceHandler(repo orchestrator.AudienceRepository) *AudienceHandler {
return &AudienceHandler{repo: repo}
}
// CreateAudienceRequest represents a request to create an audience
type CreateAudienceRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Filters orchestrator.AudienceFilters `json:"filters"`
CreatedBy string `json:"created_by"`
}
// UpdateAudienceRequest represents a request to update an audience
type UpdateAudienceRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Filters orchestrator.AudienceFilters `json:"filters"`
IsActive bool `json:"is_active"`
}
// CreateExportRequest represents a request to create an export
type CreateExportRequest struct {
ExportType string `json:"export_type" binding:"required"` // csv, json, email_list
Purpose string `json:"purpose"`
ExportedBy string `json:"exported_by"`
}
// ListAudiences returns all audiences
func (h *AudienceHandler) ListAudiences(c *gin.Context) {
activeOnly := c.Query("active_only") == "true"
audiences, err := h.repo.ListAudiences(c.Request.Context(), activeOnly)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list audiences", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"audiences": audiences,
"count": len(audiences),
})
}
// GetAudience returns a single audience
func (h *AudienceHandler) GetAudience(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
audience, err := h.repo.GetAudience(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Audience not found", "details": err.Error()})
return
}
c.JSON(http.StatusOK, audience)
}
// CreateAudience creates a new audience
func (h *AudienceHandler) CreateAudience(c *gin.Context) {
var req CreateAudienceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
audience := &orchestrator.Audience{
Name: req.Name,
Description: req.Description,
Filters: req.Filters,
CreatedBy: req.CreatedBy,
IsActive: true,
}
if err := h.repo.CreateAudience(c.Request.Context(), audience); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create audience", "details": err.Error()})
return
}
// Update the member count
count, _ := h.repo.UpdateAudienceCount(c.Request.Context(), audience.ID)
audience.MemberCount = count
c.JSON(http.StatusCreated, audience)
}
// UpdateAudience updates an existing audience
func (h *AudienceHandler) UpdateAudience(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
var req UpdateAudienceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
audience := &orchestrator.Audience{
ID: id,
Name: req.Name,
Description: req.Description,
Filters: req.Filters,
IsActive: req.IsActive,
}
if err := h.repo.UpdateAudience(c.Request.Context(), audience); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update audience", "details": err.Error()})
return
}
// Update the member count
count, _ := h.repo.UpdateAudienceCount(c.Request.Context(), audience.ID)
audience.MemberCount = count
c.JSON(http.StatusOK, audience)
}
// DeleteAudience soft-deletes an audience
func (h *AudienceHandler) DeleteAudience(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
if err := h.repo.DeleteAudience(c.Request.Context(), id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete audience", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true, "id": idStr})
}
// GetAudienceMembers returns members matching the audience filters
func (h *AudienceHandler) GetAudienceMembers(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
// Parse pagination
limit := 50
offset := 0
if l := c.Query("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 && parsed <= 500 {
limit = parsed
}
}
if o := c.Query("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
}
}
members, totalCount, err := h.repo.GetAudienceMembers(c.Request.Context(), id, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get members", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"members": members,
"count": len(members),
"total_count": totalCount,
"limit": limit,
"offset": offset,
})
}
// RefreshAudienceCount recalculates the member count
func (h *AudienceHandler) RefreshAudienceCount(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
count, err := h.repo.UpdateAudienceCount(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to refresh count", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"audience_id": idStr,
"member_count": count,
})
}
// PreviewAudienceFilters previews the result of filters without saving
func (h *AudienceHandler) PreviewAudienceFilters(c *gin.Context) {
var filters orchestrator.AudienceFilters
if err := c.ShouldBindJSON(&filters); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
// Return the filters for now - preview functionality can be expanded later
c.JSON(http.StatusOK, gin.H{
"filters": filters,
"message": "Preview functionality requires direct repository access",
})
}
// CreateExport creates a new export for an audience
func (h *AudienceHandler) CreateExport(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
var req CreateExportRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
// Get the member count for the export
_, totalCount, err := h.repo.GetAudienceMembers(c.Request.Context(), id, 1, 0)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get members", "details": err.Error()})
return
}
export := &orchestrator.AudienceExport{
AudienceID: id,
ExportType: req.ExportType,
RecordCount: totalCount,
ExportedBy: req.ExportedBy,
Purpose: req.Purpose,
}
if err := h.repo.CreateExport(c.Request.Context(), export); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create export", "details": err.Error()})
return
}
c.JSON(http.StatusCreated, export)
}
// ListExports lists exports for an audience
func (h *AudienceHandler) ListExports(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid audience ID"})
return
}
exports, err := h.repo.ListExports(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list exports", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"exports": exports,
"count": len(exports),
})
}
// SetupAudienceRoutes configures audience API routes
func SetupAudienceRoutes(r *gin.RouterGroup, h *AudienceHandler) {
audiences := r.Group("/audiences")
{
// Audience CRUD
audiences.GET("", h.ListAudiences)
audiences.GET("/:id", h.GetAudience)
audiences.POST("", h.CreateAudience)
audiences.PUT("/:id", h.UpdateAudience)
audiences.DELETE("/:id", h.DeleteAudience)
// Members
audiences.GET("/:id/members", h.GetAudienceMembers)
audiences.POST("/:id/refresh", h.RefreshAudienceCount)
// Exports
audiences.GET("/:id/exports", h.ListExports)
audiences.POST("/:id/exports", h.CreateExport)
// Preview (no audience required)
audiences.POST("/preview", h.PreviewAudienceFilters)
}
}

View File

@@ -0,0 +1,630 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/breakpilot/edu-search-service/internal/orchestrator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// MockAudienceRepository implements orchestrator.AudienceRepository for testing
type MockAudienceRepository struct {
audiences []orchestrator.Audience
exports []orchestrator.AudienceExport
members []orchestrator.AudienceMember
}
func NewMockAudienceRepository() *MockAudienceRepository {
return &MockAudienceRepository{
audiences: make([]orchestrator.Audience, 0),
exports: make([]orchestrator.AudienceExport, 0),
members: make([]orchestrator.AudienceMember, 0),
}
}
func (m *MockAudienceRepository) CreateAudience(ctx context.Context, audience *orchestrator.Audience) error {
audience.ID = uuid.New()
audience.CreatedAt = time.Now()
audience.UpdatedAt = time.Now()
m.audiences = append(m.audiences, *audience)
return nil
}
func (m *MockAudienceRepository) GetAudience(ctx context.Context, id uuid.UUID) (*orchestrator.Audience, error) {
for i := range m.audiences {
if m.audiences[i].ID == id {
return &m.audiences[i], nil
}
}
return nil, context.DeadlineExceeded // simulate not found
}
func (m *MockAudienceRepository) ListAudiences(ctx context.Context, activeOnly bool) ([]orchestrator.Audience, error) {
if activeOnly {
var active []orchestrator.Audience
for _, a := range m.audiences {
if a.IsActive {
active = append(active, a)
}
}
return active, nil
}
return m.audiences, nil
}
func (m *MockAudienceRepository) UpdateAudience(ctx context.Context, audience *orchestrator.Audience) error {
for i := range m.audiences {
if m.audiences[i].ID == audience.ID {
m.audiences[i].Name = audience.Name
m.audiences[i].Description = audience.Description
m.audiences[i].Filters = audience.Filters
m.audiences[i].IsActive = audience.IsActive
m.audiences[i].UpdatedAt = time.Now()
audience.UpdatedAt = m.audiences[i].UpdatedAt
return nil
}
}
return nil
}
func (m *MockAudienceRepository) DeleteAudience(ctx context.Context, id uuid.UUID) error {
for i := range m.audiences {
if m.audiences[i].ID == id {
m.audiences[i].IsActive = false
return nil
}
}
return nil
}
func (m *MockAudienceRepository) GetAudienceMembers(ctx context.Context, id uuid.UUID, limit, offset int) ([]orchestrator.AudienceMember, int, error) {
// Return mock members
if len(m.members) == 0 {
m.members = []orchestrator.AudienceMember{
{
ID: uuid.New(),
Name: "Prof. Dr. Test Person",
Email: "test@university.de",
Position: "professor",
University: "Test Universität",
Department: "Informatik",
SubjectArea: "Informatik",
PublicationCount: 42,
},
{
ID: uuid.New(),
Name: "Dr. Another Person",
Email: "another@university.de",
Position: "researcher",
University: "Test Universität",
Department: "Mathematik",
SubjectArea: "Mathematik",
PublicationCount: 15,
},
}
}
total := len(m.members)
if offset >= total {
return []orchestrator.AudienceMember{}, total, nil
}
end := offset + limit
if end > total {
end = total
}
return m.members[offset:end], total, nil
}
func (m *MockAudienceRepository) UpdateAudienceCount(ctx context.Context, id uuid.UUID) (int, error) {
count := len(m.members)
for i := range m.audiences {
if m.audiences[i].ID == id {
m.audiences[i].MemberCount = count
now := time.Now()
m.audiences[i].LastCountUpdate = &now
}
}
return count, nil
}
func (m *MockAudienceRepository) CreateExport(ctx context.Context, export *orchestrator.AudienceExport) error {
export.ID = uuid.New()
export.CreatedAt = time.Now()
m.exports = append(m.exports, *export)
return nil
}
func (m *MockAudienceRepository) ListExports(ctx context.Context, audienceID uuid.UUID) ([]orchestrator.AudienceExport, error) {
var exports []orchestrator.AudienceExport
for _, e := range m.exports {
if e.AudienceID == audienceID {
exports = append(exports, e)
}
}
return exports, nil
}
func setupAudienceRouter(repo *MockAudienceRepository) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAudienceHandler(repo)
v1 := router.Group("/v1")
SetupAudienceRoutes(v1, handler)
return router
}
func TestAudienceHandler_ListAudiences_Empty(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var response struct {
Audiences []orchestrator.Audience `json:"audiences"`
Count int `json:"count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Count != 0 {
t.Errorf("Expected 0 audiences, got %d", response.Count)
}
}
func TestAudienceHandler_CreateAudience(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
body := CreateAudienceRequest{
Name: "Test Audience",
Description: "A test audience for professors",
Filters: orchestrator.AudienceFilters{
PositionTypes: []string{"professor"},
States: []string{"BW", "BY"},
},
CreatedBy: "test-admin",
}
bodyJSON, _ := json.Marshal(body)
req := httptest.NewRequest(http.MethodPost, "/v1/audiences", bytes.NewBuffer(bodyJSON))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d: %s", http.StatusCreated, w.Code, w.Body.String())
}
var response orchestrator.Audience
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Name != "Test Audience" {
t.Errorf("Expected name 'Test Audience', got '%s'", response.Name)
}
if !response.IsActive {
t.Errorf("Expected audience to be active")
}
if len(repo.audiences) != 1 {
t.Errorf("Expected 1 audience in repo, got %d", len(repo.audiences))
}
}
func TestAudienceHandler_CreateAudience_InvalidJSON(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
req := httptest.NewRequest(http.MethodPost, "/v1/audiences", bytes.NewBuffer([]byte("invalid json")))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestAudienceHandler_CreateAudience_MissingName(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
body := map[string]interface{}{
"description": "Missing name field",
}
bodyJSON, _ := json.Marshal(body)
req := httptest.NewRequest(http.MethodPost, "/v1/audiences", bytes.NewBuffer(bodyJSON))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestAudienceHandler_GetAudience(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
// Create an audience first
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Test Audience",
Description: "Test description",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences/"+audience.ID.String(), nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
var response orchestrator.Audience
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Name != "Test Audience" {
t.Errorf("Expected name 'Test Audience', got '%s'", response.Name)
}
}
func TestAudienceHandler_GetAudience_InvalidID(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences/invalid-uuid", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestAudienceHandler_GetAudience_NotFound(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences/"+uuid.New().String(), nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("Expected status %d, got %d", http.StatusNotFound, w.Code)
}
}
func TestAudienceHandler_UpdateAudience(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
// Create an audience first
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Old Name",
Description: "Old description",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
body := UpdateAudienceRequest{
Name: "New Name",
Description: "New description",
IsActive: true,
}
bodyJSON, _ := json.Marshal(body)
req := httptest.NewRequest(http.MethodPut, "/v1/audiences/"+audience.ID.String(), bytes.NewBuffer(bodyJSON))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
// Verify the update
if repo.audiences[0].Name != "New Name" {
t.Errorf("Expected name 'New Name', got '%s'", repo.audiences[0].Name)
}
}
func TestAudienceHandler_DeleteAudience(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
// Create an audience first
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "To Delete",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
req := httptest.NewRequest(http.MethodDelete, "/v1/audiences/"+audience.ID.String(), nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
// Verify soft delete
if repo.audiences[0].IsActive {
t.Errorf("Expected audience to be inactive after delete")
}
}
func TestAudienceHandler_GetAudienceMembers(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
// Create an audience first
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Test Audience",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences/"+audience.ID.String()+"/members", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
var response struct {
Members []orchestrator.AudienceMember `json:"members"`
Count int `json:"count"`
TotalCount int `json:"total_count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.TotalCount != 2 {
t.Errorf("Expected 2 total members, got %d", response.TotalCount)
}
}
func TestAudienceHandler_GetAudienceMembers_WithPagination(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Test Audience",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences/"+audience.ID.String()+"/members?limit=1&offset=0", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var response struct {
Members []orchestrator.AudienceMember `json:"members"`
Count int `json:"count"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Count != 1 {
t.Errorf("Expected 1 member in response, got %d", response.Count)
}
if response.Limit != 1 {
t.Errorf("Expected limit 1, got %d", response.Limit)
}
}
func TestAudienceHandler_RefreshAudienceCount(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Test Audience",
IsActive: true,
MemberCount: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
// Pre-initialize members so count works correctly
repo.members = []orchestrator.AudienceMember{
{ID: uuid.New(), Name: "Test Person 1"},
{ID: uuid.New(), Name: "Test Person 2"},
}
req := httptest.NewRequest(http.MethodPost, "/v1/audiences/"+audience.ID.String()+"/refresh", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var response struct {
AudienceID string `json:"audience_id"`
MemberCount int `json:"member_count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.MemberCount != 2 {
t.Errorf("Expected member_count 2, got %d", response.MemberCount)
}
}
func TestAudienceHandler_CreateExport(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Test Audience",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
body := CreateExportRequest{
ExportType: "csv",
Purpose: "Newsletter December 2024",
ExportedBy: "admin",
}
bodyJSON, _ := json.Marshal(body)
req := httptest.NewRequest(http.MethodPost, "/v1/audiences/"+audience.ID.String()+"/exports", bytes.NewBuffer(bodyJSON))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d: %s", http.StatusCreated, w.Code, w.Body.String())
}
var response orchestrator.AudienceExport
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.ExportType != "csv" {
t.Errorf("Expected export_type 'csv', got '%s'", response.ExportType)
}
if response.RecordCount != 2 {
t.Errorf("Expected record_count 2, got %d", response.RecordCount)
}
}
func TestAudienceHandler_ListExports(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
audience := orchestrator.Audience{
ID: uuid.New(),
Name: "Test Audience",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
repo.audiences = append(repo.audiences, audience)
// Add an export
export := orchestrator.AudienceExport{
ID: uuid.New(),
AudienceID: audience.ID,
ExportType: "csv",
RecordCount: 100,
Purpose: "Test export",
CreatedAt: time.Now(),
}
repo.exports = append(repo.exports, export)
req := httptest.NewRequest(http.MethodGet, "/v1/audiences/"+audience.ID.String()+"/exports", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var response struct {
Exports []orchestrator.AudienceExport `json:"exports"`
Count int `json:"count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Count != 1 {
t.Errorf("Expected 1 export, got %d", response.Count)
}
}
func TestAudienceHandler_ListAudiences_ActiveOnly(t *testing.T) {
repo := NewMockAudienceRepository()
router := setupAudienceRouter(repo)
// Add active and inactive audiences
repo.audiences = []orchestrator.Audience{
{ID: uuid.New(), Name: "Active", IsActive: true, CreatedAt: time.Now(), UpdatedAt: time.Now()},
{ID: uuid.New(), Name: "Inactive", IsActive: false, CreatedAt: time.Now(), UpdatedAt: time.Now()},
}
req := httptest.NewRequest(http.MethodGet, "/v1/audiences?active_only=true", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var response struct {
Audiences []orchestrator.Audience `json:"audiences"`
Count int `json:"count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Count != 1 {
t.Errorf("Expected 1 active audience, got %d", response.Count)
}
if response.Audiences[0].Name != "Active" {
t.Errorf("Expected audience 'Active', got '%s'", response.Audiences[0].Name)
}
}

View File

@@ -0,0 +1,146 @@
package handlers
import (
"net/http"
"github.com/breakpilot/edu-search-service/internal/config"
"github.com/breakpilot/edu-search-service/internal/indexer"
"github.com/breakpilot/edu-search-service/internal/search"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// Handler contains all HTTP handlers
type Handler struct {
cfg *config.Config
searchService *search.Service
indexClient *indexer.Client
}
// NewHandler creates a new handler instance
func NewHandler(cfg *config.Config, searchService *search.Service, indexClient *indexer.Client) *Handler {
return &Handler{
cfg: cfg,
searchService: searchService,
indexClient: indexClient,
}
}
// Health returns service health status
func (h *Handler) Health(c *gin.Context) {
status := "ok"
// Check OpenSearch health
osStatus, err := h.indexClient.Health(c.Request.Context())
if err != nil {
status = "degraded"
osStatus = "unreachable"
}
c.JSON(http.StatusOK, gin.H{
"status": status,
"opensearch": osStatus,
"service": "edu-search-service",
"version": "0.1.0",
})
}
// Search handles /v1/search requests
func (h *Handler) Search(c *gin.Context) {
var req search.SearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
// Set defaults
if req.Limit <= 0 || req.Limit > 100 {
req.Limit = 10
}
if req.Mode == "" {
req.Mode = "keyword" // MVP: only BM25
}
// Generate query ID
queryID := uuid.New().String()
// Execute search
result, err := h.searchService.Search(c.Request.Context(), &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Search failed", "details": err.Error()})
return
}
result.QueryID = queryID
c.JSON(http.StatusOK, result)
}
// GetDocument retrieves a single document
func (h *Handler) GetDocument(c *gin.Context) {
docID := c.Query("doc_id")
if docID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "doc_id parameter required"})
return
}
// TODO: Implement document retrieval
c.JSON(http.StatusNotImplemented, gin.H{"error": "Not implemented yet"})
}
// AuthMiddleware validates API keys
func AuthMiddleware(apiKey string) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip auth for health endpoint
if c.Request.URL.Path == "/v1/health" {
c.Next()
return
}
// Check API key
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing Authorization header"})
return
}
// Extract Bearer token
if len(authHeader) < 7 || authHeader[:7] != "Bearer " {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
return
}
token := authHeader[7:]
if apiKey != "" && token != apiKey {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
return
}
c.Next()
}
}
// RateLimitMiddleware implements basic rate limiting
func RateLimitMiddleware() gin.HandlerFunc {
// TODO: Implement proper rate limiting with Redis
return func(c *gin.Context) {
c.Next()
}
}
// SetupRoutes configures all API routes
func SetupRoutes(r *gin.Engine, h *Handler, apiKey string) {
// Health endpoint (no auth)
r.GET("/v1/health", h.Health)
// API v1 group with auth
v1 := r.Group("/v1")
v1.Use(AuthMiddleware(apiKey))
v1.Use(RateLimitMiddleware())
{
v1.POST("/search", h.Search)
v1.GET("/document", h.GetDocument)
// Admin routes
SetupAdminRoutes(v1, h)
}
}

View File

@@ -0,0 +1,645 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
// setupTestRouter creates a test router with the handler
func setupTestRouter(h *Handler, apiKey string) *gin.Engine {
router := gin.New()
SetupRoutes(router, h, apiKey)
return router
}
// setupTestSeedStore creates a test seed store
func setupTestSeedStore(t *testing.T) string {
t.Helper()
dir := t.TempDir()
// Initialize global seed store
err := InitSeedStore(dir)
if err != nil {
t.Fatalf("Failed to initialize seed store: %v", err)
}
return dir
}
func TestHealthEndpoint(t *testing.T) {
// Health endpoint requires indexClient for health check
// This test verifies the route is set up correctly
// A full integration test would need a mock OpenSearch client
t.Skip("Skipping: requires mock indexer client for full test")
}
func TestAuthMiddleware_NoAuth(t *testing.T) {
h := &Handler{}
router := setupTestRouter(h, "test-api-key")
// Request without auth header
req, _ := http.NewRequest("POST", "/v1/search", bytes.NewBufferString(`{"q":"test"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
}
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
h := &Handler{}
router := setupTestRouter(h, "test-api-key")
// Request with wrong auth format
req, _ := http.NewRequest("POST", "/v1/search", bytes.NewBufferString(`{"q":"test"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Basic dGVzdDp0ZXN0") // Basic auth instead of Bearer
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
}
func TestAuthMiddleware_InvalidKey(t *testing.T) {
h := &Handler{}
router := setupTestRouter(h, "test-api-key")
// Request with wrong API key
req, _ := http.NewRequest("POST", "/v1/search", bytes.NewBufferString(`{"q":"test"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer wrong-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
}
func TestAuthMiddleware_ValidKey(t *testing.T) {
h := &Handler{}
router := setupTestRouter(h, "test-api-key")
// Request with correct API key (search will fail due to no search service, but auth should pass)
req, _ := http.NewRequest("GET", "/v1/document?doc_id=test", nil)
req.Header.Set("Authorization", "Bearer test-api-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Auth should pass, endpoint returns 501 (not implemented)
if w.Code == http.StatusUnauthorized {
t.Error("Expected auth to pass, got 401")
}
}
func TestAuthMiddleware_HealthNoAuth(t *testing.T) {
// Health endpoint requires indexClient for health check
// Skipping because route calls h.indexClient.Health() which panics with nil
t.Skip("Skipping: requires mock indexer client for full test")
}
func TestGetDocument_MissingDocID(t *testing.T) {
h := &Handler{}
router := setupTestRouter(h, "test-key")
req, _ := http.NewRequest("GET", "/v1/document", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// Admin Handler Tests
func TestSeedStore_InitAndLoad(t *testing.T) {
dir := t.TempDir()
// First initialization should create default seeds
err := InitSeedStore(dir)
if err != nil {
t.Fatalf("InitSeedStore failed: %v", err)
}
// Check that seeds file was created
seedsFile := filepath.Join(dir, "seeds.json")
if _, err := os.Stat(seedsFile); os.IsNotExist(err) {
t.Error("seeds.json was not created")
}
// Check that default seeds were loaded
seeds := seedStore.GetAllSeeds()
if len(seeds) == 0 {
t.Error("Expected default seeds to be loaded")
}
}
func TestSeedStore_CreateSeed(t *testing.T) {
setupTestSeedStore(t)
newSeed := SeedURL{
URL: "https://test.example.com",
Name: "Test Seed",
Category: "test",
Description: "A test seed",
TrustBoost: 0.5,
Enabled: true,
}
created, err := seedStore.CreateSeed(newSeed)
if err != nil {
t.Fatalf("CreateSeed failed: %v", err)
}
if created.ID == "" {
t.Error("Expected generated ID")
}
if created.URL != newSeed.URL {
t.Errorf("Expected URL %q, got %q", newSeed.URL, created.URL)
}
if created.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
}
func TestSeedStore_GetSeed(t *testing.T) {
setupTestSeedStore(t)
// Create a seed first
newSeed := SeedURL{
URL: "https://get-test.example.com",
Name: "Get Test",
Category: "test",
}
created, _ := seedStore.CreateSeed(newSeed)
// Get the seed
retrieved, found := seedStore.GetSeed(created.ID)
if !found {
t.Fatal("Seed not found")
}
if retrieved.URL != newSeed.URL {
t.Errorf("Expected URL %q, got %q", newSeed.URL, retrieved.URL)
}
}
func TestSeedStore_GetSeed_NotFound(t *testing.T) {
setupTestSeedStore(t)
_, found := seedStore.GetSeed("nonexistent-id")
if found {
t.Error("Expected seed not to be found")
}
}
func TestSeedStore_UpdateSeed(t *testing.T) {
setupTestSeedStore(t)
// Create a seed first
original := SeedURL{
URL: "https://update-test.example.com",
Name: "Original Name",
Category: "test",
Enabled: true,
}
created, _ := seedStore.CreateSeed(original)
// Update the seed
updates := SeedURL{
Name: "Updated Name",
TrustBoost: 0.75,
Enabled: false,
}
updated, found, err := seedStore.UpdateSeed(created.ID, updates)
if err != nil {
t.Fatalf("UpdateSeed failed: %v", err)
}
if !found {
t.Fatal("Seed not found for update")
}
if updated.Name != "Updated Name" {
t.Errorf("Expected name 'Updated Name', got %q", updated.Name)
}
if updated.TrustBoost != 0.75 {
t.Errorf("Expected TrustBoost 0.75, got %f", updated.TrustBoost)
}
if updated.Enabled != false {
t.Error("Expected Enabled to be false")
}
// URL should remain unchanged since we didn't provide it
if updated.URL != original.URL {
t.Errorf("URL should remain unchanged, expected %q, got %q", original.URL, updated.URL)
}
}
func TestSeedStore_UpdateSeed_NotFound(t *testing.T) {
setupTestSeedStore(t)
updates := SeedURL{Name: "New Name"}
_, found, err := seedStore.UpdateSeed("nonexistent-id", updates)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if found {
t.Error("Expected seed not to be found")
}
}
func TestSeedStore_DeleteSeed(t *testing.T) {
setupTestSeedStore(t)
// Create a seed first
newSeed := SeedURL{
URL: "https://delete-test.example.com",
Name: "Delete Test",
Category: "test",
}
created, _ := seedStore.CreateSeed(newSeed)
// Delete the seed
deleted := seedStore.DeleteSeed(created.ID)
if !deleted {
t.Error("Expected delete to succeed")
}
// Verify it's gone
_, found := seedStore.GetSeed(created.ID)
if found {
t.Error("Seed should have been deleted")
}
}
func TestSeedStore_DeleteSeed_NotFound(t *testing.T) {
setupTestSeedStore(t)
deleted := seedStore.DeleteSeed("nonexistent-id")
if deleted {
t.Error("Expected delete to return false for nonexistent seed")
}
}
func TestSeedStore_Persistence(t *testing.T) {
dir := t.TempDir()
// Create and populate seed store
err := InitSeedStore(dir)
if err != nil {
t.Fatal(err)
}
newSeed := SeedURL{
URL: "https://persist-test.example.com",
Name: "Persistence Test",
Category: "test",
}
created, err := seedStore.CreateSeed(newSeed)
if err != nil {
t.Fatal(err)
}
// Re-initialize from the same directory
seedStore = nil
err = InitSeedStore(dir)
if err != nil {
t.Fatal(err)
}
// Check if the seed persisted
retrieved, found := seedStore.GetSeed(created.ID)
if !found {
t.Error("Seed should have persisted")
}
if retrieved.URL != newSeed.URL {
t.Errorf("Persisted seed URL mismatch: expected %q, got %q", newSeed.URL, retrieved.URL)
}
}
func TestAdminGetSeeds(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
// Initialize seed store for the test
InitSeedStore(dir)
req, _ := http.NewRequest("GET", "/v1/admin/seeds", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var seeds []SeedURL
if err := json.Unmarshal(w.Body.Bytes(), &seeds); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Should have default seeds
if len(seeds) == 0 {
t.Error("Expected seeds to be returned")
}
}
func TestAdminCreateSeed(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
newSeed := map[string]interface{}{
"url": "https://new-seed.example.com",
"name": "New Seed",
"category": "test",
"description": "Test description",
"trustBoost": 0.5,
"enabled": true,
}
body, _ := json.Marshal(newSeed)
req, _ := http.NewRequest("POST", "/v1/admin/seeds", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String())
}
var created SeedURL
if err := json.Unmarshal(w.Body.Bytes(), &created); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if created.ID == "" {
t.Error("Expected ID to be generated")
}
if created.URL != "https://new-seed.example.com" {
t.Errorf("Expected URL to match, got %q", created.URL)
}
}
func TestAdminCreateSeed_MissingURL(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
newSeed := map[string]interface{}{
"name": "No URL Seed",
"category": "test",
}
body, _ := json.Marshal(newSeed)
req, _ := http.NewRequest("POST", "/v1/admin/seeds", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing URL, got %d", w.Code)
}
}
func TestAdminUpdateSeed(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
// Create a seed first
newSeed := SeedURL{
URL: "https://update-api-test.example.com",
Name: "API Update Test",
Category: "test",
}
created, _ := seedStore.CreateSeed(newSeed)
// Update via API
updates := map[string]interface{}{
"name": "Updated via API",
"trustBoost": 0.8,
}
body, _ := json.Marshal(updates)
req, _ := http.NewRequest("PUT", "/v1/admin/seeds/"+created.ID, bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String())
}
var updated SeedURL
if err := json.Unmarshal(w.Body.Bytes(), &updated); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if updated.Name != "Updated via API" {
t.Errorf("Expected name 'Updated via API', got %q", updated.Name)
}
}
func TestAdminDeleteSeed(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
// Create a seed first
newSeed := SeedURL{
URL: "https://delete-api-test.example.com",
Name: "API Delete Test",
Category: "test",
}
created, _ := seedStore.CreateSeed(newSeed)
// Delete via API
req, _ := http.NewRequest("DELETE", "/v1/admin/seeds/"+created.ID, nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Verify it's deleted
_, found := seedStore.GetSeed(created.ID)
if found {
t.Error("Seed should have been deleted")
}
}
func TestAdminDeleteSeed_NotFound(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
req, _ := http.NewRequest("DELETE", "/v1/admin/seeds/nonexistent-id", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", w.Code)
}
}
func TestAdminGetStats(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
req, _ := http.NewRequest("GET", "/v1/admin/stats", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var stats CrawlStats
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Check that stats structure is populated
if stats.CrawlStatus == "" {
t.Error("Expected CrawlStatus to be set")
}
if stats.DocumentsPerCategory == nil {
t.Error("Expected DocumentsPerCategory to be set")
}
}
func TestAdminStartCrawl(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
// Reset crawl status
crawlStatus = "idle"
req, _ := http.NewRequest("POST", "/v1/admin/crawl/start", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusAccepted {
t.Errorf("Expected status 202, got %d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "started" {
t.Errorf("Expected status 'started', got %v", response["status"])
}
}
func TestAdminStartCrawl_AlreadyRunning(t *testing.T) {
dir := setupTestSeedStore(t)
h := &Handler{}
router := gin.New()
SetupRoutes(router, h, "test-key")
InitSeedStore(dir)
// Set crawl status to running
crawlStatus = "running"
req, _ := http.NewRequest("POST", "/v1/admin/crawl/start", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusConflict {
t.Errorf("Expected status 409, got %d", w.Code)
}
// Reset for other tests
crawlStatus = "idle"
}
func TestConcurrentSeedAccess(t *testing.T) {
setupTestSeedStore(t)
// Test concurrent reads and writes
done := make(chan bool, 10)
// Concurrent readers
for i := 0; i < 5; i++ {
go func() {
seedStore.GetAllSeeds()
done <- true
}()
}
// Concurrent writers
for i := 0; i < 5; i++ {
go func(n int) {
seed := SeedURL{
URL: "https://concurrent-" + string(rune('A'+n)) + ".example.com",
Name: "Concurrent Test",
Category: "test",
}
seedStore.CreateSeed(seed)
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// If we get here without deadlock or race, test passes
}

View File

@@ -0,0 +1,207 @@
package handlers
import (
"net/http"
"github.com/breakpilot/edu-search-service/internal/orchestrator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// OrchestratorHandler handles orchestrator-related HTTP requests
type OrchestratorHandler struct {
orchestrator *orchestrator.Orchestrator
repo orchestrator.Repository
}
// NewOrchestratorHandler creates a new orchestrator handler
func NewOrchestratorHandler(orch *orchestrator.Orchestrator, repo orchestrator.Repository) *OrchestratorHandler {
return &OrchestratorHandler{
orchestrator: orch,
repo: repo,
}
}
// AddToQueueRequest represents a request to add a university to the crawl queue
type AddToQueueRequest struct {
UniversityID string `json:"university_id" binding:"required"`
Priority int `json:"priority"`
InitiatedBy string `json:"initiated_by"`
}
// GetStatus returns the current orchestrator status
func (h *OrchestratorHandler) GetStatus(c *gin.Context) {
status, err := h.orchestrator.Status(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get status", "details": err.Error()})
return
}
c.JSON(http.StatusOK, status)
}
// GetQueue returns all items in the crawl queue
func (h *OrchestratorHandler) GetQueue(c *gin.Context) {
items, err := h.orchestrator.GetQueue(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get queue", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"queue": items,
"count": len(items),
})
}
// AddToQueue adds a university to the crawl queue
func (h *OrchestratorHandler) AddToQueue(c *gin.Context) {
var req AddToQueueRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
universityID, err := uuid.Parse(req.UniversityID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university_id format"})
return
}
// Default priority if not specified
priority := req.Priority
if priority == 0 {
priority = 5
}
initiatedBy := req.InitiatedBy
if initiatedBy == "" {
initiatedBy = "api"
}
item, err := h.orchestrator.AddUniversity(c.Request.Context(), universityID, priority, initiatedBy)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to add to queue", "details": err.Error()})
return
}
c.JSON(http.StatusCreated, item)
}
// RemoveFromQueue removes a university from the crawl queue
func (h *OrchestratorHandler) RemoveFromQueue(c *gin.Context) {
idStr := c.Param("id")
if idStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "University ID required"})
return
}
universityID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university_id format"})
return
}
if err := h.orchestrator.RemoveUniversity(c.Request.Context(), universityID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to remove from queue", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true, "university_id": idStr})
}
// Start starts the orchestrator
func (h *OrchestratorHandler) Start(c *gin.Context) {
if err := h.orchestrator.Start(); err != nil {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": "started",
"message": "Orchestrator started successfully",
})
}
// Stop stops the orchestrator
func (h *OrchestratorHandler) Stop(c *gin.Context) {
if err := h.orchestrator.Stop(); err != nil {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": "stopped",
"message": "Orchestrator stopped successfully",
})
}
// PauseUniversity pauses crawling for a specific university
func (h *OrchestratorHandler) PauseUniversity(c *gin.Context) {
idStr := c.Param("id")
if idStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "University ID required"})
return
}
universityID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university_id format"})
return
}
if err := h.orchestrator.PauseUniversity(c.Request.Context(), universityID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to pause crawl", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": "paused",
"university_id": idStr,
})
}
// ResumeUniversity resumes crawling for a paused university
func (h *OrchestratorHandler) ResumeUniversity(c *gin.Context) {
idStr := c.Param("id")
if idStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "University ID required"})
return
}
universityID, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university_id format"})
return
}
if err := h.orchestrator.ResumeUniversity(c.Request.Context(), universityID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to resume crawl", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": "resumed",
"university_id": idStr,
})
}
// SetupOrchestratorRoutes configures orchestrator API routes
func SetupOrchestratorRoutes(r *gin.RouterGroup, h *OrchestratorHandler) {
crawl := r.Group("/crawl")
{
// Orchestrator control
crawl.GET("/status", h.GetStatus)
crawl.POST("/start", h.Start)
crawl.POST("/stop", h.Stop)
// Queue management
crawl.GET("/queue", h.GetQueue)
crawl.POST("/queue", h.AddToQueue)
crawl.DELETE("/queue/:id", h.RemoveFromQueue)
// Individual university control
crawl.POST("/queue/:id/pause", h.PauseUniversity)
crawl.POST("/queue/:id/resume", h.ResumeUniversity)
}
}

View File

@@ -0,0 +1,659 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/breakpilot/edu-search-service/internal/orchestrator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func init() {
gin.SetMode(gin.TestMode)
}
// MockRepository implements orchestrator.Repository for testing
type MockRepository struct {
items []orchestrator.CrawlQueueItem
failOnAdd bool
failOnUpdate bool
}
func NewMockRepository() *MockRepository {
return &MockRepository{
items: make([]orchestrator.CrawlQueueItem, 0),
}
}
func (m *MockRepository) GetQueueItems(ctx context.Context) ([]orchestrator.CrawlQueueItem, error) {
return m.items, nil
}
func (m *MockRepository) GetNextInQueue(ctx context.Context) (*orchestrator.CrawlQueueItem, error) {
for i := range m.items {
if m.items[i].CurrentPhase != orchestrator.PhaseCompleted &&
m.items[i].CurrentPhase != orchestrator.PhaseFailed &&
m.items[i].CurrentPhase != orchestrator.PhasePaused {
return &m.items[i], nil
}
}
return nil, nil
}
func (m *MockRepository) AddToQueue(ctx context.Context, universityID uuid.UUID, priority int, initiatedBy string) (*orchestrator.CrawlQueueItem, error) {
if m.failOnAdd {
return nil, context.DeadlineExceeded
}
position := len(m.items) + 1
item := orchestrator.CrawlQueueItem{
ID: uuid.New(),
UniversityID: universityID,
QueuePosition: &position,
Priority: priority,
CurrentPhase: orchestrator.PhasePending,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
m.items = append(m.items, item)
return &item, nil
}
func (m *MockRepository) RemoveFromQueue(ctx context.Context, universityID uuid.UUID) error {
for i, item := range m.items {
if item.UniversityID == universityID {
m.items = append(m.items[:i], m.items[i+1:]...)
return nil
}
}
return nil
}
func (m *MockRepository) UpdateQueueItem(ctx context.Context, item *orchestrator.CrawlQueueItem) error {
if m.failOnUpdate {
return context.DeadlineExceeded
}
for i, existing := range m.items {
if existing.UniversityID == item.UniversityID {
m.items[i] = *item
return nil
}
}
return nil
}
func (m *MockRepository) PauseQueueItem(ctx context.Context, universityID uuid.UUID) error {
for i, item := range m.items {
if item.UniversityID == universityID {
m.items[i].CurrentPhase = orchestrator.PhasePaused
return nil
}
}
return nil
}
func (m *MockRepository) ResumeQueueItem(ctx context.Context, universityID uuid.UUID) error {
for i, item := range m.items {
if item.UniversityID == universityID && m.items[i].CurrentPhase == orchestrator.PhasePaused {
m.items[i].CurrentPhase = orchestrator.PhasePending
return nil
}
}
return nil
}
func (m *MockRepository) CompletePhase(ctx context.Context, universityID uuid.UUID, phase orchestrator.CrawlPhase, count int) error {
return nil
}
func (m *MockRepository) FailPhase(ctx context.Context, universityID uuid.UUID, phase orchestrator.CrawlPhase, errMsg string) error {
return nil
}
func (m *MockRepository) GetCompletedTodayCount(ctx context.Context) (int, error) {
count := 0
today := time.Now().Truncate(24 * time.Hour)
for _, item := range m.items {
if item.CurrentPhase == orchestrator.PhaseCompleted &&
item.CompletedAt != nil &&
item.CompletedAt.After(today) {
count++
}
}
return count, nil
}
func (m *MockRepository) GetTotalProcessedCount(ctx context.Context) (int, error) {
count := 0
for _, item := range m.items {
if item.CurrentPhase == orchestrator.PhaseCompleted {
count++
}
}
return count, nil
}
// MockStaffCrawler implements orchestrator.StaffCrawlerInterface
type MockStaffCrawler struct{}
func (m *MockStaffCrawler) DiscoverSampleProfessor(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
return &orchestrator.CrawlProgress{
Phase: orchestrator.PhaseDiscovery,
ItemsFound: 1,
}, nil
}
func (m *MockStaffCrawler) CrawlProfessors(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
return &orchestrator.CrawlProgress{
Phase: orchestrator.PhaseProfessors,
ItemsFound: 10,
}, nil
}
func (m *MockStaffCrawler) CrawlAllStaff(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
return &orchestrator.CrawlProgress{
Phase: orchestrator.PhaseAllStaff,
ItemsFound: 50,
}, nil
}
// MockPubCrawler implements orchestrator.PublicationCrawlerInterface
type MockPubCrawler struct{}
func (m *MockPubCrawler) CrawlPublicationsForUniversity(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
return &orchestrator.CrawlProgress{
Phase: orchestrator.PhasePublications,
ItemsFound: 100,
}, nil
}
// setupOrchestratorTestRouter creates a test router with orchestrator handler
func setupOrchestratorTestRouter(orch *orchestrator.Orchestrator, repo orchestrator.Repository, apiKey string) *gin.Engine {
router := gin.New()
handler := NewOrchestratorHandler(orch, repo)
v1 := router.Group("/v1")
v1.Use(AuthMiddleware(apiKey))
SetupOrchestratorRoutes(v1, handler)
return router
}
func TestOrchestratorGetStatus(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("GET", "/v1/crawl/status", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String())
}
var status orchestrator.OrchestratorStatus
if err := json.Unmarshal(w.Body.Bytes(), &status); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if status.IsRunning != false {
t.Error("Expected orchestrator to not be running initially")
}
}
func TestOrchestratorGetQueue(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("GET", "/v1/crawl/queue", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String())
}
var response struct {
Queue []orchestrator.CrawlQueueItem `json:"queue"`
Count int `json:"count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response.Count != 0 {
t.Errorf("Expected empty queue, got %d items", response.Count)
}
}
func TestOrchestratorAddToQueue(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
universityID := uuid.New()
reqBody := AddToQueueRequest{
UniversityID: universityID.String(),
Priority: 7,
InitiatedBy: "test-user",
}
body, _ := json.Marshal(reqBody)
req, _ := http.NewRequest("POST", "/v1/crawl/queue", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String())
}
var item orchestrator.CrawlQueueItem
if err := json.Unmarshal(w.Body.Bytes(), &item); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if item.UniversityID != universityID {
t.Errorf("Expected universityID %s, got %s", universityID, item.UniversityID)
}
if item.Priority != 7 {
t.Errorf("Expected priority 7, got %d", item.Priority)
}
}
func TestOrchestratorAddToQueue_InvalidUUID(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
reqBody := map[string]interface{}{
"university_id": "not-a-valid-uuid",
"priority": 5,
}
body, _ := json.Marshal(reqBody)
req, _ := http.NewRequest("POST", "/v1/crawl/queue", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestOrchestratorAddToQueue_MissingUniversityID(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
reqBody := map[string]interface{}{
"priority": 5,
}
body, _ := json.Marshal(reqBody)
req, _ := http.NewRequest("POST", "/v1/crawl/queue", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestOrchestratorRemoveFromQueue(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
// Add an item first
universityID := uuid.New()
repo.AddToQueue(context.Background(), universityID, 5, "test")
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("DELETE", "/v1/crawl/queue/"+universityID.String(), nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify it was removed
items, _ := repo.GetQueueItems(context.Background())
if len(items) != 0 {
t.Errorf("Expected queue to be empty, got %d items", len(items))
}
}
func TestOrchestratorRemoveFromQueue_InvalidUUID(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("DELETE", "/v1/crawl/queue/invalid-uuid", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestOrchestratorStartStop(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
// Start orchestrator
req, _ := http.NewRequest("POST", "/v1/crawl/start", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 on start, got %d: %s", w.Code, w.Body.String())
}
// Try to start again (should fail)
req, _ = http.NewRequest("POST", "/v1/crawl/start", nil)
req.Header.Set("Authorization", "Bearer test-key")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusConflict {
t.Errorf("Expected status 409 on duplicate start, got %d", w.Code)
}
// Stop orchestrator
req, _ = http.NewRequest("POST", "/v1/crawl/stop", nil)
req.Header.Set("Authorization", "Bearer test-key")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 on stop, got %d: %s", w.Code, w.Body.String())
}
// Try to stop again (should fail)
req, _ = http.NewRequest("POST", "/v1/crawl/stop", nil)
req.Header.Set("Authorization", "Bearer test-key")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusConflict {
t.Errorf("Expected status 409 on duplicate stop, got %d", w.Code)
}
}
func TestOrchestratorPauseResume(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
// Add an item first
universityID := uuid.New()
repo.AddToQueue(context.Background(), universityID, 5, "test")
router := setupOrchestratorTestRouter(orch, repo, "test-key")
// Pause university
req, _ := http.NewRequest("POST", "/v1/crawl/queue/"+universityID.String()+"/pause", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 on pause, got %d: %s", w.Code, w.Body.String())
}
// Verify it's paused
items, _ := repo.GetQueueItems(context.Background())
if len(items) != 1 || items[0].CurrentPhase != orchestrator.PhasePaused {
t.Errorf("Expected item to be paused, got phase %s", items[0].CurrentPhase)
}
// Resume university
req, _ = http.NewRequest("POST", "/v1/crawl/queue/"+universityID.String()+"/resume", nil)
req.Header.Set("Authorization", "Bearer test-key")
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 on resume, got %d: %s", w.Code, w.Body.String())
}
// Verify it's resumed
items, _ = repo.GetQueueItems(context.Background())
if len(items) != 1 || items[0].CurrentPhase == orchestrator.PhasePaused {
t.Errorf("Expected item to not be paused, got phase %s", items[0].CurrentPhase)
}
}
func TestOrchestratorPause_InvalidUUID(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("POST", "/v1/crawl/queue/invalid-uuid/pause", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestOrchestratorNoAuth(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
// Request without auth
req, _ := http.NewRequest("GET", "/v1/crawl/status", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
}
func TestOrchestratorDefaultPriority(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
// Add without priority (should default to 5)
universityID := uuid.New()
reqBody := AddToQueueRequest{
UniversityID: universityID.String(),
// Priority and InitiatedBy omitted
}
body, _ := json.Marshal(reqBody)
req, _ := http.NewRequest("POST", "/v1/crawl/queue", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String())
}
var item orchestrator.CrawlQueueItem
if err := json.Unmarshal(w.Body.Bytes(), &item); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if item.Priority != 5 {
t.Errorf("Expected default priority 5, got %d", item.Priority)
}
}
// TestOrchestratorQueueWithNullableFields tests that queue items with NULL values
// for optional fields (UniversityShort, LastError) are handled correctly.
// This tests the COALESCE fix in repository.go that prevents NULL scan errors.
func TestOrchestratorQueueWithNullableFields(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
// Add item with empty optional fields (simulates NULL from DB)
universityID := uuid.New()
item := orchestrator.CrawlQueueItem{
ID: uuid.New(),
UniversityID: universityID,
UniversityName: "Test Universität",
UniversityShort: "", // Empty string (COALESCE converts NULL to '')
CurrentPhase: orchestrator.PhasePending,
LastError: "", // Empty string (COALESCE converts NULL to '')
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
position := 1
item.QueuePosition = &position
repo.items = append(repo.items, item)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("GET", "/v1/crawl/queue", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String())
}
var response struct {
Queue []orchestrator.CrawlQueueItem `json:"queue"`
Count int `json:"count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response.Count != 1 {
t.Errorf("Expected 1 item in queue, got %d", response.Count)
}
// Verify empty strings are preserved (not NULL)
if response.Queue[0].UniversityShort != "" {
t.Errorf("Expected empty UniversityShort, got %q", response.Queue[0].UniversityShort)
}
if response.Queue[0].LastError != "" {
t.Errorf("Expected empty LastError, got %q", response.Queue[0].LastError)
}
}
// TestOrchestratorQueueWithLastError tests that queue items with an error message
// are correctly serialized and returned.
func TestOrchestratorQueueWithLastError(t *testing.T) {
repo := NewMockRepository()
staffCrawler := &MockStaffCrawler{}
pubCrawler := &MockPubCrawler{}
orch := orchestrator.NewOrchestrator(repo, staffCrawler, pubCrawler)
// Add item with an error
universityID := uuid.New()
item := orchestrator.CrawlQueueItem{
ID: uuid.New(),
UniversityID: universityID,
UniversityName: "Test Universität mit Fehler",
UniversityShort: "TUmF",
CurrentPhase: orchestrator.PhaseFailed,
LastError: "connection timeout after 30s",
RetryCount: 3,
MaxRetries: 3,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
position := 1
item.QueuePosition = &position
repo.items = append(repo.items, item)
router := setupOrchestratorTestRouter(orch, repo, "test-key")
req, _ := http.NewRequest("GET", "/v1/crawl/queue", nil)
req.Header.Set("Authorization", "Bearer test-key")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String())
}
var response struct {
Queue []orchestrator.CrawlQueueItem `json:"queue"`
Count int `json:"count"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response.Count != 1 {
t.Errorf("Expected 1 item in queue, got %d", response.Count)
}
// Verify error message is preserved
if response.Queue[0].LastError != "connection timeout after 30s" {
t.Errorf("Expected LastError to be 'connection timeout after 30s', got %q", response.Queue[0].LastError)
}
if response.Queue[0].UniversityShort != "TUmF" {
t.Errorf("Expected UniversityShort 'TUmF', got %q", response.Queue[0].UniversityShort)
}
}

View File

@@ -0,0 +1,700 @@
package handlers
import (
"net/http"
"time"
"github.com/breakpilot/edu-search-service/internal/policy"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// PolicyHandler contains all policy-related HTTP handlers.
type PolicyHandler struct {
store *policy.Store
enforcer *policy.Enforcer
}
// policyHandler is the singleton instance
var policyHandler *PolicyHandler
// InitPolicyHandler initializes the policy handler with a database pool.
func InitPolicyHandler(store *policy.Store) {
policyHandler = &PolicyHandler{
store: store,
enforcer: policy.NewEnforcer(store),
}
}
// GetPolicyHandler returns the policy handler instance.
func GetPolicyHandler() *PolicyHandler {
return policyHandler
}
// =============================================================================
// POLICIES
// =============================================================================
// ListPolicies returns all source policies.
func (h *PolicyHandler) ListPolicies(c *gin.Context) {
var filter policy.PolicyListFilter
if err := c.ShouldBindQuery(&filter); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters", "details": err.Error()})
return
}
// Set defaults
if filter.Limit <= 0 || filter.Limit > 100 {
filter.Limit = 50
}
policies, total, err := h.store.ListPolicies(c.Request.Context(), &filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list policies", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"policies": policies,
"total": total,
"limit": filter.Limit,
"offset": filter.Offset,
})
}
// GetPolicy returns a single policy by ID.
func (h *PolicyHandler) GetPolicy(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid policy ID"})
return
}
p, err := h.store.GetPolicy(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get policy", "details": err.Error()})
return
}
if p == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Policy not found"})
return
}
c.JSON(http.StatusOK, p)
}
// CreatePolicy creates a new source policy.
func (h *PolicyHandler) CreatePolicy(c *gin.Context) {
var req policy.CreateSourcePolicyRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
p, err := h.store.CreatePolicy(c.Request.Context(), &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create policy", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionCreate, policy.AuditEntitySourcePolicy, &p.ID, nil, p, userEmail)
c.JSON(http.StatusCreated, p)
}
// UpdatePolicy updates an existing policy.
func (h *PolicyHandler) UpdatePolicy(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid policy ID"})
return
}
// Get old value for audit
oldPolicy, err := h.store.GetPolicy(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get policy", "details": err.Error()})
return
}
if oldPolicy == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Policy not found"})
return
}
var req policy.UpdateSourcePolicyRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
p, err := h.store.UpdatePolicy(c.Request.Context(), id, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update policy", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionUpdate, policy.AuditEntitySourcePolicy, &p.ID, oldPolicy, p, userEmail)
c.JSON(http.StatusOK, p)
}
// =============================================================================
// SOURCES (WHITELIST)
// =============================================================================
// ListSources returns all allowed sources.
func (h *PolicyHandler) ListSources(c *gin.Context) {
var filter policy.SourceListFilter
if err := c.ShouldBindQuery(&filter); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters", "details": err.Error()})
return
}
// Set defaults
if filter.Limit <= 0 || filter.Limit > 100 {
filter.Limit = 50
}
sources, total, err := h.store.ListSources(c.Request.Context(), &filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list sources", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"sources": sources,
"total": total,
"limit": filter.Limit,
"offset": filter.Offset,
})
}
// GetSource returns a single source by ID.
func (h *PolicyHandler) GetSource(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid source ID"})
return
}
source, err := h.store.GetSource(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get source", "details": err.Error()})
return
}
if source == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Source not found"})
return
}
c.JSON(http.StatusOK, source)
}
// CreateSource creates a new allowed source.
func (h *PolicyHandler) CreateSource(c *gin.Context) {
var req policy.CreateAllowedSourceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
source, err := h.store.CreateSource(c.Request.Context(), &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create source", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionCreate, policy.AuditEntityAllowedSource, &source.ID, nil, source, userEmail)
c.JSON(http.StatusCreated, source)
}
// UpdateSource updates an existing source.
func (h *PolicyHandler) UpdateSource(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid source ID"})
return
}
// Get old value for audit
oldSource, err := h.store.GetSource(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get source", "details": err.Error()})
return
}
if oldSource == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Source not found"})
return
}
var req policy.UpdateAllowedSourceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
source, err := h.store.UpdateSource(c.Request.Context(), id, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update source", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionUpdate, policy.AuditEntityAllowedSource, &source.ID, oldSource, source, userEmail)
c.JSON(http.StatusOK, source)
}
// DeleteSource deletes a source.
func (h *PolicyHandler) DeleteSource(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid source ID"})
return
}
// Get source for audit before deletion
source, err := h.store.GetSource(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get source", "details": err.Error()})
return
}
if source == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Source not found"})
return
}
if err := h.store.DeleteSource(c.Request.Context(), id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete source", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionDelete, policy.AuditEntityAllowedSource, &id, source, nil, userEmail)
c.JSON(http.StatusOK, gin.H{"deleted": true, "id": id})
}
// =============================================================================
// OPERATIONS MATRIX
// =============================================================================
// GetOperationsMatrix returns all sources with their operation permissions.
func (h *PolicyHandler) GetOperationsMatrix(c *gin.Context) {
sources, err := h.store.GetOperationsMatrix(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get operations matrix", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"sources": sources,
"operations": []string{
string(policy.OperationLookup),
string(policy.OperationRAG),
string(policy.OperationTraining),
string(policy.OperationExport),
},
})
}
// UpdateOperationPermission updates a single operation permission.
func (h *PolicyHandler) UpdateOperationPermission(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid operation permission ID"})
return
}
var req policy.UpdateOperationPermissionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
// SECURITY: Prevent enabling training
if req.IsAllowed != nil && *req.IsAllowed {
// Check if this is a training operation by querying
ops, _ := h.store.GetOperationsBySourceID(c.Request.Context(), id)
for _, op := range ops {
if op.ID == id && op.Operation == policy.OperationTraining {
c.JSON(http.StatusForbidden, gin.H{
"error": "Training operations cannot be enabled",
"message": "Training with external data is FORBIDDEN by policy",
})
return
}
}
}
op, err := h.store.UpdateOperationPermission(c.Request.Context(), id, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update operation permission", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionUpdate, policy.AuditEntityOperationPermission, &op.ID, nil, op, userEmail)
c.JSON(http.StatusOK, op)
}
// =============================================================================
// PII RULES
// =============================================================================
// ListPIIRules returns all PII detection rules.
func (h *PolicyHandler) ListPIIRules(c *gin.Context) {
activeOnly := c.Query("active_only") == "true"
rules, err := h.store.ListPIIRules(c.Request.Context(), activeOnly)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list PII rules", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"rules": rules,
"total": len(rules),
})
}
// GetPIIRule returns a single PII rule by ID.
func (h *PolicyHandler) GetPIIRule(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid PII rule ID"})
return
}
rule, err := h.store.GetPIIRule(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get PII rule", "details": err.Error()})
return
}
if rule == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "PII rule not found"})
return
}
c.JSON(http.StatusOK, rule)
}
// CreatePIIRule creates a new PII detection rule.
func (h *PolicyHandler) CreatePIIRule(c *gin.Context) {
var req policy.CreatePIIRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
rule, err := h.store.CreatePIIRule(c.Request.Context(), &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create PII rule", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionCreate, policy.AuditEntityPIIRule, &rule.ID, nil, rule, userEmail)
c.JSON(http.StatusCreated, rule)
}
// UpdatePIIRule updates an existing PII rule.
func (h *PolicyHandler) UpdatePIIRule(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid PII rule ID"})
return
}
// Get old value for audit
oldRule, err := h.store.GetPIIRule(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get PII rule", "details": err.Error()})
return
}
if oldRule == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "PII rule not found"})
return
}
var req policy.UpdatePIIRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
rule, err := h.store.UpdatePIIRule(c.Request.Context(), id, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update PII rule", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionUpdate, policy.AuditEntityPIIRule, &rule.ID, oldRule, rule, userEmail)
c.JSON(http.StatusOK, rule)
}
// DeletePIIRule deletes a PII rule.
func (h *PolicyHandler) DeletePIIRule(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid PII rule ID"})
return
}
// Get rule for audit before deletion
rule, err := h.store.GetPIIRule(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get PII rule", "details": err.Error()})
return
}
if rule == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "PII rule not found"})
return
}
if err := h.store.DeletePIIRule(c.Request.Context(), id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete PII rule", "details": err.Error()})
return
}
// Log audit
userEmail := getUserEmail(c)
h.enforcer.LogChange(c.Request.Context(), policy.AuditActionDelete, policy.AuditEntityPIIRule, &id, rule, nil, userEmail)
c.JSON(http.StatusOK, gin.H{"deleted": true, "id": id})
}
// TestPIIRules tests PII detection against sample text.
func (h *PolicyHandler) TestPIIRules(c *gin.Context) {
var req policy.PIITestRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
response, err := h.enforcer.DetectPII(c.Request.Context(), req.Text)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to test PII detection", "details": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// =============================================================================
// AUDIT & COMPLIANCE
// =============================================================================
// ListAuditLogs returns audit log entries.
func (h *PolicyHandler) ListAuditLogs(c *gin.Context) {
var filter policy.AuditLogFilter
if err := c.ShouldBindQuery(&filter); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters", "details": err.Error()})
return
}
// Set defaults
if filter.Limit <= 0 || filter.Limit > 500 {
filter.Limit = 100
}
logs, total, err := h.store.ListAuditLogs(c.Request.Context(), &filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list audit logs", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
"limit": filter.Limit,
"offset": filter.Offset,
})
}
// ListBlockedContent returns blocked content log entries.
func (h *PolicyHandler) ListBlockedContent(c *gin.Context) {
var filter policy.BlockedContentFilter
if err := c.ShouldBindQuery(&filter); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters", "details": err.Error()})
return
}
// Set defaults
if filter.Limit <= 0 || filter.Limit > 500 {
filter.Limit = 100
}
logs, total, err := h.store.ListBlockedContent(c.Request.Context(), &filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list blocked content", "details": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"blocked": logs,
"total": total,
"limit": filter.Limit,
"offset": filter.Offset,
})
}
// CheckCompliance performs a compliance check for a URL.
func (h *PolicyHandler) CheckCompliance(c *gin.Context) {
var req policy.CheckComplianceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
response, err := h.enforcer.CheckCompliance(c.Request.Context(), &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check compliance", "details": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetPolicyStats returns aggregated statistics.
func (h *PolicyHandler) GetPolicyStats(c *gin.Context) {
stats, err := h.store.GetStats(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get stats", "details": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// GenerateComplianceReport generates an audit report.
func (h *PolicyHandler) GenerateComplianceReport(c *gin.Context) {
var auditFilter policy.AuditLogFilter
var blockedFilter policy.BlockedContentFilter
// Parse date filters
fromStr := c.Query("from")
toStr := c.Query("to")
if fromStr != "" {
from, err := time.Parse("2006-01-02", fromStr)
if err == nil {
auditFilter.FromDate = &from
blockedFilter.FromDate = &from
}
}
if toStr != "" {
to, err := time.Parse("2006-01-02", toStr)
if err == nil {
// Add 1 day to include the end date
to = to.Add(24 * time.Hour)
auditFilter.ToDate = &to
blockedFilter.ToDate = &to
}
}
// No limit for report
auditFilter.Limit = 10000
blockedFilter.Limit = 10000
auditor := policy.NewAuditor(h.store)
report, err := auditor.GenerateAuditReport(c.Request.Context(), &auditFilter, &blockedFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate report", "details": err.Error()})
return
}
// Set filename for download
format := c.Query("format")
if format == "download" {
filename := "compliance-report-" + time.Now().Format("2006-01-02") + ".json"
c.Header("Content-Disposition", "attachment; filename="+filename)
c.Header("Content-Type", "application/json")
}
c.JSON(http.StatusOK, report)
}
// =============================================================================
// HELPERS
// =============================================================================
// getUserEmail extracts user email from context or headers.
func getUserEmail(c *gin.Context) *string {
// Try to get from header (set by auth proxy)
email := c.GetHeader("X-User-Email")
if email != "" {
return &email
}
// Try to get from context (set by auth middleware)
if e, exists := c.Get("user_email"); exists {
if emailStr, ok := e.(string); ok {
return &emailStr
}
}
return nil
}
// =============================================================================
// ROUTE SETUP
// =============================================================================
// SetupPolicyRoutes configures all policy-related routes.
func SetupPolicyRoutes(r *gin.RouterGroup) {
if policyHandler == nil {
return
}
h := policyHandler
// Policies
r.GET("/policies", h.ListPolicies)
r.GET("/policies/:id", h.GetPolicy)
r.POST("/policies", h.CreatePolicy)
r.PUT("/policies/:id", h.UpdatePolicy)
// Sources (Whitelist)
r.GET("/sources", h.ListSources)
r.GET("/sources/:id", h.GetSource)
r.POST("/sources", h.CreateSource)
r.PUT("/sources/:id", h.UpdateSource)
r.DELETE("/sources/:id", h.DeleteSource)
// Operations Matrix
r.GET("/operations-matrix", h.GetOperationsMatrix)
r.PUT("/operations/:id", h.UpdateOperationPermission)
// PII Rules
r.GET("/pii-rules", h.ListPIIRules)
r.GET("/pii-rules/:id", h.GetPIIRule)
r.POST("/pii-rules", h.CreatePIIRule)
r.PUT("/pii-rules/:id", h.UpdatePIIRule)
r.DELETE("/pii-rules/:id", h.DeletePIIRule)
r.POST("/pii-rules/test", h.TestPIIRules)
// Audit & Compliance
r.GET("/policy-audit", h.ListAuditLogs)
r.GET("/blocked-content", h.ListBlockedContent)
r.POST("/check-compliance", h.CheckCompliance)
r.GET("/policy-stats", h.GetPolicyStats)
r.GET("/compliance-report", h.GenerateComplianceReport)
}

View File

@@ -0,0 +1,374 @@
package handlers
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/breakpilot/edu-search-service/internal/database"
"github.com/breakpilot/edu-search-service/internal/publications"
"github.com/breakpilot/edu-search-service/internal/staff"
)
// StaffHandlers handles staff-related API endpoints
type StaffHandlers struct {
repo *database.Repository
crawler *staff.StaffCrawler
pubCrawler *publications.PublicationCrawler
}
// NewStaffHandlers creates new staff handlers
func NewStaffHandlers(repo *database.Repository, email string) *StaffHandlers {
return &StaffHandlers{
repo: repo,
crawler: staff.NewStaffCrawler(repo),
pubCrawler: publications.NewPublicationCrawler(repo, email),
}
}
// SearchStaff searches for university staff
// GET /api/v1/staff/search?q=...&university_id=...&state=...&position_type=...&is_professor=...
func (h *StaffHandlers) SearchStaff(c *gin.Context) {
params := database.StaffSearchParams{
Query: c.Query("q"),
Limit: parseIntDefault(c.Query("limit"), 20),
Offset: parseIntDefault(c.Query("offset"), 0),
}
// Optional filters
if uniID := c.Query("university_id"); uniID != "" {
id, err := uuid.Parse(uniID)
if err == nil {
params.UniversityID = &id
}
}
if deptID := c.Query("department_id"); deptID != "" {
id, err := uuid.Parse(deptID)
if err == nil {
params.DepartmentID = &id
}
}
if state := c.Query("state"); state != "" {
params.State = &state
}
if uniType := c.Query("uni_type"); uniType != "" {
params.UniType = &uniType
}
if posType := c.Query("position_type"); posType != "" {
params.PositionType = &posType
}
if isProfStr := c.Query("is_professor"); isProfStr != "" {
isProf := isProfStr == "true" || isProfStr == "1"
params.IsProfessor = &isProf
}
result, err := h.repo.SearchStaff(c.Request.Context(), params)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// GetStaff gets a single staff member by ID
// GET /api/v1/staff/:id
func (h *StaffHandlers) GetStaff(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid staff ID"})
return
}
staff, err := h.repo.GetStaff(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Staff not found"})
return
}
c.JSON(http.StatusOK, staff)
}
// GetStaffPublications gets publications for a staff member
// GET /api/v1/staff/:id/publications
func (h *StaffHandlers) GetStaffPublications(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid staff ID"})
return
}
pubs, err := h.repo.GetStaffPublications(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"publications": pubs,
"total": len(pubs),
"staff_id": id,
})
}
// SearchPublications searches for publications
// GET /api/v1/publications/search?q=...&year=...&pub_type=...
func (h *StaffHandlers) SearchPublications(c *gin.Context) {
params := database.PublicationSearchParams{
Query: c.Query("q"),
Limit: parseIntDefault(c.Query("limit"), 20),
Offset: parseIntDefault(c.Query("offset"), 0),
}
if staffID := c.Query("staff_id"); staffID != "" {
id, err := uuid.Parse(staffID)
if err == nil {
params.StaffID = &id
}
}
if year := c.Query("year"); year != "" {
y := parseIntDefault(year, 0)
if y > 0 {
params.Year = &y
}
}
if yearFrom := c.Query("year_from"); yearFrom != "" {
y := parseIntDefault(yearFrom, 0)
if y > 0 {
params.YearFrom = &y
}
}
if yearTo := c.Query("year_to"); yearTo != "" {
y := parseIntDefault(yearTo, 0)
if y > 0 {
params.YearTo = &y
}
}
if pubType := c.Query("pub_type"); pubType != "" {
params.PubType = &pubType
}
result, err := h.repo.SearchPublications(c.Request.Context(), params)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// GetStaffStats gets statistics about staff data
// GET /api/v1/staff/stats
func (h *StaffHandlers) GetStaffStats(c *gin.Context) {
stats, err := h.repo.GetStaffStats(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// ListUniversities lists all universities
// GET /api/v1/universities
func (h *StaffHandlers) ListUniversities(c *gin.Context) {
universities, err := h.repo.ListUniversities(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"universities": universities,
"total": len(universities),
})
}
// StartStaffCrawl starts a staff crawl for a university
// POST /api/v1/admin/crawl/staff
func (h *StaffHandlers) StartStaffCrawl(c *gin.Context) {
var req struct {
UniversityID string `json:"university_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
uniID, err := uuid.Parse(req.UniversityID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university ID"})
return
}
uni, err := h.repo.GetUniversity(c.Request.Context(), uniID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "University not found"})
return
}
// Start crawl in background
go func() {
result, err := h.crawler.CrawlUniversity(c.Request.Context(), uni)
if err != nil {
// Log error
return
}
_ = result
}()
c.JSON(http.StatusAccepted, gin.H{
"status": "started",
"university_id": uniID,
"message": "Staff crawl started in background",
})
}
// StartPublicationCrawl starts a publication crawl for a university
// POST /api/v1/admin/crawl/publications
func (h *StaffHandlers) StartPublicationCrawl(c *gin.Context) {
var req struct {
UniversityID string `json:"university_id"`
Limit int `json:"limit"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
uniID, err := uuid.Parse(req.UniversityID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university ID"})
return
}
limit := req.Limit
if limit <= 0 {
limit = 50
}
// Start crawl in background
go func() {
status, err := h.pubCrawler.CrawlForUniversity(c.Request.Context(), uniID, limit)
if err != nil {
// Log error
return
}
_ = status
}()
c.JSON(http.StatusAccepted, gin.H{
"status": "started",
"university_id": uniID,
"message": "Publication crawl started in background",
})
}
// ResolveDOI resolves a DOI and saves the publication
// POST /api/v1/publications/resolve-doi
func (h *StaffHandlers) ResolveDOI(c *gin.Context) {
var req struct {
DOI string `json:"doi"`
StaffID string `json:"staff_id,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil || req.DOI == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "DOI is required"})
return
}
pub, err := h.pubCrawler.ResolveDOI(c.Request.Context(), req.DOI)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Link to staff if provided
if req.StaffID != "" {
staffID, err := uuid.Parse(req.StaffID)
if err == nil {
link := &database.StaffPublication{
StaffID: staffID,
PublicationID: pub.ID,
}
h.repo.LinkStaffPublication(c.Request.Context(), link)
}
}
c.JSON(http.StatusOK, pub)
}
// GetCrawlStatus gets crawl status for a university
// GET /api/v1/admin/crawl/status/:university_id
func (h *StaffHandlers) GetCrawlStatus(c *gin.Context) {
idStr := c.Param("university_id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid university ID"})
return
}
status, err := h.repo.GetCrawlStatus(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if status == nil {
c.JSON(http.StatusOK, gin.H{
"university_id": id,
"staff_crawl_status": "never",
"pub_crawl_status": "never",
})
return
}
c.JSON(http.StatusOK, status)
}
// Helper to parse int with default
func parseIntDefault(s string, def int) int {
if s == "" {
return def
}
var n int
_, err := fmt.Sscanf(s, "%d", &n)
if err != nil {
return def
}
return n
}
// RegisterStaffRoutes registers staff-related routes
func (h *StaffHandlers) RegisterRoutes(r *gin.RouterGroup) {
// Public endpoints
r.GET("/staff/search", h.SearchStaff)
r.GET("/staff/stats", h.GetStaffStats)
r.GET("/staff/:id", h.GetStaff)
r.GET("/staff/:id/publications", h.GetStaffPublications)
r.GET("/publications/search", h.SearchPublications)
r.POST("/publications/resolve-doi", h.ResolveDOI)
r.GET("/universities", h.ListUniversities)
// Admin endpoints
r.POST("/admin/crawl/staff", h.StartStaffCrawl)
r.POST("/admin/crawl/publications", h.StartPublicationCrawl)
r.GET("/admin/crawl/status/:university_id", h.GetCrawlStatus)
}

View File

@@ -0,0 +1,127 @@
package config
import (
"os"
"strconv"
)
type Config struct {
// Server
Port string
// OpenSearch
OpenSearchURL string
OpenSearchUsername string
OpenSearchPassword string
IndexName string
// Crawler
UserAgent string
RateLimitPerSec float64
MaxDepth int
MaxPagesPerRun int
// Paths
SeedsDir string
RulesDir string
// API
APIKey string
// Backend Integration
BackendURL string // URL to Python Backend for Seeds API
SeedsFromAPI bool // If true, fetch seeds from API instead of files
// Embedding/Semantic Search
EmbeddingProvider string // "openai", "ollama", or "none"
OpenAIAPIKey string // API Key for OpenAI embeddings
EmbeddingModel string // Model name (e.g., "text-embedding-3-small")
EmbeddingDimension int // Vector dimension (1536 for OpenAI small)
OllamaURL string // Ollama base URL for local embeddings
SemanticSearchEnabled bool // Enable semantic search features
// Scheduler
SchedulerEnabled bool // Enable automatic crawl scheduling
SchedulerInterval string // Crawl interval (e.g., "24h", "168h" for weekly)
// PostgreSQL (for Staff/Publications database)
DBHost string
DBPort string
DBUser string
DBPassword string
DBName string
DBSSLMode string
// Staff Crawler
StaffCrawlerEmail string // Contact email for CrossRef polite pool
}
func Load() *Config {
return &Config{
Port: getEnv("PORT", "8084"),
OpenSearchURL: getEnv("OPENSEARCH_URL", "http://opensearch:9200"),
OpenSearchUsername: getEnv("OPENSEARCH_USERNAME", "admin"),
OpenSearchPassword: getEnv("OPENSEARCH_PASSWORD", "admin"),
IndexName: getEnv("INDEX_NAME", "bp_documents_v1"),
UserAgent: getEnv("USER_AGENT", "BreakpilotEduCrawler/1.0 (+contact: security@breakpilot.com)"),
RateLimitPerSec: getEnvFloat("RATE_LIMIT_PER_SEC", 0.2),
MaxDepth: getEnvInt("MAX_DEPTH", 4),
MaxPagesPerRun: getEnvInt("MAX_PAGES_PER_RUN", 500),
SeedsDir: getEnv("SEEDS_DIR", "./seeds"),
RulesDir: getEnv("RULES_DIR", "./rules"),
APIKey: getEnv("EDU_SEARCH_API_KEY", ""),
BackendURL: getEnv("BACKEND_URL", "http://backend:8000"),
SeedsFromAPI: getEnvBool("SEEDS_FROM_API", true),
// Embedding/Semantic Search
EmbeddingProvider: getEnv("EMBEDDING_PROVIDER", "none"), // "openai", "ollama", or "none"
OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""),
EmbeddingModel: getEnv("EMBEDDING_MODEL", "text-embedding-3-small"),
EmbeddingDimension: getEnvInt("EMBEDDING_DIMENSION", 1536),
OllamaURL: getEnv("OLLAMA_URL", "http://ollama:11434"),
SemanticSearchEnabled: getEnvBool("SEMANTIC_SEARCH_ENABLED", false),
// Scheduler
SchedulerEnabled: getEnvBool("SCHEDULER_ENABLED", false),
SchedulerInterval: getEnv("SCHEDULER_INTERVAL", "24h"),
// PostgreSQL
DBHost: getEnv("DB_HOST", "postgres"),
DBPort: getEnv("DB_PORT", "5432"),
DBUser: getEnv("DB_USER", "postgres"),
DBPassword: getEnv("DB_PASSWORD", "postgres"),
DBName: getEnv("DB_NAME", "breakpilot"),
DBSSLMode: getEnv("DB_SSLMODE", "disable"),
// Staff Crawler
StaffCrawlerEmail: getEnv("STAFF_CRAWLER_EMAIL", "crawler@breakpilot.de"),
}
}
func getEnvBool(key string, fallback bool) bool {
if value := os.Getenv(key); value != "" {
return value == "true" || value == "1" || value == "yes"
}
return fallback
}
func getEnv(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if value := os.Getenv(key); value != "" {
if i, err := strconv.Atoi(value); err == nil {
return i
}
}
return fallback
}
func getEnvFloat(key string, fallback float64) float64 {
if value := os.Getenv(key); value != "" {
if f, err := strconv.ParseFloat(value, 64); err == nil {
return f
}
}
return fallback
}

View File

@@ -0,0 +1,183 @@
package crawler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// SeedFromAPI represents a seed URL from the Backend API
type SeedFromAPI struct {
URL string `json:"url"`
Trust float64 `json:"trust"`
Source string `json:"source"` // GOV, EDU, UNI, etc.
Scope string `json:"scope"` // FEDERAL, STATE, etc.
State string `json:"state"` // BW, BY, etc. (optional)
Depth int `json:"depth"` // Crawl depth for this seed
Category string `json:"category"` // Category name
}
// SeedsExportResponse represents the API response from /seeds/export/for-crawler
type SeedsExportResponse struct {
Seeds []SeedFromAPI `json:"seeds"`
Total int `json:"total"`
ExportedAt string `json:"exported_at"`
}
// APIClient handles communication with the Python Backend
type APIClient struct {
baseURL string
httpClient *http.Client
}
// NewAPIClient creates a new API client for fetching seeds
func NewAPIClient(backendURL string) *APIClient {
return &APIClient{
baseURL: backendURL,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// FetchSeeds retrieves enabled seeds from the Backend API
func (c *APIClient) FetchSeeds(ctx context.Context) (*SeedsExportResponse, error) {
url := fmt.Sprintf("%s/v1/edu-search/seeds/export/for-crawler", c.baseURL)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "EduSearchCrawler/1.0")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch seeds: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var result SeedsExportResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &result, nil
}
// CrawlStatusReport represents a crawl status to report to the Backend
type CrawlStatusReport struct {
SeedURL string `json:"seed_url"`
Status string `json:"status"` // "success", "error", "partial"
DocumentsCrawled int `json:"documents_crawled"`
ErrorMessage string `json:"error_message,omitempty"`
CrawlDuration float64 `json:"crawl_duration_seconds"`
}
// CrawlStatusResponse represents the response from crawl status endpoint
type CrawlStatusResponse struct {
Success bool `json:"success"`
SeedURL string `json:"seed_url"`
Message string `json:"message"`
}
// BulkCrawlStatusResponse represents the response from bulk crawl status endpoint
type BulkCrawlStatusResponse struct {
Updated int `json:"updated"`
Failed int `json:"failed"`
Errors []string `json:"errors"`
}
// ReportStatus sends crawl status for a single seed to the Backend
func (c *APIClient) ReportStatus(ctx context.Context, report *CrawlStatusReport) error {
url := fmt.Sprintf("%s/v1/edu-search/seeds/crawl-status", c.baseURL)
body, err := json.Marshal(report)
if err != nil {
return fmt.Errorf("failed to marshal report: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "EduSearchCrawler/1.0")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to report status: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// ReportStatusBulk sends crawl status for multiple seeds in one request
func (c *APIClient) ReportStatusBulk(ctx context.Context, reports []*CrawlStatusReport) (*BulkCrawlStatusResponse, error) {
url := fmt.Sprintf("%s/v1/edu-search/seeds/crawl-status/bulk", c.baseURL)
payload := struct {
Updates []*CrawlStatusReport `json:"updates"`
}{
Updates: reports,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal reports: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "EduSearchCrawler/1.0")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to report status: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(respBody))
}
var result BulkCrawlStatusResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &result, nil
}

View File

@@ -0,0 +1,428 @@
package crawler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewAPIClient(t *testing.T) {
client := NewAPIClient("http://backend:8000")
if client == nil {
t.Fatal("Expected non-nil client")
}
if client.baseURL != "http://backend:8000" {
t.Errorf("Expected baseURL 'http://backend:8000', got '%s'", client.baseURL)
}
if client.httpClient == nil {
t.Fatal("Expected non-nil httpClient")
}
}
func TestFetchSeeds_Success(t *testing.T) {
// Create mock server
mockResponse := SeedsExportResponse{
Seeds: []SeedFromAPI{
{
URL: "https://www.kmk.org",
Trust: 0.8,
Source: "GOV",
Scope: "FEDERAL",
State: "",
Depth: 3,
Category: "federal",
},
{
URL: "https://www.km-bw.de",
Trust: 0.7,
Source: "GOV",
Scope: "STATE",
State: "BW",
Depth: 2,
Category: "states",
},
},
Total: 2,
ExportedAt: "2025-01-17T10:00:00Z",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request path
if r.URL.Path != "/v1/edu-search/seeds/export/for-crawler" {
t.Errorf("Expected path '/v1/edu-search/seeds/export/for-crawler', got '%s'", r.URL.Path)
}
// Verify headers
if r.Header.Get("Accept") != "application/json" {
t.Errorf("Expected Accept header 'application/json', got '%s'", r.Header.Get("Accept"))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(mockResponse)
}))
defer server.Close()
// Test
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
result, err := client.FetchSeeds(ctx)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if result.Total != 2 {
t.Errorf("Expected 2 seeds, got %d", result.Total)
}
if len(result.Seeds) != 2 {
t.Fatalf("Expected 2 seeds in array, got %d", len(result.Seeds))
}
// Verify first seed
if result.Seeds[0].URL != "https://www.kmk.org" {
t.Errorf("Expected URL 'https://www.kmk.org', got '%s'", result.Seeds[0].URL)
}
if result.Seeds[0].Trust != 0.8 {
t.Errorf("Expected Trust 0.8, got %f", result.Seeds[0].Trust)
}
if result.Seeds[0].Source != "GOV" {
t.Errorf("Expected Source 'GOV', got '%s'", result.Seeds[0].Source)
}
// Verify second seed with state
if result.Seeds[1].State != "BW" {
t.Errorf("Expected State 'BW', got '%s'", result.Seeds[1].State)
}
}
func TestFetchSeeds_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal server error"))
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.FetchSeeds(ctx)
if err == nil {
t.Fatal("Expected error for server error response")
}
}
func TestFetchSeeds_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("not valid json"))
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.FetchSeeds(ctx)
if err == nil {
t.Fatal("Expected error for invalid JSON response")
}
}
func TestFetchSeeds_Timeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate slow response
time.Sleep(2 * time.Second)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewAPIClient(server.URL)
// Very short timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := client.FetchSeeds(ctx)
if err == nil {
t.Fatal("Expected timeout error")
}
}
func TestFetchSeeds_EmptyResponse(t *testing.T) {
mockResponse := SeedsExportResponse{
Seeds: []SeedFromAPI{},
Total: 0,
ExportedAt: "2025-01-17T10:00:00Z",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(mockResponse)
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
result, err := client.FetchSeeds(ctx)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if result.Total != 0 {
t.Errorf("Expected 0 seeds, got %d", result.Total)
}
if len(result.Seeds) != 0 {
t.Errorf("Expected empty seeds array, got %d", len(result.Seeds))
}
}
// Tests for Crawl Status Reporting
func TestReportStatus_Success(t *testing.T) {
var receivedReport CrawlStatusReport
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method and path
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
}
if r.URL.Path != "/v1/edu-search/seeds/crawl-status" {
t.Errorf("Expected path '/v1/edu-search/seeds/crawl-status', got '%s'", r.URL.Path)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type 'application/json', got '%s'", r.Header.Get("Content-Type"))
}
// Parse body
json.NewDecoder(r.Body).Decode(&receivedReport)
// Send response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(CrawlStatusResponse{
Success: true,
SeedURL: receivedReport.SeedURL,
Message: "Status updated",
})
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
report := &CrawlStatusReport{
SeedURL: "https://www.kmk.org",
Status: "success",
DocumentsCrawled: 42,
CrawlDuration: 15.5,
}
err := client.ReportStatus(ctx, report)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Verify the report was sent correctly
if receivedReport.SeedURL != "https://www.kmk.org" {
t.Errorf("Expected SeedURL 'https://www.kmk.org', got '%s'", receivedReport.SeedURL)
}
if receivedReport.Status != "success" {
t.Errorf("Expected Status 'success', got '%s'", receivedReport.Status)
}
if receivedReport.DocumentsCrawled != 42 {
t.Errorf("Expected DocumentsCrawled 42, got %d", receivedReport.DocumentsCrawled)
}
}
func TestReportStatus_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal server error"))
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
report := &CrawlStatusReport{
SeedURL: "https://www.kmk.org",
Status: "success",
}
err := client.ReportStatus(ctx, report)
if err == nil {
t.Fatal("Expected error for server error response")
}
}
func TestReportStatus_NotFound(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`{"detail": "Seed nicht gefunden"}`))
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
report := &CrawlStatusReport{
SeedURL: "https://unknown.example.com",
Status: "error",
}
err := client.ReportStatus(ctx, report)
if err == nil {
t.Fatal("Expected error for 404 response")
}
}
func TestReportStatusBulk_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method and path
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
}
if r.URL.Path != "/v1/edu-search/seeds/crawl-status/bulk" {
t.Errorf("Expected path '/v1/edu-search/seeds/crawl-status/bulk', got '%s'", r.URL.Path)
}
// Parse body
var payload struct {
Updates []*CrawlStatusReport `json:"updates"`
}
json.NewDecoder(r.Body).Decode(&payload)
// Send response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(BulkCrawlStatusResponse{
Updated: len(payload.Updates),
Failed: 0,
Errors: []string{},
})
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
reports := []*CrawlStatusReport{
{
SeedURL: "https://www.kmk.org",
Status: "success",
DocumentsCrawled: 42,
},
{
SeedURL: "https://www.km-bw.de",
Status: "partial",
DocumentsCrawled: 15,
},
}
result, err := client.ReportStatusBulk(ctx, reports)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if result.Updated != 2 {
t.Errorf("Expected 2 updated, got %d", result.Updated)
}
if result.Failed != 0 {
t.Errorf("Expected 0 failed, got %d", result.Failed)
}
}
func TestReportStatusBulk_PartialFailure(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(BulkCrawlStatusResponse{
Updated: 1,
Failed: 1,
Errors: []string{"Seed nicht gefunden: https://unknown.example.com"},
})
}))
defer server.Close()
client := NewAPIClient(server.URL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
reports := []*CrawlStatusReport{
{SeedURL: "https://www.kmk.org", Status: "success"},
{SeedURL: "https://unknown.example.com", Status: "error"},
}
result, err := client.ReportStatusBulk(ctx, reports)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if result.Updated != 1 {
t.Errorf("Expected 1 updated, got %d", result.Updated)
}
if result.Failed != 1 {
t.Errorf("Expected 1 failed, got %d", result.Failed)
}
if len(result.Errors) != 1 {
t.Errorf("Expected 1 error, got %d", len(result.Errors))
}
}
func TestCrawlStatusReport_Struct(t *testing.T) {
report := CrawlStatusReport{
SeedURL: "https://www.example.com",
Status: "success",
DocumentsCrawled: 100,
ErrorMessage: "",
CrawlDuration: 25.5,
}
// Test JSON marshaling
data, err := json.Marshal(report)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
var decoded CrawlStatusReport
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if decoded.SeedURL != report.SeedURL {
t.Errorf("SeedURL mismatch")
}
if decoded.Status != report.Status {
t.Errorf("Status mismatch")
}
if decoded.DocumentsCrawled != report.DocumentsCrawled {
t.Errorf("DocumentsCrawled mismatch")
}
if decoded.CrawlDuration != report.CrawlDuration {
t.Errorf("CrawlDuration mismatch")
}
}

View File

@@ -0,0 +1,364 @@
package crawler
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Note: API client is in the same package (api_client.go)
// FetchResult contains the result of fetching a URL
type FetchResult struct {
URL string
CanonicalURL string
ContentType string
StatusCode int
Body []byte
ContentHash string
FetchTime time.Time
Error error
}
// Seed represents a URL to crawl with metadata
type Seed struct {
URL string
TrustBoost float64
Source string // GOV, EDU, UNI, etc.
Scope string // FEDERAL, STATE, etc.
State string // BW, BY, etc. (optional)
MaxDepth int // Custom crawl depth for this seed
Category string // Category name
}
// Crawler handles URL fetching with rate limiting and robots.txt respect
type Crawler struct {
userAgent string
rateLimitPerSec float64
maxDepth int
timeout time.Duration
client *http.Client
denylist map[string]bool
lastFetch map[string]time.Time
mu sync.Mutex
apiClient *APIClient // API client for fetching seeds from Backend
}
// NewCrawler creates a new crawler instance
func NewCrawler(userAgent string, rateLimitPerSec float64, maxDepth int) *Crawler {
return &Crawler{
userAgent: userAgent,
rateLimitPerSec: rateLimitPerSec,
maxDepth: maxDepth,
timeout: 30 * time.Second,
client: &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("too many redirects")
}
return nil
},
},
denylist: make(map[string]bool),
lastFetch: make(map[string]time.Time),
}
}
// SetAPIClient sets the API client for fetching seeds from Backend
func (c *Crawler) SetAPIClient(backendURL string) {
c.apiClient = NewAPIClient(backendURL)
}
// LoadSeedsFromAPI fetches seeds from the Backend API
func (c *Crawler) LoadSeedsFromAPI(ctx context.Context) ([]Seed, error) {
if c.apiClient == nil {
return nil, fmt.Errorf("API client not initialized - call SetAPIClient first")
}
response, err := c.apiClient.FetchSeeds(ctx)
if err != nil {
return nil, fmt.Errorf("failed to fetch seeds from API: %w", err)
}
seeds := make([]Seed, 0, len(response.Seeds))
for _, apiSeed := range response.Seeds {
seed := Seed{
URL: apiSeed.URL,
TrustBoost: apiSeed.Trust,
Source: apiSeed.Source,
Scope: apiSeed.Scope,
State: apiSeed.State,
MaxDepth: apiSeed.Depth,
Category: apiSeed.Category,
}
// Use default depth if not specified
if seed.MaxDepth <= 0 {
seed.MaxDepth = c.maxDepth
}
seeds = append(seeds, seed)
}
log.Printf("Loaded %d seeds from API (exported at: %s)", len(seeds), response.ExportedAt)
return seeds, nil
}
// LoadSeeds loads seed URLs from files in a directory (legacy method)
func (c *Crawler) LoadSeeds(seedsDir string) ([]string, error) {
var seeds []string
files, err := filepath.Glob(filepath.Join(seedsDir, "*.txt"))
if err != nil {
return nil, err
}
for _, file := range files {
if strings.Contains(file, "denylist") {
// Load denylist
if err := c.loadDenylist(file); err != nil {
log.Printf("Warning: Could not load denylist %s: %v", file, err)
}
continue
}
fileSeeds, err := c.loadSeedFile(file)
if err != nil {
log.Printf("Warning: Could not load seed file %s: %v", file, err)
continue
}
seeds = append(seeds, fileSeeds...)
}
log.Printf("Loaded %d seeds from files, %d domains in denylist", len(seeds), len(c.denylist))
return seeds, nil
}
// LoadSeedsWithMetadata loads seeds from files and converts to Seed struct
// This provides backward compatibility while allowing metadata
func (c *Crawler) LoadSeedsWithMetadata(seedsDir string) ([]Seed, error) {
urlList, err := c.LoadSeeds(seedsDir)
if err != nil {
return nil, err
}
seeds := make([]Seed, 0, len(urlList))
for _, url := range urlList {
seeds = append(seeds, Seed{
URL: url,
TrustBoost: 0.5, // Default trust boost
MaxDepth: c.maxDepth,
})
}
return seeds, nil
}
func (c *Crawler) loadSeedFile(filename string) ([]string, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
var seeds []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip comments and empty lines
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Extract URL (ignore comments after URL)
parts := strings.SplitN(line, " ", 2)
urlStr := strings.TrimSpace(parts[0])
if urlStr != "" {
seeds = append(seeds, urlStr)
}
}
return seeds, scanner.Err()
}
func (c *Crawler) loadDenylist(filename string) error {
file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
c.denylist[strings.ToLower(line)] = true
}
return scanner.Err()
}
// IsDenied checks if a domain is in the denylist
func (c *Crawler) IsDenied(urlStr string) bool {
u, err := url.Parse(urlStr)
if err != nil {
return true
}
host := strings.ToLower(u.Host)
// Check exact match
if c.denylist[host] {
return true
}
// Check parent domains
parts := strings.Split(host, ".")
for i := 1; i < len(parts)-1; i++ {
parent := strings.Join(parts[i:], ".")
if c.denylist[parent] {
return true
}
}
return false
}
// Fetch fetches a single URL with rate limiting
func (c *Crawler) Fetch(ctx context.Context, urlStr string) (*FetchResult, error) {
result := &FetchResult{
URL: urlStr,
FetchTime: time.Now(),
}
// Check denylist
if c.IsDenied(urlStr) {
result.Error = fmt.Errorf("domain denied")
return result, result.Error
}
// Parse URL
u, err := url.Parse(urlStr)
if err != nil {
result.Error = err
return result, err
}
// Rate limiting per domain
c.waitForRateLimit(u.Host)
// Create request
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
result.Error = err
return result, err
}
req.Header.Set("User-Agent", c.userAgent)
req.Header.Set("Accept", "text/html,application/pdf,application/xhtml+xml")
req.Header.Set("Accept-Language", "de-DE,de;q=0.9,en;q=0.8")
// Execute request
resp, err := c.client.Do(req)
if err != nil {
result.Error = err
return result, err
}
defer resp.Body.Close()
result.StatusCode = resp.StatusCode
result.ContentType = resp.Header.Get("Content-Type")
result.CanonicalURL = resp.Request.URL.String()
if resp.StatusCode != http.StatusOK {
result.Error = fmt.Errorf("HTTP %d", resp.StatusCode)
return result, result.Error
}
// Read body (limit to 20MB)
limitedReader := io.LimitReader(resp.Body, 20*1024*1024)
body, err := io.ReadAll(limitedReader)
if err != nil {
result.Error = err
return result, err
}
result.Body = body
// Calculate content hash
hash := sha256.Sum256(body)
result.ContentHash = hex.EncodeToString(hash[:])
return result, nil
}
func (c *Crawler) waitForRateLimit(host string) {
c.mu.Lock()
defer c.mu.Unlock()
minInterval := time.Duration(float64(time.Second) / c.rateLimitPerSec)
if last, ok := c.lastFetch[host]; ok {
elapsed := time.Since(last)
if elapsed < minInterval {
time.Sleep(minInterval - elapsed)
}
}
c.lastFetch[host] = time.Now()
}
// ExtractDomain extracts the domain from a URL
func ExtractDomain(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
return ""
}
return u.Host
}
// GenerateDocID generates a unique document ID
func GenerateDocID() string {
return uuid.New().String()
}
// NormalizeURL normalizes a URL for deduplication
func NormalizeURL(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
return urlStr
}
// Remove trailing slashes
u.Path = strings.TrimSuffix(u.Path, "/")
// Remove common tracking parameters
q := u.Query()
for key := range q {
lowerKey := strings.ToLower(key)
if strings.HasPrefix(lowerKey, "utm_") ||
lowerKey == "ref" ||
lowerKey == "source" ||
lowerKey == "fbclid" ||
lowerKey == "gclid" {
q.Del(key)
}
}
u.RawQuery = q.Encode()
// Lowercase host
u.Host = strings.ToLower(u.Host)
return u.String()
}

View File

@@ -0,0 +1,639 @@
package crawler
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestNewCrawler(t *testing.T) {
crawler := NewCrawler("TestBot/1.0", 1.0, 3)
if crawler == nil {
t.Fatal("Expected non-nil crawler")
}
if crawler.userAgent != "TestBot/1.0" {
t.Errorf("Expected userAgent 'TestBot/1.0', got %q", crawler.userAgent)
}
if crawler.rateLimitPerSec != 1.0 {
t.Errorf("Expected rateLimitPerSec 1.0, got %f", crawler.rateLimitPerSec)
}
if crawler.maxDepth != 3 {
t.Errorf("Expected maxDepth 3, got %d", crawler.maxDepth)
}
if crawler.client == nil {
t.Error("Expected non-nil HTTP client")
}
}
func TestCrawler_LoadSeeds(t *testing.T) {
// Create temp directory with seed files
dir := t.TempDir()
// Create a seed file
seedContent := `# Federal education sources
https://www.kmk.org
https://www.bildungsserver.de
# Comment line
https://www.bpb.de # with inline comment
`
if err := os.WriteFile(filepath.Join(dir, "federal.txt"), []byte(seedContent), 0644); err != nil {
t.Fatal(err)
}
// Create another seed file
stateContent := `https://www.km.bayern.de
https://www.schulministerium.nrw.de
`
if err := os.WriteFile(filepath.Join(dir, "states.txt"), []byte(stateContent), 0644); err != nil {
t.Fatal(err)
}
// Create denylist
denylistContent := `# Denylist
facebook.com
twitter.com
instagram.com
`
if err := os.WriteFile(filepath.Join(dir, "denylist.txt"), []byte(denylistContent), 0644); err != nil {
t.Fatal(err)
}
crawler := NewCrawler("TestBot/1.0", 1.0, 3)
seeds, err := crawler.LoadSeeds(dir)
if err != nil {
t.Fatalf("LoadSeeds failed: %v", err)
}
// Check seeds loaded
if len(seeds) != 5 {
t.Errorf("Expected 5 seeds, got %d", len(seeds))
}
// Check expected URLs
expectedURLs := []string{
"https://www.kmk.org",
"https://www.bildungsserver.de",
"https://www.bpb.de",
"https://www.km.bayern.de",
"https://www.schulministerium.nrw.de",
}
for _, expected := range expectedURLs {
found := false
for _, seed := range seeds {
if seed == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected seed %q not found", expected)
}
}
// Check denylist loaded
if len(crawler.denylist) != 3 {
t.Errorf("Expected 3 denylist entries, got %d", len(crawler.denylist))
}
}
func TestCrawler_IsDenied(t *testing.T) {
crawler := NewCrawler("TestBot/1.0", 1.0, 3)
crawler.denylist = map[string]bool{
"facebook.com": true,
"twitter.com": true,
"ads.example.com": true,
}
tests := []struct {
name string
url string
expected bool
}{
{
name: "Exact domain match",
url: "https://facebook.com/page",
expected: true,
},
{
name: "Subdomain of denied domain",
url: "https://www.facebook.com/page",
expected: true,
},
{
name: "Allowed domain",
url: "https://www.kmk.org/bildung",
expected: false,
},
{
name: "Denied subdomain",
url: "https://ads.example.com/banner",
expected: true,
},
{
name: "Parent domain allowed",
url: "https://example.com/page",
expected: false,
},
{
name: "Invalid URL scheme",
url: "://invalid",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.IsDenied(tt.url)
if result != tt.expected {
t.Errorf("IsDenied(%q) = %v, expected %v", tt.url, result, tt.expected)
}
})
}
}
func TestCrawler_Fetch_Success(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check user agent
if r.Header.Get("User-Agent") != "TestBot/1.0" {
t.Errorf("Expected User-Agent 'TestBot/1.0', got %q", r.Header.Get("User-Agent"))
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html><body>Test content</body></html>"))
}))
defer server.Close()
crawler := NewCrawler("TestBot/1.0", 100.0, 3) // High rate limit for testing
ctx := context.Background()
result, err := crawler.Fetch(ctx, server.URL+"/page")
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if result.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", result.StatusCode)
}
if result.Error != nil {
t.Errorf("Expected no error, got %v", result.Error)
}
if !strings.Contains(result.ContentType, "text/html") {
t.Errorf("Expected Content-Type to contain 'text/html', got %q", result.ContentType)
}
if len(result.Body) == 0 {
t.Error("Expected non-empty body")
}
if result.ContentHash == "" {
t.Error("Expected non-empty content hash")
}
if result.FetchTime.IsZero() {
t.Error("Expected non-zero fetch time")
}
}
func TestCrawler_Fetch_DeniedDomain(t *testing.T) {
crawler := NewCrawler("TestBot/1.0", 100.0, 3)
crawler.denylist = map[string]bool{
"denied.com": true,
}
ctx := context.Background()
result, err := crawler.Fetch(ctx, "https://denied.com/page")
if err == nil {
t.Error("Expected error for denied domain")
}
if result.Error == nil {
t.Error("Expected error in result")
}
if !strings.Contains(result.Error.Error(), "denied") {
t.Errorf("Expected 'denied' in error message, got %v", result.Error)
}
}
func TestCrawler_Fetch_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
crawler := NewCrawler("TestBot/1.0", 100.0, 3)
ctx := context.Background()
result, err := crawler.Fetch(ctx, server.URL+"/notfound")
if err == nil {
t.Error("Expected error for 404 response")
}
if result.StatusCode != 404 {
t.Errorf("Expected status 404, got %d", result.StatusCode)
}
}
func TestCrawler_Fetch_Redirect(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/redirect" {
redirectCount++
http.Redirect(w, r, "/final", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("Final content"))
}))
defer server.Close()
crawler := NewCrawler("TestBot/1.0", 100.0, 3)
ctx := context.Background()
result, err := crawler.Fetch(ctx, server.URL+"/redirect")
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
// CanonicalURL should be the final URL after redirect
if !strings.HasSuffix(result.CanonicalURL, "/final") {
t.Errorf("Expected canonical URL to end with '/final', got %q", result.CanonicalURL)
}
}
func TestCrawler_Fetch_Timeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second) // Delay response
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
crawler := NewCrawler("TestBot/1.0", 100.0, 3)
crawler.timeout = 100 * time.Millisecond // Very short timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := crawler.Fetch(ctx, server.URL+"/slow")
if err == nil {
t.Error("Expected timeout error")
}
}
func TestExtractDomain(t *testing.T) {
tests := []struct {
url string
expected string
}{
{
url: "https://www.example.com/page",
expected: "www.example.com",
},
{
url: "https://example.com:8080/path",
expected: "example.com:8080",
},
{
url: "http://subdomain.example.com",
expected: "subdomain.example.com",
},
{
url: "invalid-url",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.url, func(t *testing.T) {
result := ExtractDomain(tt.url)
if result != tt.expected {
t.Errorf("ExtractDomain(%q) = %q, expected %q", tt.url, result, tt.expected)
}
})
}
}
func TestGenerateDocID(t *testing.T) {
id1 := GenerateDocID()
id2 := GenerateDocID()
if id1 == "" {
t.Error("Expected non-empty ID")
}
if id1 == id2 {
t.Error("Expected unique IDs")
}
// UUID format check (basic)
if len(id1) != 36 {
t.Errorf("Expected UUID length 36, got %d", len(id1))
}
}
func TestNormalizeURL(t *testing.T) {
tests := []struct {
name string
url string
expected string
}{
{
name: "Remove trailing slash",
url: "https://example.com/page/",
expected: "https://example.com/page",
},
{
name: "Remove UTM parameters",
url: "https://example.com/page?utm_source=google&utm_medium=cpc",
expected: "https://example.com/page",
},
{
name: "Remove multiple tracking params",
url: "https://example.com/page?id=123&utm_campaign=test&fbclid=abc",
expected: "https://example.com/page?id=123",
},
{
name: "Keep non-tracking params",
url: "https://example.com/search?q=test&page=2",
expected: "https://example.com/search?page=2&q=test",
},
{
name: "Lowercase host",
url: "https://EXAMPLE.COM/Page",
expected: "https://example.com/Page",
},
{
name: "Invalid URL returns as-is",
url: "not-a-url",
expected: "not-a-url",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NormalizeURL(tt.url)
if result != tt.expected {
t.Errorf("NormalizeURL(%q) = %q, expected %q", tt.url, result, tt.expected)
}
})
}
}
func TestCrawler_RateLimit(t *testing.T) {
requestCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// 2 requests per second = 500ms between requests
crawler := NewCrawler("TestBot/1.0", 2.0, 3)
ctx := context.Background()
start := time.Now()
// Make 3 requests
for i := 0; i < 3; i++ {
crawler.Fetch(ctx, server.URL+"/page")
}
elapsed := time.Since(start)
// With 2 req/sec, 3 requests should take at least 1 second (2 intervals)
if elapsed < 800*time.Millisecond {
t.Errorf("Rate limiting not working: 3 requests took only %v", elapsed)
}
}
func TestLoadSeedFile_EmptyLines(t *testing.T) {
dir := t.TempDir()
content := `
https://example.com
# comment
https://example.org
`
if err := os.WriteFile(filepath.Join(dir, "seeds.txt"), []byte(content), 0644); err != nil {
t.Fatal(err)
}
crawler := NewCrawler("TestBot/1.0", 1.0, 3)
seeds, err := crawler.LoadSeeds(dir)
if err != nil {
t.Fatal(err)
}
if len(seeds) != 2 {
t.Errorf("Expected 2 seeds (ignoring empty lines and comments), got %d", len(seeds))
}
}
func TestCrawler_Fetch_LargeBody(t *testing.T) {
// Create a large response (but under the limit)
largeBody := strings.Repeat("A", 1024*1024) // 1MB
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(largeBody))
}))
defer server.Close()
crawler := NewCrawler("TestBot/1.0", 100.0, 3)
ctx := context.Background()
result, err := crawler.Fetch(ctx, server.URL+"/large")
if err != nil {
t.Fatalf("Fetch failed: %v", err)
}
if len(result.Body) != len(largeBody) {
t.Errorf("Expected body length %d, got %d", len(largeBody), len(result.Body))
}
}
// Tests for API Integration (new functionality)
func TestCrawler_SetAPIClient(t *testing.T) {
crawler := NewCrawler("TestBot/1.0", 1.0, 3)
if crawler.apiClient != nil {
t.Error("Expected nil apiClient initially")
}
crawler.SetAPIClient("http://backend:8000")
if crawler.apiClient == nil {
t.Error("Expected non-nil apiClient after SetAPIClient")
}
}
func TestCrawler_LoadSeedsFromAPI_NotInitialized(t *testing.T) {
crawler := NewCrawler("TestBot/1.0", 1.0, 3)
ctx := context.Background()
_, err := crawler.LoadSeedsFromAPI(ctx)
if err == nil {
t.Error("Expected error when API client not initialized")
}
}
func TestCrawler_LoadSeedsFromAPI_Success(t *testing.T) {
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"seeds": [
{"url": "https://www.kmk.org", "trust": 0.8, "source": "GOV", "scope": "FEDERAL", "state": "", "depth": 3, "category": "federal"},
{"url": "https://www.km-bw.de", "trust": 0.7, "source": "GOV", "scope": "STATE", "state": "BW", "depth": 2, "category": "states"}
],
"total": 2,
"exported_at": "2025-01-17T10:00:00Z"
}`))
}))
defer server.Close()
crawler := NewCrawler("TestBot/1.0", 1.0, 4)
crawler.SetAPIClient(server.URL)
ctx := context.Background()
seeds, err := crawler.LoadSeedsFromAPI(ctx)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if len(seeds) != 2 {
t.Fatalf("Expected 2 seeds, got %d", len(seeds))
}
// Check first seed
if seeds[0].URL != "https://www.kmk.org" {
t.Errorf("Expected URL 'https://www.kmk.org', got '%s'", seeds[0].URL)
}
if seeds[0].TrustBoost != 0.8 {
t.Errorf("Expected TrustBoost 0.8, got %f", seeds[0].TrustBoost)
}
if seeds[0].Source != "GOV" {
t.Errorf("Expected Source 'GOV', got '%s'", seeds[0].Source)
}
if seeds[0].MaxDepth != 3 {
t.Errorf("Expected MaxDepth 3, got %d", seeds[0].MaxDepth)
}
// Check second seed with state
if seeds[1].State != "BW" {
t.Errorf("Expected State 'BW', got '%s'", seeds[1].State)
}
if seeds[1].Category != "states" {
t.Errorf("Expected Category 'states', got '%s'", seeds[1].Category)
}
}
func TestCrawler_LoadSeedsFromAPI_DefaultDepth(t *testing.T) {
// Create mock server with seed that has no depth
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"seeds": [
{"url": "https://www.example.com", "trust": 0.5, "source": "EDU", "scope": "FEDERAL", "state": "", "depth": 0, "category": "edu"}
],
"total": 1,
"exported_at": "2025-01-17T10:00:00Z"
}`))
}))
defer server.Close()
defaultDepth := 5
crawler := NewCrawler("TestBot/1.0", 1.0, defaultDepth)
crawler.SetAPIClient(server.URL)
ctx := context.Background()
seeds, err := crawler.LoadSeedsFromAPI(ctx)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// When depth is 0 or not specified, it should use crawler's default
if seeds[0].MaxDepth != defaultDepth {
t.Errorf("Expected default MaxDepth %d, got %d", defaultDepth, seeds[0].MaxDepth)
}
}
func TestCrawler_LoadSeedsWithMetadata(t *testing.T) {
dir := t.TempDir()
seedContent := `https://www.kmk.org
https://www.bildungsserver.de`
if err := os.WriteFile(filepath.Join(dir, "seeds.txt"), []byte(seedContent), 0644); err != nil {
t.Fatal(err)
}
defaultDepth := 4
crawler := NewCrawler("TestBot/1.0", 1.0, defaultDepth)
seeds, err := crawler.LoadSeedsWithMetadata(dir)
if err != nil {
t.Fatalf("LoadSeedsWithMetadata failed: %v", err)
}
if len(seeds) != 2 {
t.Fatalf("Expected 2 seeds, got %d", len(seeds))
}
// Check default values
for _, seed := range seeds {
if seed.TrustBoost != 0.5 {
t.Errorf("Expected default TrustBoost 0.5, got %f", seed.TrustBoost)
}
if seed.MaxDepth != defaultDepth {
t.Errorf("Expected default MaxDepth %d, got %d", defaultDepth, seed.MaxDepth)
}
}
}
func TestSeed_Struct(t *testing.T) {
seed := Seed{
URL: "https://www.example.com",
TrustBoost: 0.75,
Source: "GOV",
Scope: "STATE",
State: "BY",
MaxDepth: 3,
Category: "states",
}
if seed.URL != "https://www.example.com" {
t.Errorf("URL mismatch")
}
if seed.TrustBoost != 0.75 {
t.Errorf("TrustBoost mismatch")
}
if seed.Source != "GOV" {
t.Errorf("Source mismatch")
}
if seed.Scope != "STATE" {
t.Errorf("Scope mismatch")
}
if seed.State != "BY" {
t.Errorf("State mismatch")
}
if seed.MaxDepth != 3 {
t.Errorf("MaxDepth mismatch")
}
if seed.Category != "states" {
t.Errorf("Category mismatch")
}
}

View File

@@ -0,0 +1,133 @@
package database
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// DB holds the database connection pool
type DB struct {
Pool *pgxpool.Pool
}
// Config holds database configuration
type Config struct {
Host string
Port string
User string
Password string
DBName string
SSLMode string
}
// NewConfig creates a new database config from environment variables
func NewConfig() *Config {
return &Config{
Host: getEnv("DB_HOST", "localhost"),
Port: getEnv("DB_PORT", "5432"),
User: getEnv("DB_USER", "postgres"),
Password: getEnv("DB_PASSWORD", "postgres"),
DBName: getEnv("DB_NAME", "breakpilot"),
SSLMode: getEnv("DB_SSLMODE", "disable"),
}
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// ConnectionString returns the PostgreSQL connection string
func (c *Config) ConnectionString() string {
return fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=%s",
c.User, c.Password, c.Host, c.Port, c.DBName, c.SSLMode,
)
}
// New creates a new database connection
func New(ctx context.Context, cfg *Config) (*DB, error) {
config, err := pgxpool.ParseConfig(cfg.ConnectionString())
if err != nil {
return nil, fmt.Errorf("failed to parse database config: %w", err)
}
// Configure connection pool
config.MaxConns = 10
config.MinConns = 2
config.MaxConnLifetime = time.Hour
config.MaxConnIdleTime = 30 * time.Minute
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Test connection
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Printf("Connected to database %s on %s:%s", cfg.DBName, cfg.Host, cfg.Port)
return &DB{Pool: pool}, nil
}
// Close closes the database connection pool
func (db *DB) Close() {
if db.Pool != nil {
db.Pool.Close()
}
}
// RunMigrations executes all SQL migrations
func (db *DB) RunMigrations(ctx context.Context) error {
// Try multiple paths for migration file
migrationPaths := []string{
"migrations/001_university_staff.sql",
"../migrations/001_university_staff.sql",
"../../migrations/001_university_staff.sql",
}
var content []byte
var err error
var foundPath string
for _, path := range migrationPaths {
absPath, _ := filepath.Abs(path)
content, err = os.ReadFile(absPath)
if err == nil {
foundPath = absPath
break
}
}
if content == nil {
return fmt.Errorf("failed to read migration file from any path: %w", err)
}
log.Printf("Running migrations from: %s", foundPath)
// Execute migration
_, err = db.Pool.Exec(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to execute migration: %w", err)
}
log.Println("Database migrations completed successfully")
return nil
}
// Health checks if the database is healthy
func (db *DB) Health(ctx context.Context) error {
return db.Pool.Ping(ctx)
}

View File

@@ -0,0 +1,205 @@
package database
import (
"time"
"github.com/google/uuid"
)
// University represents a German university/Hochschule
type University struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
ShortName *string `json:"short_name,omitempty"`
URL string `json:"url"`
State *string `json:"state,omitempty"`
UniType *string `json:"uni_type,omitempty"`
StaffPagePattern *string `json:"staff_page_pattern,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Department represents a faculty/department at a university
type Department struct {
ID uuid.UUID `json:"id"`
UniversityID uuid.UUID `json:"university_id"`
Name string `json:"name"`
NameEN *string `json:"name_en,omitempty"`
URL *string `json:"url,omitempty"`
Category *string `json:"category,omitempty"`
ParentID *uuid.UUID `json:"parent_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UniversityStaff represents a staff member at a university
type UniversityStaff struct {
ID uuid.UUID `json:"id"`
UniversityID uuid.UUID `json:"university_id"`
DepartmentID *uuid.UUID `json:"department_id,omitempty"`
FirstName *string `json:"first_name,omitempty"`
LastName string `json:"last_name"`
FullName *string `json:"full_name,omitempty"`
Title *string `json:"title,omitempty"`
AcademicTitle *string `json:"academic_title,omitempty"`
Position *string `json:"position,omitempty"`
PositionType *string `json:"position_type,omitempty"`
IsProfessor bool `json:"is_professor"`
Email *string `json:"email,omitempty"`
Phone *string `json:"phone,omitempty"`
Office *string `json:"office,omitempty"`
ProfileURL *string `json:"profile_url,omitempty"`
PhotoURL *string `json:"photo_url,omitempty"`
ORCID *string `json:"orcid,omitempty"`
GoogleScholarID *string `json:"google_scholar_id,omitempty"`
ResearchgateURL *string `json:"researchgate_url,omitempty"`
LinkedInURL *string `json:"linkedin_url,omitempty"`
PersonalWebsite *string `json:"personal_website,omitempty"`
ResearchInterests []string `json:"research_interests,omitempty"`
ResearchSummary *string `json:"research_summary,omitempty"`
SupervisorID *uuid.UUID `json:"supervisor_id,omitempty"`
TeamRole *string `json:"team_role,omitempty"` // leitung, mitarbeiter, sekretariat, hiwi, doktorand
CrawledAt time.Time `json:"crawled_at"`
LastVerified *time.Time `json:"last_verified,omitempty"`
IsActive bool `json:"is_active"`
SourceURL *string `json:"source_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// Joined fields (from views)
UniversityName *string `json:"university_name,omitempty"`
UniversityShort *string `json:"university_short,omitempty"`
DepartmentName *string `json:"department_name,omitempty"`
PublicationCount int `json:"publication_count,omitempty"`
SupervisorName *string `json:"supervisor_name,omitempty"`
}
// Publication represents an academic publication
type Publication struct {
ID uuid.UUID `json:"id"`
Title string `json:"title"`
TitleEN *string `json:"title_en,omitempty"`
Abstract *string `json:"abstract,omitempty"`
AbstractEN *string `json:"abstract_en,omitempty"`
Year *int `json:"year,omitempty"`
Month *int `json:"month,omitempty"`
PubType *string `json:"pub_type,omitempty"`
Venue *string `json:"venue,omitempty"`
VenueShort *string `json:"venue_short,omitempty"`
Publisher *string `json:"publisher,omitempty"`
DOI *string `json:"doi,omitempty"`
ISBN *string `json:"isbn,omitempty"`
ISSN *string `json:"issn,omitempty"`
ArxivID *string `json:"arxiv_id,omitempty"`
PubmedID *string `json:"pubmed_id,omitempty"`
URL *string `json:"url,omitempty"`
PDFURL *string `json:"pdf_url,omitempty"`
CitationCount int `json:"citation_count"`
Keywords []string `json:"keywords,omitempty"`
Topics []string `json:"topics,omitempty"`
Source *string `json:"source,omitempty"`
RawData []byte `json:"raw_data,omitempty"`
CrawledAt time.Time `json:"crawled_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// Joined fields
Authors []string `json:"authors,omitempty"`
AuthorCount int `json:"author_count,omitempty"`
}
// StaffPublication represents the N:M relationship between staff and publications
type StaffPublication struct {
StaffID uuid.UUID `json:"staff_id"`
PublicationID uuid.UUID `json:"publication_id"`
AuthorPosition *int `json:"author_position,omitempty"`
IsCorresponding bool `json:"is_corresponding"`
CreatedAt time.Time `json:"created_at"`
}
// UniversityCrawlStatus tracks crawl progress for a university
type UniversityCrawlStatus struct {
UniversityID uuid.UUID `json:"university_id"`
LastStaffCrawl *time.Time `json:"last_staff_crawl,omitempty"`
StaffCrawlStatus string `json:"staff_crawl_status"`
StaffCount int `json:"staff_count"`
StaffErrors []string `json:"staff_errors,omitempty"`
LastPubCrawl *time.Time `json:"last_pub_crawl,omitempty"`
PubCrawlStatus string `json:"pub_crawl_status"`
PubCount int `json:"pub_count"`
PubErrors []string `json:"pub_errors,omitempty"`
NextScheduledCrawl *time.Time `json:"next_scheduled_crawl,omitempty"`
CrawlPriority int `json:"crawl_priority"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// CrawlHistory represents a crawl audit log entry
type CrawlHistory struct {
ID uuid.UUID `json:"id"`
UniversityID *uuid.UUID `json:"university_id,omitempty"`
CrawlType string `json:"crawl_type"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
ItemsFound int `json:"items_found"`
ItemsNew int `json:"items_new"`
ItemsUpdated int `json:"items_updated"`
Errors []byte `json:"errors,omitempty"`
Metadata []byte `json:"metadata,omitempty"`
}
// StaffSearchParams contains parameters for searching staff
type StaffSearchParams struct {
Query string `json:"query,omitempty"`
UniversityID *uuid.UUID `json:"university_id,omitempty"`
DepartmentID *uuid.UUID `json:"department_id,omitempty"`
State *string `json:"state,omitempty"`
UniType *string `json:"uni_type,omitempty"`
PositionType *string `json:"position_type,omitempty"`
IsProfessor *bool `json:"is_professor,omitempty"`
Limit int `json:"limit,omitempty"`
Offset int `json:"offset,omitempty"`
}
// StaffSearchResult contains search results for staff
type StaffSearchResult struct {
Staff []UniversityStaff `json:"staff"`
Total int `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Query string `json:"query,omitempty"`
}
// PublicationSearchParams contains parameters for searching publications
type PublicationSearchParams struct {
Query string `json:"query,omitempty"`
StaffID *uuid.UUID `json:"staff_id,omitempty"`
Year *int `json:"year,omitempty"`
YearFrom *int `json:"year_from,omitempty"`
YearTo *int `json:"year_to,omitempty"`
PubType *string `json:"pub_type,omitempty"`
Limit int `json:"limit,omitempty"`
Offset int `json:"offset,omitempty"`
}
// PublicationSearchResult contains search results for publications
type PublicationSearchResult struct {
Publications []Publication `json:"publications"`
Total int `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Query string `json:"query,omitempty"`
}
// StaffStats contains statistics about staff data
type StaffStats struct {
TotalStaff int `json:"total_staff"`
TotalProfessors int `json:"total_professors"`
TotalPublications int `json:"total_publications"`
TotalUniversities int `json:"total_universities"`
ByState map[string]int `json:"by_state,omitempty"`
ByUniType map[string]int `json:"by_uni_type,omitempty"`
ByPositionType map[string]int `json:"by_position_type,omitempty"`
RecentCrawls []CrawlHistory `json:"recent_crawls,omitempty"`
}

View File

@@ -0,0 +1,684 @@
package database
import (
"context"
"fmt"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
// Repository provides database operations for staff and publications
type Repository struct {
db *DB
}
// NewRepository creates a new repository
func NewRepository(db *DB) *Repository {
return &Repository{db: db}
}
// ============================================================================
// UNIVERSITIES
// ============================================================================
// CreateUniversity creates a new university
func (r *Repository) CreateUniversity(ctx context.Context, u *University) error {
query := `
INSERT INTO universities (name, short_name, url, state, uni_type, staff_page_pattern)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (url) DO UPDATE SET
name = EXCLUDED.name,
short_name = EXCLUDED.short_name,
state = EXCLUDED.state,
uni_type = EXCLUDED.uni_type,
staff_page_pattern = EXCLUDED.staff_page_pattern,
updated_at = NOW()
RETURNING id, created_at, updated_at
`
return r.db.Pool.QueryRow(ctx, query,
u.Name, u.ShortName, u.URL, u.State, u.UniType, u.StaffPagePattern,
).Scan(&u.ID, &u.CreatedAt, &u.UpdatedAt)
}
// GetUniversity retrieves a university by ID
func (r *Repository) GetUniversity(ctx context.Context, id uuid.UUID) (*University, error) {
query := `SELECT id, name, short_name, url, state, uni_type, staff_page_pattern, created_at, updated_at
FROM universities WHERE id = $1`
u := &University{}
err := r.db.Pool.QueryRow(ctx, query, id).Scan(
&u.ID, &u.Name, &u.ShortName, &u.URL, &u.State, &u.UniType,
&u.StaffPagePattern, &u.CreatedAt, &u.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return u, nil
}
// GetUniversityByID is an alias for GetUniversity (for interface compatibility)
func (r *Repository) GetUniversityByID(ctx context.Context, id uuid.UUID) (*University, error) {
return r.GetUniversity(ctx, id)
}
// GetUniversityByURL retrieves a university by URL
func (r *Repository) GetUniversityByURL(ctx context.Context, url string) (*University, error) {
query := `SELECT id, name, short_name, url, state, uni_type, staff_page_pattern, created_at, updated_at
FROM universities WHERE url = $1`
u := &University{}
err := r.db.Pool.QueryRow(ctx, query, url).Scan(
&u.ID, &u.Name, &u.ShortName, &u.URL, &u.State, &u.UniType,
&u.StaffPagePattern, &u.CreatedAt, &u.UpdatedAt,
)
if err != nil {
return nil, err
}
return u, nil
}
// ListUniversities lists all universities
func (r *Repository) ListUniversities(ctx context.Context) ([]University, error) {
query := `SELECT id, name, short_name, url, state, uni_type, staff_page_pattern, created_at, updated_at
FROM universities ORDER BY name`
rows, err := r.db.Pool.Query(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var universities []University
for rows.Next() {
var u University
if err := rows.Scan(
&u.ID, &u.Name, &u.ShortName, &u.URL, &u.State, &u.UniType,
&u.StaffPagePattern, &u.CreatedAt, &u.UpdatedAt,
); err != nil {
return nil, err
}
universities = append(universities, u)
}
return universities, rows.Err()
}
// ============================================================================
// DEPARTMENTS
// ============================================================================
// CreateDepartment creates or updates a department
func (r *Repository) CreateDepartment(ctx context.Context, d *Department) error {
query := `
INSERT INTO departments (university_id, name, name_en, url, category, parent_id)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (university_id, name) DO UPDATE SET
name_en = EXCLUDED.name_en,
url = EXCLUDED.url,
category = EXCLUDED.category,
parent_id = EXCLUDED.parent_id,
updated_at = NOW()
RETURNING id, created_at, updated_at
`
return r.db.Pool.QueryRow(ctx, query,
d.UniversityID, d.Name, d.NameEN, d.URL, d.Category, d.ParentID,
).Scan(&d.ID, &d.CreatedAt, &d.UpdatedAt)
}
// GetDepartmentByName retrieves a department by university and name
func (r *Repository) GetDepartmentByName(ctx context.Context, uniID uuid.UUID, name string) (*Department, error) {
query := `SELECT id, university_id, name, name_en, url, category, parent_id, created_at, updated_at
FROM departments WHERE university_id = $1 AND name = $2`
d := &Department{}
err := r.db.Pool.QueryRow(ctx, query, uniID, name).Scan(
&d.ID, &d.UniversityID, &d.Name, &d.NameEN, &d.URL, &d.Category,
&d.ParentID, &d.CreatedAt, &d.UpdatedAt,
)
if err != nil {
return nil, err
}
return d, nil
}
// ============================================================================
// STAFF
// ============================================================================
// CreateStaff creates or updates a staff member
func (r *Repository) CreateStaff(ctx context.Context, s *UniversityStaff) error {
query := `
INSERT INTO university_staff (
university_id, department_id, first_name, last_name, full_name,
title, academic_title, position, position_type, is_professor,
email, phone, office, profile_url, photo_url,
orcid, google_scholar_id, researchgate_url, linkedin_url, personal_website,
research_interests, research_summary, supervisor_id, team_role, source_url
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25
)
ON CONFLICT (university_id, first_name, last_name, COALESCE(department_id, '00000000-0000-0000-0000-000000000000'::uuid))
DO UPDATE SET
full_name = EXCLUDED.full_name,
title = EXCLUDED.title,
academic_title = EXCLUDED.academic_title,
position = EXCLUDED.position,
position_type = EXCLUDED.position_type,
is_professor = EXCLUDED.is_professor,
email = COALESCE(EXCLUDED.email, university_staff.email),
phone = COALESCE(EXCLUDED.phone, university_staff.phone),
office = COALESCE(EXCLUDED.office, university_staff.office),
profile_url = COALESCE(EXCLUDED.profile_url, university_staff.profile_url),
photo_url = COALESCE(EXCLUDED.photo_url, university_staff.photo_url),
orcid = COALESCE(EXCLUDED.orcid, university_staff.orcid),
google_scholar_id = COALESCE(EXCLUDED.google_scholar_id, university_staff.google_scholar_id),
researchgate_url = COALESCE(EXCLUDED.researchgate_url, university_staff.researchgate_url),
linkedin_url = COALESCE(EXCLUDED.linkedin_url, university_staff.linkedin_url),
personal_website = COALESCE(EXCLUDED.personal_website, university_staff.personal_website),
research_interests = COALESCE(EXCLUDED.research_interests, university_staff.research_interests),
research_summary = COALESCE(EXCLUDED.research_summary, university_staff.research_summary),
supervisor_id = COALESCE(EXCLUDED.supervisor_id, university_staff.supervisor_id),
team_role = COALESCE(EXCLUDED.team_role, university_staff.team_role),
source_url = COALESCE(EXCLUDED.source_url, university_staff.source_url),
crawled_at = NOW(),
updated_at = NOW()
RETURNING id, crawled_at, created_at, updated_at
`
return r.db.Pool.QueryRow(ctx, query,
s.UniversityID, s.DepartmentID, s.FirstName, s.LastName, s.FullName,
s.Title, s.AcademicTitle, s.Position, s.PositionType, s.IsProfessor,
s.Email, s.Phone, s.Office, s.ProfileURL, s.PhotoURL,
s.ORCID, s.GoogleScholarID, s.ResearchgateURL, s.LinkedInURL, s.PersonalWebsite,
s.ResearchInterests, s.ResearchSummary, s.SupervisorID, s.TeamRole, s.SourceURL,
).Scan(&s.ID, &s.CrawledAt, &s.CreatedAt, &s.UpdatedAt)
}
// GetStaff retrieves a staff member by ID
func (r *Repository) GetStaff(ctx context.Context, id uuid.UUID) (*UniversityStaff, error) {
query := `SELECT * FROM v_staff_full WHERE id = $1`
s := &UniversityStaff{}
err := r.db.Pool.QueryRow(ctx, query, id).Scan(
&s.ID, &s.UniversityID, &s.DepartmentID, &s.FirstName, &s.LastName, &s.FullName,
&s.Title, &s.AcademicTitle, &s.Position, &s.PositionType, &s.IsProfessor,
&s.Email, &s.Phone, &s.Office, &s.ProfileURL, &s.PhotoURL,
&s.ORCID, &s.GoogleScholarID, &s.ResearchgateURL, &s.LinkedInURL, &s.PersonalWebsite,
&s.ResearchInterests, &s.ResearchSummary, &s.CrawledAt, &s.LastVerified, &s.IsActive, &s.SourceURL,
&s.CreatedAt, &s.UpdatedAt, &s.UniversityName, &s.UniversityShort, nil, nil,
&s.DepartmentName, nil, &s.PublicationCount,
)
if err != nil {
return nil, err
}
return s, nil
}
// SearchStaff searches for staff members
func (r *Repository) SearchStaff(ctx context.Context, params StaffSearchParams) (*StaffSearchResult, error) {
// Build query dynamically
var conditions []string
var args []interface{}
argNum := 1
baseQuery := `
SELECT s.id, s.university_id, s.department_id, s.first_name, s.last_name, s.full_name,
s.title, s.academic_title, s.position, s.position_type, s.is_professor,
s.email, s.profile_url, s.photo_url, s.orcid,
s.research_interests, s.crawled_at, s.is_active,
u.name as university_name, u.short_name as university_short, u.state as university_state,
d.name as department_name,
(SELECT COUNT(*) FROM staff_publications sp WHERE sp.staff_id = s.id) as publication_count
FROM university_staff s
JOIN universities u ON s.university_id = u.id
LEFT JOIN departments d ON s.department_id = d.id
`
if params.Query != "" {
conditions = append(conditions, fmt.Sprintf(
`(to_tsvector('german', COALESCE(s.full_name, '') || ' ' || COALESCE(s.research_summary, '')) @@ plainto_tsquery('german', $%d)
OR s.full_name ILIKE '%%' || $%d || '%%'
OR s.last_name ILIKE '%%' || $%d || '%%')`,
argNum, argNum, argNum))
args = append(args, params.Query)
argNum++
}
if params.UniversityID != nil {
conditions = append(conditions, fmt.Sprintf("s.university_id = $%d", argNum))
args = append(args, *params.UniversityID)
argNum++
}
if params.DepartmentID != nil {
conditions = append(conditions, fmt.Sprintf("s.department_id = $%d", argNum))
args = append(args, *params.DepartmentID)
argNum++
}
if params.State != nil {
conditions = append(conditions, fmt.Sprintf("u.state = $%d", argNum))
args = append(args, *params.State)
argNum++
}
if params.UniType != nil {
conditions = append(conditions, fmt.Sprintf("u.uni_type = $%d", argNum))
args = append(args, *params.UniType)
argNum++
}
if params.PositionType != nil {
conditions = append(conditions, fmt.Sprintf("s.position_type = $%d", argNum))
args = append(args, *params.PositionType)
argNum++
}
if params.IsProfessor != nil {
conditions = append(conditions, fmt.Sprintf("s.is_professor = $%d", argNum))
args = append(args, *params.IsProfessor)
argNum++
}
// Build WHERE clause
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// Count total
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM university_staff s JOIN universities u ON s.university_id = u.id LEFT JOIN departments d ON s.department_id = d.id %s", whereClause)
var total int
if err := r.db.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, err
}
// Apply pagination
limit := params.Limit
if limit <= 0 {
limit = 20
}
if limit > 100 {
limit = 100
}
offset := params.Offset
if offset < 0 {
offset = 0
}
// Full query with pagination
fullQuery := fmt.Sprintf("%s %s ORDER BY s.is_professor DESC, s.last_name ASC LIMIT %d OFFSET %d",
baseQuery, whereClause, limit, offset)
rows, err := r.db.Pool.Query(ctx, fullQuery, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var staff []UniversityStaff
for rows.Next() {
var s UniversityStaff
var uniState *string
if err := rows.Scan(
&s.ID, &s.UniversityID, &s.DepartmentID, &s.FirstName, &s.LastName, &s.FullName,
&s.Title, &s.AcademicTitle, &s.Position, &s.PositionType, &s.IsProfessor,
&s.Email, &s.ProfileURL, &s.PhotoURL, &s.ORCID,
&s.ResearchInterests, &s.CrawledAt, &s.IsActive,
&s.UniversityName, &s.UniversityShort, &uniState,
&s.DepartmentName, &s.PublicationCount,
); err != nil {
return nil, err
}
staff = append(staff, s)
}
return &StaffSearchResult{
Staff: staff,
Total: total,
Limit: limit,
Offset: offset,
Query: params.Query,
}, rows.Err()
}
// ============================================================================
// PUBLICATIONS
// ============================================================================
// CreatePublication creates or updates a publication
func (r *Repository) CreatePublication(ctx context.Context, p *Publication) error {
query := `
INSERT INTO publications (
title, title_en, abstract, abstract_en, year, month,
pub_type, venue, venue_short, publisher,
doi, isbn, issn, arxiv_id, pubmed_id,
url, pdf_url, citation_count, keywords, topics, source, raw_data
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22
)
ON CONFLICT (doi) WHERE doi IS NOT NULL DO UPDATE SET
title = EXCLUDED.title,
abstract = EXCLUDED.abstract,
year = EXCLUDED.year,
venue = EXCLUDED.venue,
citation_count = EXCLUDED.citation_count,
updated_at = NOW()
RETURNING id, crawled_at, created_at, updated_at
`
// Handle potential duplicate without DOI
err := r.db.Pool.QueryRow(ctx, query,
p.Title, p.TitleEN, p.Abstract, p.AbstractEN, p.Year, p.Month,
p.PubType, p.Venue, p.VenueShort, p.Publisher,
p.DOI, p.ISBN, p.ISSN, p.ArxivID, p.PubmedID,
p.URL, p.PDFURL, p.CitationCount, p.Keywords, p.Topics, p.Source, p.RawData,
).Scan(&p.ID, &p.CrawledAt, &p.CreatedAt, &p.UpdatedAt)
if err != nil && strings.Contains(err.Error(), "duplicate") {
// Try to find existing publication by title and year
findQuery := `SELECT id FROM publications WHERE title = $1 AND year = $2`
err = r.db.Pool.QueryRow(ctx, findQuery, p.Title, p.Year).Scan(&p.ID)
}
return err
}
// LinkStaffPublication creates a link between staff and publication
func (r *Repository) LinkStaffPublication(ctx context.Context, sp *StaffPublication) error {
query := `
INSERT INTO staff_publications (staff_id, publication_id, author_position, is_corresponding)
VALUES ($1, $2, $3, $4)
ON CONFLICT (staff_id, publication_id) DO UPDATE SET
author_position = EXCLUDED.author_position,
is_corresponding = EXCLUDED.is_corresponding
`
_, err := r.db.Pool.Exec(ctx, query,
sp.StaffID, sp.PublicationID, sp.AuthorPosition, sp.IsCorresponding,
)
return err
}
// GetStaffPublications retrieves all publications for a staff member
func (r *Repository) GetStaffPublications(ctx context.Context, staffID uuid.UUID) ([]Publication, error) {
query := `
SELECT p.id, p.title, p.abstract, p.year, p.pub_type, p.venue, p.doi, p.url, p.citation_count
FROM publications p
JOIN staff_publications sp ON p.id = sp.publication_id
WHERE sp.staff_id = $1
ORDER BY p.year DESC NULLS LAST, p.title
`
rows, err := r.db.Pool.Query(ctx, query, staffID)
if err != nil {
return nil, err
}
defer rows.Close()
var pubs []Publication
for rows.Next() {
var p Publication
if err := rows.Scan(
&p.ID, &p.Title, &p.Abstract, &p.Year, &p.PubType, &p.Venue, &p.DOI, &p.URL, &p.CitationCount,
); err != nil {
return nil, err
}
pubs = append(pubs, p)
}
return pubs, rows.Err()
}
// SearchPublications searches for publications
func (r *Repository) SearchPublications(ctx context.Context, params PublicationSearchParams) (*PublicationSearchResult, error) {
var conditions []string
var args []interface{}
argNum := 1
if params.Query != "" {
conditions = append(conditions, fmt.Sprintf(
`to_tsvector('german', COALESCE(title, '') || ' ' || COALESCE(abstract, '')) @@ plainto_tsquery('german', $%d)`,
argNum))
args = append(args, params.Query)
argNum++
}
if params.StaffID != nil {
conditions = append(conditions, fmt.Sprintf(
`id IN (SELECT publication_id FROM staff_publications WHERE staff_id = $%d)`,
argNum))
args = append(args, *params.StaffID)
argNum++
}
if params.Year != nil {
conditions = append(conditions, fmt.Sprintf("year = $%d", argNum))
args = append(args, *params.Year)
argNum++
}
if params.YearFrom != nil {
conditions = append(conditions, fmt.Sprintf("year >= $%d", argNum))
args = append(args, *params.YearFrom)
argNum++
}
if params.YearTo != nil {
conditions = append(conditions, fmt.Sprintf("year <= $%d", argNum))
args = append(args, *params.YearTo)
argNum++
}
if params.PubType != nil {
conditions = append(conditions, fmt.Sprintf("pub_type = $%d", argNum))
args = append(args, *params.PubType)
argNum++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// Count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM publications %s", whereClause)
var total int
if err := r.db.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, err
}
// Pagination
limit := params.Limit
if limit <= 0 {
limit = 20
}
offset := params.Offset
// Query
query := fmt.Sprintf(`
SELECT id, title, abstract, year, pub_type, venue, doi, url, citation_count, keywords
FROM publications %s
ORDER BY year DESC NULLS LAST, citation_count DESC
LIMIT %d OFFSET %d
`, whereClause, limit, offset)
rows, err := r.db.Pool.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var pubs []Publication
for rows.Next() {
var p Publication
if err := rows.Scan(
&p.ID, &p.Title, &p.Abstract, &p.Year, &p.PubType, &p.Venue, &p.DOI, &p.URL, &p.CitationCount, &p.Keywords,
); err != nil {
return nil, err
}
pubs = append(pubs, p)
}
return &PublicationSearchResult{
Publications: pubs,
Total: total,
Limit: limit,
Offset: offset,
Query: params.Query,
}, rows.Err()
}
// ============================================================================
// CRAWL STATUS
// ============================================================================
// UpdateCrawlStatus updates crawl status for a university
func (r *Repository) UpdateCrawlStatus(ctx context.Context, status *UniversityCrawlStatus) error {
query := `
INSERT INTO university_crawl_status (
university_id, last_staff_crawl, staff_crawl_status, staff_count, staff_errors,
last_pub_crawl, pub_crawl_status, pub_count, pub_errors,
next_scheduled_crawl, crawl_priority
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (university_id) DO UPDATE SET
last_staff_crawl = EXCLUDED.last_staff_crawl,
staff_crawl_status = EXCLUDED.staff_crawl_status,
staff_count = EXCLUDED.staff_count,
staff_errors = EXCLUDED.staff_errors,
last_pub_crawl = EXCLUDED.last_pub_crawl,
pub_crawl_status = EXCLUDED.pub_crawl_status,
pub_count = EXCLUDED.pub_count,
pub_errors = EXCLUDED.pub_errors,
next_scheduled_crawl = EXCLUDED.next_scheduled_crawl,
crawl_priority = EXCLUDED.crawl_priority,
updated_at = NOW()
`
_, err := r.db.Pool.Exec(ctx, query,
status.UniversityID, status.LastStaffCrawl, status.StaffCrawlStatus, status.StaffCount, status.StaffErrors,
status.LastPubCrawl, status.PubCrawlStatus, status.PubCount, status.PubErrors,
status.NextScheduledCrawl, status.CrawlPriority,
)
return err
}
// GetCrawlStatus retrieves crawl status for a university
func (r *Repository) GetCrawlStatus(ctx context.Context, uniID uuid.UUID) (*UniversityCrawlStatus, error) {
query := `SELECT * FROM university_crawl_status WHERE university_id = $1`
s := &UniversityCrawlStatus{}
err := r.db.Pool.QueryRow(ctx, query, uniID).Scan(
&s.UniversityID, &s.LastStaffCrawl, &s.StaffCrawlStatus, &s.StaffCount, &s.StaffErrors,
&s.LastPubCrawl, &s.PubCrawlStatus, &s.PubCount, &s.PubErrors,
&s.NextScheduledCrawl, &s.CrawlPriority, &s.CreatedAt, &s.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return s, nil
}
// ============================================================================
// STATS
// ============================================================================
// GetStaffStats retrieves statistics about staff data
func (r *Repository) GetStaffStats(ctx context.Context) (*StaffStats, error) {
stats := &StaffStats{
ByState: make(map[string]int),
ByUniType: make(map[string]int),
ByPositionType: make(map[string]int),
}
// Basic counts
queries := []struct {
query string
dest *int
}{
{"SELECT COUNT(*) FROM university_staff WHERE is_active = true", &stats.TotalStaff},
{"SELECT COUNT(*) FROM university_staff WHERE is_professor = true AND is_active = true", &stats.TotalProfessors},
{"SELECT COUNT(*) FROM publications", &stats.TotalPublications},
{"SELECT COUNT(*) FROM universities", &stats.TotalUniversities},
}
for _, q := range queries {
if err := r.db.Pool.QueryRow(ctx, q.query).Scan(q.dest); err != nil {
return nil, err
}
}
// By state
rows, err := r.db.Pool.Query(ctx, `
SELECT COALESCE(u.state, 'unknown'), COUNT(*)
FROM university_staff s
JOIN universities u ON s.university_id = u.id
WHERE s.is_active = true
GROUP BY u.state
`)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var state string
var count int
if err := rows.Scan(&state, &count); err != nil {
return nil, err
}
stats.ByState[state] = count
}
// By uni type
rows2, err := r.db.Pool.Query(ctx, `
SELECT COALESCE(u.uni_type, 'unknown'), COUNT(*)
FROM university_staff s
JOIN universities u ON s.university_id = u.id
WHERE s.is_active = true
GROUP BY u.uni_type
`)
if err != nil {
return nil, err
}
defer rows2.Close()
for rows2.Next() {
var uniType string
var count int
if err := rows2.Scan(&uniType, &count); err != nil {
return nil, err
}
stats.ByUniType[uniType] = count
}
// By position type
rows3, err := r.db.Pool.Query(ctx, `
SELECT COALESCE(position_type, 'unknown'), COUNT(*)
FROM university_staff
WHERE is_active = true
GROUP BY position_type
`)
if err != nil {
return nil, err
}
defer rows3.Close()
for rows3.Next() {
var posType string
var count int
if err := rows3.Scan(&posType, &count); err != nil {
return nil, err
}
stats.ByPositionType[posType] = count
}
return stats, nil
}

View File

@@ -0,0 +1,332 @@
package embedding
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
// EmbeddingProvider defines the interface for embedding services
type EmbeddingProvider interface {
// Embed generates embeddings for the given text
Embed(ctx context.Context, text string) ([]float32, error)
// EmbedBatch generates embeddings for multiple texts
EmbedBatch(ctx context.Context, texts []string) ([][]float32, error)
// Dimension returns the embedding vector dimension
Dimension() int
}
// Service wraps an embedding provider
type Service struct {
provider EmbeddingProvider
dimension int
enabled bool
}
// NewService creates a new embedding service based on configuration
func NewService(provider, apiKey, model, ollamaURL string, dimension int, enabled bool) (*Service, error) {
if !enabled {
return &Service{
provider: nil,
dimension: dimension,
enabled: false,
}, nil
}
var p EmbeddingProvider
var err error
switch provider {
case "openai":
if apiKey == "" {
return nil, errors.New("OpenAI API key required for openai provider")
}
p = NewOpenAIProvider(apiKey, model, dimension)
case "ollama":
p, err = NewOllamaProvider(ollamaURL, model, dimension)
if err != nil {
return nil, err
}
case "none", "":
return &Service{
provider: nil,
dimension: dimension,
enabled: false,
}, nil
default:
return nil, fmt.Errorf("unknown embedding provider: %s", provider)
}
return &Service{
provider: p,
dimension: dimension,
enabled: true,
}, nil
}
// IsEnabled returns true if semantic search is enabled
func (s *Service) IsEnabled() bool {
return s.enabled && s.provider != nil
}
// Embed generates embedding for a single text
func (s *Service) Embed(ctx context.Context, text string) ([]float32, error) {
if !s.IsEnabled() {
return nil, errors.New("embedding service not enabled")
}
return s.provider.Embed(ctx, text)
}
// EmbedBatch generates embeddings for multiple texts
func (s *Service) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
if !s.IsEnabled() {
return nil, errors.New("embedding service not enabled")
}
return s.provider.EmbedBatch(ctx, texts)
}
// Dimension returns the configured embedding dimension
func (s *Service) Dimension() int {
return s.dimension
}
// =====================================================
// OpenAI Embedding Provider
// =====================================================
// OpenAIProvider implements EmbeddingProvider using OpenAI's API
type OpenAIProvider struct {
apiKey string
model string
dimension int
httpClient *http.Client
}
// NewOpenAIProvider creates a new OpenAI embedding provider
func NewOpenAIProvider(apiKey, model string, dimension int) *OpenAIProvider {
return &OpenAIProvider{
apiKey: apiKey,
model: model,
dimension: dimension,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// openAIEmbeddingRequest represents the OpenAI API request
type openAIEmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
Dimensions int `json:"dimensions,omitempty"`
}
// openAIEmbeddingResponse represents the OpenAI API response
type openAIEmbeddingResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error,omitempty"`
}
// Embed generates embedding for a single text
func (p *OpenAIProvider) Embed(ctx context.Context, text string) ([]float32, error) {
embeddings, err := p.EmbedBatch(ctx, []string{text})
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, errors.New("no embedding returned")
}
return embeddings[0], nil
}
// EmbedBatch generates embeddings for multiple texts
func (p *OpenAIProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// Truncate texts to avoid token limits (max ~8000 tokens per text)
truncatedTexts := make([]string, len(texts))
for i, text := range texts {
if len(text) > 30000 { // Rough estimate: ~4 chars per token
truncatedTexts[i] = text[:30000]
} else {
truncatedTexts[i] = text
}
}
reqBody := openAIEmbeddingRequest{
Model: p.model,
Input: truncatedTexts,
}
// Only set dimensions for models that support it (text-embedding-3-*)
if p.model == "text-embedding-3-small" || p.model == "text-embedding-3-large" {
reqBody.Dimensions = p.dimension
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/embeddings", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call OpenAI API: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var apiResp openAIEmbeddingResponse
if err := json.Unmarshal(respBody, &apiResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if apiResp.Error != nil {
return nil, fmt.Errorf("OpenAI API error: %s", apiResp.Error.Message)
}
if len(apiResp.Data) != len(texts) {
return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(apiResp.Data))
}
// Sort by index to maintain order
result := make([][]float32, len(texts))
for _, item := range apiResp.Data {
result[item.Index] = item.Embedding
}
return result, nil
}
// Dimension returns the embedding dimension
func (p *OpenAIProvider) Dimension() int {
return p.dimension
}
// =====================================================
// Ollama Embedding Provider (for local models)
// =====================================================
// OllamaProvider implements EmbeddingProvider using Ollama's API
type OllamaProvider struct {
baseURL string
model string
dimension int
httpClient *http.Client
}
// NewOllamaProvider creates a new Ollama embedding provider
func NewOllamaProvider(baseURL, model string, dimension int) (*OllamaProvider, error) {
return &OllamaProvider{
baseURL: baseURL,
model: model,
dimension: dimension,
httpClient: &http.Client{
Timeout: 120 * time.Second, // Ollama can be slow on first inference
},
}, nil
}
// ollamaEmbeddingRequest represents the Ollama API request
type ollamaEmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
// ollamaEmbeddingResponse represents the Ollama API response
type ollamaEmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
// Embed generates embedding for a single text
func (p *OllamaProvider) Embed(ctx context.Context, text string) ([]float32, error) {
// Truncate text
if len(text) > 30000 {
text = text[:30000]
}
reqBody := ollamaEmbeddingRequest{
Model: p.model,
Prompt: text,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/api/embeddings", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call Ollama API: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Ollama API error (status %d): %s", resp.StatusCode, string(respBody))
}
var apiResp ollamaEmbeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return apiResp.Embedding, nil
}
// EmbedBatch generates embeddings for multiple texts (sequential for Ollama)
func (p *OllamaProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
result := make([][]float32, len(texts))
for i, text := range texts {
embedding, err := p.Embed(ctx, text)
if err != nil {
return nil, fmt.Errorf("failed to embed text %d: %w", i, err)
}
result[i] = embedding
}
return result, nil
}
// Dimension returns the embedding dimension
func (p *OllamaProvider) Dimension() int {
return p.dimension
}

View File

@@ -0,0 +1,319 @@
package embedding
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewService_Disabled(t *testing.T) {
service, err := NewService("none", "", "", "", 1536, false)
if err != nil {
t.Fatalf("NewService failed: %v", err)
}
if service.IsEnabled() {
t.Error("Service should not be enabled")
}
if service.Dimension() != 1536 {
t.Errorf("Expected dimension 1536, got %d", service.Dimension())
}
}
func TestNewService_DisabledByProvider(t *testing.T) {
service, err := NewService("none", "", "", "", 1536, true)
if err != nil {
t.Fatalf("NewService failed: %v", err)
}
if service.IsEnabled() {
t.Error("Service should not be enabled when provider is 'none'")
}
}
func TestNewService_OpenAIMissingKey(t *testing.T) {
_, err := NewService("openai", "", "", "", 1536, true)
if err == nil {
t.Error("Expected error for missing OpenAI API key")
}
}
func TestNewService_UnknownProvider(t *testing.T) {
_, err := NewService("unknown", "", "", "", 1536, true)
if err == nil {
t.Error("Expected error for unknown provider")
}
}
func TestService_EmbedWhenDisabled(t *testing.T) {
service, _ := NewService("none", "", "", "", 1536, false)
_, err := service.Embed(context.Background(), "test text")
if err == nil {
t.Error("Expected error when embedding with disabled service")
}
}
func TestService_EmbedBatchWhenDisabled(t *testing.T) {
service, _ := NewService("none", "", "", "", 1536, false)
_, err := service.EmbedBatch(context.Background(), []string{"test1", "test2"})
if err == nil {
t.Error("Expected error when embedding batch with disabled service")
}
}
// =====================================================
// OpenAI Provider Tests with Mock Server
// =====================================================
func TestOpenAIProvider_Embed(t *testing.T) {
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request
if r.Method != "POST" {
t.Errorf("Expected POST, got %s", r.Method)
}
if r.Header.Get("Authorization") != "Bearer test-api-key" {
t.Errorf("Expected correct Authorization header")
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json")
}
// Parse request body
var reqBody openAIEmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
t.Fatalf("Failed to parse request body: %v", err)
}
if reqBody.Model != "text-embedding-3-small" {
t.Errorf("Expected model text-embedding-3-small, got %s", reqBody.Model)
}
// Send mock response
resp := openAIEmbeddingResponse{
Data: []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}{
{
Embedding: make([]float32, 1536),
Index: 0,
},
},
}
resp.Data[0].Embedding[0] = 0.1
resp.Data[0].Embedding[1] = 0.2
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
// Create provider with mock server (we need to override the URL)
provider := &OpenAIProvider{
apiKey: "test-api-key",
model: "text-embedding-3-small",
dimension: 1536,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
// Note: This test won't actually work with the mock server because
// the provider hardcodes the OpenAI URL. This is a structural test.
// For real testing, we'd need to make the URL configurable.
if provider.Dimension() != 1536 {
t.Errorf("Expected dimension 1536, got %d", provider.Dimension())
}
}
func TestOpenAIProvider_EmbedBatch_EmptyInput(t *testing.T) {
provider := NewOpenAIProvider("test-key", "text-embedding-3-small", 1536)
result, err := provider.EmbedBatch(context.Background(), []string{})
if err != nil {
t.Errorf("Empty input should not cause error: %v", err)
}
if result != nil {
t.Errorf("Expected nil result for empty input, got %v", result)
}
}
// =====================================================
// Ollama Provider Tests with Mock Server
// =====================================================
func TestOllamaProvider_Embed(t *testing.T) {
// Create mock server
mockEmbedding := make([]float32, 384)
mockEmbedding[0] = 0.5
mockEmbedding[1] = 0.3
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST, got %s", r.Method)
}
if r.URL.Path != "/api/embeddings" {
t.Errorf("Expected path /api/embeddings, got %s", r.URL.Path)
}
// Parse request
var reqBody ollamaEmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
t.Fatalf("Failed to parse request: %v", err)
}
if reqBody.Model != "nomic-embed-text" {
t.Errorf("Expected model nomic-embed-text, got %s", reqBody.Model)
}
// Send response
resp := ollamaEmbeddingResponse{
Embedding: mockEmbedding,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider, err := NewOllamaProvider(server.URL, "nomic-embed-text", 384)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
ctx := context.Background()
embedding, err := provider.Embed(ctx, "Test text für Embedding")
if err != nil {
t.Fatalf("Embed failed: %v", err)
}
if len(embedding) != 384 {
t.Errorf("Expected 384 dimensions, got %d", len(embedding))
}
if embedding[0] != 0.5 {
t.Errorf("Expected first value 0.5, got %f", embedding[0])
}
}
func TestOllamaProvider_EmbedBatch(t *testing.T) {
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
mockEmbedding := make([]float32, 384)
mockEmbedding[0] = float32(callCount) * 0.1
resp := ollamaEmbeddingResponse{
Embedding: mockEmbedding,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider, err := NewOllamaProvider(server.URL, "nomic-embed-text", 384)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
ctx := context.Background()
texts := []string{"Text 1", "Text 2", "Text 3"}
embeddings, err := provider.EmbedBatch(ctx, texts)
if err != nil {
t.Fatalf("EmbedBatch failed: %v", err)
}
if len(embeddings) != 3 {
t.Errorf("Expected 3 embeddings, got %d", len(embeddings))
}
// Verify each embedding was called
if callCount != 3 {
t.Errorf("Expected 3 API calls, got %d", callCount)
}
}
func TestOllamaProvider_EmbedServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal server error"))
}))
defer server.Close()
provider, _ := NewOllamaProvider(server.URL, "nomic-embed-text", 384)
_, err := provider.Embed(context.Background(), "test")
if err == nil {
t.Error("Expected error for server error response")
}
}
func TestOllamaProvider_Dimension(t *testing.T) {
provider, _ := NewOllamaProvider("http://localhost:11434", "nomic-embed-text", 768)
if provider.Dimension() != 768 {
t.Errorf("Expected dimension 768, got %d", provider.Dimension())
}
}
// =====================================================
// Text Truncation Tests
// =====================================================
func TestOllamaProvider_TextTruncation(t *testing.T) {
receivedText := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqBody ollamaEmbeddingRequest
json.NewDecoder(r.Body).Decode(&reqBody)
receivedText = reqBody.Prompt
resp := ollamaEmbeddingResponse{
Embedding: make([]float32, 384),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider, _ := NewOllamaProvider(server.URL, "nomic-embed-text", 384)
// Create very long text
longText := ""
for i := 0; i < 40000; i++ {
longText += "a"
}
provider.Embed(context.Background(), longText)
// Text should be truncated to 30000 chars
if len(receivedText) > 30000 {
t.Errorf("Expected truncated text <= 30000 chars, got %d", len(receivedText))
}
}
// =====================================================
// Integration Tests (require actual service)
// =====================================================
func TestOpenAIProvider_Integration(t *testing.T) {
// Skip in CI/CD - only run manually with real API key
t.Skip("Integration test - requires OPENAI_API_KEY environment variable")
// provider := NewOpenAIProvider(os.Getenv("OPENAI_API_KEY"), "text-embedding-3-small", 1536)
// embedding, err := provider.Embed(context.Background(), "Lehrplan Mathematik Bayern")
// ...
}

View File

@@ -0,0 +1,464 @@
package extractor
import (
"bytes"
"io"
"regexp"
"strings"
"unicode"
"github.com/PuerkitoBio/goquery"
"github.com/ledongthuc/pdf"
"golang.org/x/net/html"
)
// ExtractedContent contains parsed content from HTML/PDF
type ExtractedContent struct {
Title string
ContentText string
SnippetText string
Language string
ContentLength int
Headings []string
Links []string
MetaData map[string]string
Features ContentFeatures
}
// ContentFeatures for quality scoring
type ContentFeatures struct {
AdDensity float64
LinkDensity float64
TextToHTMLRatio float64
HasMainContent bool
}
// ExtractHTML extracts content from HTML
func ExtractHTML(body []byte) (*ExtractedContent, error) {
doc, err := goquery.NewDocumentFromReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
content := &ExtractedContent{
MetaData: make(map[string]string),
}
// Extract title
content.Title = strings.TrimSpace(doc.Find("title").First().Text())
if content.Title == "" {
content.Title = strings.TrimSpace(doc.Find("h1").First().Text())
}
// Extract meta tags
doc.Find("meta").Each(func(i int, s *goquery.Selection) {
name, _ := s.Attr("name")
property, _ := s.Attr("property")
contentAttr, _ := s.Attr("content")
key := name
if key == "" {
key = property
}
if key != "" && contentAttr != "" {
content.MetaData[strings.ToLower(key)] = contentAttr
}
})
// Try to get og:title if main title is empty
if content.Title == "" {
if ogTitle, ok := content.MetaData["og:title"]; ok {
content.Title = ogTitle
}
}
// Extract headings
doc.Find("h1, h2, h3").Each(func(i int, s *goquery.Selection) {
text := strings.TrimSpace(s.Text())
if text != "" && len(text) < 500 {
content.Headings = append(content.Headings, text)
}
})
// Remove unwanted elements
doc.Find("script, style, nav, header, footer, aside, iframe, noscript, form, .advertisement, .ad, .ads, #cookie-banner, .cookie-notice, .social-share").Remove()
// Try to find main content area
mainContent := doc.Find("main, article, .content, .main-content, #content, #main").First()
if mainContent.Length() == 0 {
mainContent = doc.Find("body")
}
// Extract text content
var textBuilder strings.Builder
mainContent.Find("p, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, pre").Each(func(i int, s *goquery.Selection) {
text := strings.TrimSpace(s.Text())
if text != "" {
textBuilder.WriteString(text)
textBuilder.WriteString("\n\n")
}
})
content.ContentText = cleanText(textBuilder.String())
content.ContentLength = len(content.ContentText)
// Generate snippet (first ~300 chars of meaningful content)
content.SnippetText = generateSnippet(content.ContentText, 300)
// Extract links
doc.Find("a[href]").Each(func(i int, s *goquery.Selection) {
href, exists := s.Attr("href")
if exists && strings.HasPrefix(href, "http") {
content.Links = append(content.Links, href)
}
})
// Detect language
content.Language = detectLanguage(content.ContentText, content.MetaData)
// Calculate features
htmlLen := float64(len(body))
textLen := float64(len(content.ContentText))
if htmlLen > 0 {
content.Features.TextToHTMLRatio = textLen / htmlLen
}
if textLen > 0 {
linkTextLen := 0.0
doc.Find("a").Each(func(i int, s *goquery.Selection) {
linkTextLen += float64(len(s.Text()))
})
content.Features.LinkDensity = linkTextLen / textLen
}
content.Features.HasMainContent = content.ContentLength > 200
// Ad density estimation (very simple heuristic)
adCount := doc.Find(".ad, .ads, .advertisement, [class*='banner'], [id*='banner']").Length()
totalElements := doc.Find("div, p, article, section").Length()
if totalElements > 0 {
content.Features.AdDensity = float64(adCount) / float64(totalElements)
}
return content, nil
}
// ExtractPDF extracts text from PDF using ledongthuc/pdf library
func ExtractPDF(body []byte) (*ExtractedContent, error) {
content := &ExtractedContent{
MetaData: make(map[string]string),
}
// Create a reader from the byte slice
reader := bytes.NewReader(body)
pdfReader, err := pdf.NewReader(reader, int64(len(body)))
if err != nil {
// Fallback to basic extraction if PDF parsing fails
return extractPDFFallback(body)
}
// Extract text using GetPlainText
textReader, err := pdfReader.GetPlainText()
if err != nil {
// Fallback to basic extraction
return extractPDFFallback(body)
}
// Read all text content
var textBuilder strings.Builder
_, err = io.Copy(&textBuilder, textReader)
if err != nil {
return extractPDFFallback(body)
}
rawText := textBuilder.String()
// Clean and process text
content.ContentText = cleanText(rawText)
content.ContentLength = len(content.ContentText)
content.SnippetText = generateSnippet(content.ContentText, 300)
content.Language = detectLanguage(content.ContentText, nil)
content.Features.HasMainContent = content.ContentLength > 200
// Extract title from first significant line
content.Title = extractPDFTitle(content.ContentText)
// Try to extract headings (larger font text often appears first in lines)
content.Headings = extractPDFHeadings(content.ContentText)
// Set PDF-specific metadata
content.MetaData["content_type"] = "application/pdf"
content.MetaData["page_count"] = string(rune(pdfReader.NumPage()))
return content, nil
}
// ExtractPDFWithMetadata extracts text with page-by-page processing
// Use this when you need more control over the extraction process
func ExtractPDFWithMetadata(body []byte) (*ExtractedContent, error) {
content := &ExtractedContent{
MetaData: make(map[string]string),
}
reader := bytes.NewReader(body)
pdfReader, err := pdf.NewReader(reader, int64(len(body)))
if err != nil {
return extractPDFFallback(body)
}
// Extract text page by page for better control
var textBuilder strings.Builder
numPages := pdfReader.NumPage()
for pageNum := 1; pageNum <= numPages; pageNum++ {
page := pdfReader.Page(pageNum)
if page.V.IsNull() {
continue
}
// Get page content
pageContent := page.Content()
for _, text := range pageContent.Text {
textBuilder.WriteString(text.S)
textBuilder.WriteString(" ")
}
textBuilder.WriteString("\n")
}
rawText := textBuilder.String()
// Clean and process text
content.ContentText = cleanText(rawText)
content.ContentLength = len(content.ContentText)
content.SnippetText = generateSnippet(content.ContentText, 300)
content.Language = detectLanguage(content.ContentText, nil)
content.Features.HasMainContent = content.ContentLength > 200
// Extract title and headings from plain text
content.Title = extractPDFTitle(content.ContentText)
content.Headings = extractPDFHeadings(content.ContentText)
content.MetaData["content_type"] = "application/pdf"
content.MetaData["page_count"] = string(rune(numPages))
content.MetaData["extraction_method"] = "page_by_page"
return content, nil
}
// extractPDFFallback uses basic regex extraction when PDF library fails
func extractPDFFallback(body []byte) (*ExtractedContent, error) {
content := &ExtractedContent{
MetaData: make(map[string]string),
}
// Basic PDF text extraction using regex (fallback)
pdfContent := string(body)
var textBuilder strings.Builder
// Find text content in PDF streams
re := regexp.MustCompile(`\((.*?)\)`)
matches := re.FindAllStringSubmatch(pdfContent, -1)
for _, match := range matches {
if len(match) > 1 {
text := match[1]
if isPrintableText(text) {
textBuilder.WriteString(text)
textBuilder.WriteString(" ")
}
}
}
content.ContentText = cleanText(textBuilder.String())
content.ContentLength = len(content.ContentText)
content.SnippetText = generateSnippet(content.ContentText, 300)
content.Language = detectLanguage(content.ContentText, nil)
content.Features.HasMainContent = content.ContentLength > 200
content.Title = extractPDFTitle(content.ContentText)
content.MetaData["content_type"] = "application/pdf"
content.MetaData["extraction_method"] = "fallback"
return content, nil
}
// extractPDFTitle extracts title from PDF content (first significant line)
func extractPDFTitle(text string) string {
lines := strings.Split(text, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Title should be meaningful length
if len(line) >= 10 && len(line) <= 200 {
// Skip lines that look like page numbers or dates
if !regexp.MustCompile(`^\d+$`).MatchString(line) &&
!regexp.MustCompile(`^\d{1,2}\.\d{1,2}\.\d{2,4}$`).MatchString(line) {
return line
}
}
}
return ""
}
// extractPDFHeadings attempts to extract headings from plain text
func extractPDFHeadings(text string) []string {
var headings []string
lines := strings.Split(text, "\n")
for i, line := range lines {
line = strings.TrimSpace(line)
// Skip very short or very long lines
if len(line) < 5 || len(line) > 200 {
continue
}
// Heuristics for headings:
// 1. All caps lines (common in PDFs)
// 2. Lines followed by empty line or starting with numbers (1., 1.1, etc.)
// 3. Short lines at beginning of document
isAllCaps := line == strings.ToUpper(line) && strings.ContainsAny(line, "ABCDEFGHIJKLMNOPQRSTUVWXYZÄÖÜ")
isNumbered := regexp.MustCompile(`^\d+(\.\d+)*\.?\s+\S`).MatchString(line)
isShortAndEarly := i < 20 && len(line) < 80
if (isAllCaps || isNumbered || isShortAndEarly) && !containsHeading(headings, line) {
headings = append(headings, line)
if len(headings) >= 10 {
break // Limit to 10 headings
}
}
}
return headings
}
// containsHeading checks if a heading already exists in the list
func containsHeading(headings []string, heading string) bool {
for _, h := range headings {
if h == heading {
return true
}
}
return false
}
func isPrintableText(s string) bool {
if len(s) < 3 {
return false
}
printable := 0
for _, r := range s {
if unicode.IsPrint(r) && (unicode.IsLetter(r) || unicode.IsSpace(r) || unicode.IsPunct(r)) {
printable++
}
}
return float64(printable)/float64(len(s)) > 0.7
}
func cleanText(text string) string {
// Normalize whitespace
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
// Replace multiple newlines with double newline
re := regexp.MustCompile(`\n{3,}`)
text = re.ReplaceAllString(text, "\n\n")
// Replace multiple spaces with single space
re = regexp.MustCompile(`[ \t]+`)
text = re.ReplaceAllString(text, " ")
// Trim each line
lines := strings.Split(text, "\n")
for i, line := range lines {
lines[i] = strings.TrimSpace(line)
}
text = strings.Join(lines, "\n")
return strings.TrimSpace(text)
}
func generateSnippet(text string, maxLen int) string {
// Find first paragraph with enough content
paragraphs := strings.Split(text, "\n\n")
for _, p := range paragraphs {
p = strings.TrimSpace(p)
if len(p) >= 50 {
if len(p) > maxLen {
// Find word boundary
p = p[:maxLen]
lastSpace := strings.LastIndex(p, " ")
if lastSpace > maxLen/2 {
p = p[:lastSpace]
}
p += "..."
}
return p
}
}
// Fallback: just truncate
if len(text) > maxLen {
text = text[:maxLen] + "..."
}
return text
}
func detectLanguage(text string, meta map[string]string) string {
// Check meta tags first
if meta != nil {
if lang, ok := meta["og:locale"]; ok {
if strings.HasPrefix(lang, "de") {
return "de"
}
if strings.HasPrefix(lang, "en") {
return "en"
}
}
}
// Simple heuristic based on common German words
germanWords := []string{
"und", "der", "die", "das", "ist", "für", "mit", "von",
"werden", "wird", "sind", "auch", "als", "können", "nach",
"einer", "durch", "sich", "bei", "sein", "noch", "haben",
}
englishWords := []string{
"the", "and", "for", "are", "but", "not", "you", "all",
"can", "had", "her", "was", "one", "our", "with", "they",
}
lowerText := strings.ToLower(text)
germanCount := 0
for _, word := range germanWords {
if strings.Contains(lowerText, " "+word+" ") {
germanCount++
}
}
englishCount := 0
for _, word := range englishWords {
if strings.Contains(lowerText, " "+word+" ") {
englishCount++
}
}
if germanCount > englishCount && germanCount > 3 {
return "de"
}
if englishCount > germanCount && englishCount > 3 {
return "en"
}
return "de" // Default to German for education content
}
// UnescapeHTML unescapes HTML entities
func UnescapeHTML(s string) string {
return html.UnescapeString(s)
}

View File

@@ -0,0 +1,802 @@
package extractor
import (
"strings"
"testing"
)
func TestExtractHTML_BasicContent(t *testing.T) {
html := []byte(`<!DOCTYPE html>
<html>
<head>
<title>Test Page Title</title>
<meta name="description" content="Test description">
<meta property="og:title" content="OG Title">
</head>
<body>
<h1>Main Heading</h1>
<p>This is the first paragraph with some meaningful content.</p>
<p>This is another paragraph that adds more information.</p>
</body>
</html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatalf("ExtractHTML failed: %v", err)
}
// Check title
if content.Title != "Test Page Title" {
t.Errorf("Expected title 'Test Page Title', got %q", content.Title)
}
// Check metadata
if content.MetaData["description"] != "Test description" {
t.Errorf("Expected description 'Test description', got %q", content.MetaData["description"])
}
// Check headings
if len(content.Headings) == 0 {
t.Error("Expected at least one heading")
}
if content.Headings[0] != "Main Heading" {
t.Errorf("Expected heading 'Main Heading', got %q", content.Headings[0])
}
// Check content text
if !strings.Contains(content.ContentText, "first paragraph") {
t.Error("Expected content to contain 'first paragraph'")
}
}
func TestExtractHTML_TitleFallback(t *testing.T) {
tests := []struct {
name string
html string
expected string
}{
{
name: "Title from title tag",
html: `<html><head><title>Page Title</title></head><body></body></html>`,
expected: "Page Title",
},
{
name: "Title from H1 when no title tag",
html: `<html><head></head><body><h1>H1 Title</h1></body></html>`,
expected: "H1 Title",
},
{
name: "Title from og:title when no title or h1",
html: `<html><head><meta property="og:title" content="OG Title"></head><body></body></html>`,
expected: "OG Title",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := ExtractHTML([]byte(tt.html))
if err != nil {
t.Fatalf("ExtractHTML failed: %v", err)
}
if content.Title != tt.expected {
t.Errorf("Expected title %q, got %q", tt.expected, content.Title)
}
})
}
}
func TestExtractHTML_RemovesUnwantedElements(t *testing.T) {
html := []byte(`<html>
<body>
<nav>Navigation menu</nav>
<header>Header content</header>
<main>
<p>Main content paragraph</p>
</main>
<script>alert('dangerous');</script>
<style>.hidden{display:none;}</style>
<footer>Footer content</footer>
<aside>Sidebar content</aside>
<div class="advertisement">Ad content</div>
</body>
</html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
// Should contain main content
if !strings.Contains(content.ContentText, "Main content paragraph") {
t.Error("Expected main content to be extracted")
}
// Should not contain unwanted elements
unwanted := []string{"Navigation menu", "alert('dangerous')", "Footer content", "Ad content"}
for _, text := range unwanted {
if strings.Contains(content.ContentText, text) {
t.Errorf("Content should not contain %q", text)
}
}
}
func TestExtractHTML_ExtractsLinks(t *testing.T) {
html := []byte(`<html><body>
<a href="https://example.com/page1">Link 1</a>
<a href="https://example.com/page2">Link 2</a>
<a href="/relative/path">Relative Link</a>
<a href="mailto:test@example.com">Email</a>
</body></html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
// Should extract absolute HTTP links
if len(content.Links) != 2 {
t.Errorf("Expected 2 HTTP links, got %d", len(content.Links))
}
hasPage1 := false
hasPage2 := false
for _, link := range content.Links {
if link == "https://example.com/page1" {
hasPage1 = true
}
if link == "https://example.com/page2" {
hasPage2 = true
}
}
if !hasPage1 || !hasPage2 {
t.Error("Expected to find both HTTP links")
}
}
func TestExtractHTML_CalculatesFeatures(t *testing.T) {
html := []byte(`<html><body>
<div class="advertisement">Ad 1</div>
<p>Some content text that is long enough to be meaningful and provide a good ratio.</p>
<p>More content here to increase the text length.</p>
<a href="#">Link 1</a>
<a href="#">Link 2</a>
</body></html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
// Check features are calculated
if content.Features.TextToHTMLRatio <= 0 {
t.Error("Expected positive TextToHTMLRatio")
}
// Content should have length
if content.ContentLength == 0 {
t.Error("Expected non-zero ContentLength")
}
}
func TestExtractHTML_GeneratesSnippet(t *testing.T) {
html := []byte(`<html><body>
<p>This is a short intro.</p>
<p>This is a longer paragraph that should be used as the snippet because it has more meaningful content and meets the minimum length requirement for a good snippet.</p>
<p>Another paragraph here.</p>
</body></html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
if content.SnippetText == "" {
t.Error("Expected non-empty snippet")
}
// Snippet should be limited in length
if len(content.SnippetText) > 350 { // 300 + "..." margin
t.Errorf("Snippet too long: %d chars", len(content.SnippetText))
}
}
func TestDetectLanguage(t *testing.T) {
tests := []struct {
name string
text string
meta map[string]string
expected string
}{
{
name: "German from meta",
text: "Some text",
meta: map[string]string{"og:locale": "de_DE"},
expected: "de",
},
{
name: "English from meta",
text: "Some text",
meta: map[string]string{"og:locale": "en_US"},
expected: "en",
},
{
name: "German from content",
text: "Dies ist ein Text und der Inhalt wird hier analysiert",
meta: nil,
expected: "de",
},
{
name: "English from content",
text: "This is the content and we are analyzing the text here with all the words they can use for things but not any German",
meta: nil,
expected: "en",
},
{
name: "Default to German for ambiguous",
text: "Hello World",
meta: nil,
expected: "de",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := detectLanguage(tt.text, tt.meta)
if result != tt.expected {
t.Errorf("detectLanguage() = %q, expected %q", result, tt.expected)
}
})
}
}
func TestCleanText(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Normalize Windows line endings",
input: "Line1\r\nLine2",
expected: "Line1\nLine2",
},
{
name: "Collapse multiple newlines",
input: "Line1\n\n\n\n\nLine2",
expected: "Line1\n\nLine2",
},
{
name: "Collapse multiple spaces",
input: "Word1 Word2",
expected: "Word1 Word2",
},
{
name: "Trim whitespace",
input: " Text with spaces \n More text ",
expected: "Text with spaces\nMore text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cleanText(tt.input)
if result != tt.expected {
t.Errorf("cleanText(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
func TestGenerateSnippet(t *testing.T) {
tests := []struct {
name string
text string
maxLen int
checkFn func(string) bool
}{
{
name: "Short text unchanged",
text: "Short paragraph.",
maxLen: 300,
checkFn: func(s string) bool {
return s == "Short paragraph."
},
},
{
name: "Long text truncated",
text: strings.Repeat("A long sentence that keeps going. ", 20),
maxLen: 100,
checkFn: func(s string) bool {
return len(s) <= 103 && strings.HasSuffix(s, "...")
},
},
{
name: "First suitable paragraph",
text: "Tiny.\n\nThis is a paragraph with enough content to be used as a snippet because it meets the minimum length.",
maxLen: 300,
checkFn: func(s string) bool {
return strings.HasPrefix(s, "This is a paragraph")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := generateSnippet(tt.text, tt.maxLen)
if !tt.checkFn(result) {
t.Errorf("generateSnippet() = %q, check failed", result)
}
})
}
}
func TestIsPrintableText(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Normal text",
input: "Hello World",
expected: true,
},
{
name: "German text",
input: "Übung mit Umlauten",
expected: true,
},
{
name: "Too short",
input: "AB",
expected: false,
},
{
name: "Binary data",
input: "\x00\x01\x02\x03\x04",
expected: false,
},
{
name: "Mixed printable",
input: "Text with some \x00 binary",
expected: true, // >70% printable
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isPrintableText(tt.input)
if result != tt.expected {
t.Errorf("isPrintableText(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestExtractHTML_HeadingsExtraction(t *testing.T) {
html := []byte(`<html><body>
<h1>Main Title</h1>
<h2>Section 1</h2>
<p>Content</p>
<h2>Section 2</h2>
<h3>Subsection 2.1</h3>
<p>More content</p>
</body></html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
if len(content.Headings) != 4 {
t.Errorf("Expected 4 headings (h1, h2, h2, h3), got %d", len(content.Headings))
}
expectedHeadings := []string{"Main Title", "Section 1", "Section 2", "Subsection 2.1"}
for i, expected := range expectedHeadings {
if i < len(content.Headings) && content.Headings[i] != expected {
t.Errorf("Heading %d: expected %q, got %q", i, expected, content.Headings[i])
}
}
}
func TestExtractHTML_ContentFromMain(t *testing.T) {
html := []byte(`<html><body>
<div>Outside main</div>
<main>
<article>
<p>Article content that is inside the main element.</p>
</article>
</main>
<div>Also outside</div>
</body></html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(content.ContentText, "Article content") {
t.Error("Expected content from main element")
}
}
func TestExtractHTML_MetadataExtraction(t *testing.T) {
html := []byte(`<html>
<head>
<meta name="author" content="Test Author">
<meta name="keywords" content="education, learning">
<meta property="og:description" content="OG Description">
</head>
<body></body>
</html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
if content.MetaData["author"] != "Test Author" {
t.Errorf("Expected author 'Test Author', got %q", content.MetaData["author"])
}
if content.MetaData["keywords"] != "education, learning" {
t.Errorf("Expected keywords, got %q", content.MetaData["keywords"])
}
if content.MetaData["og:description"] != "OG Description" {
t.Errorf("Expected og:description, got %q", content.MetaData["og:description"])
}
}
func TestUnescapeHTML(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"&amp;", "&"},
{"&lt;script&gt;", "<script>"},
{"&quot;quoted&quot;", "\"quoted\""},
{"&#39;apostrophe&#39;", "'apostrophe'"},
{"No entities", "No entities"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := UnescapeHTML(tt.input)
if result != tt.expected {
t.Errorf("UnescapeHTML(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
func TestExtractPDF_BasicText(t *testing.T) {
// Create minimal PDF-like content with text markers
// Real PDFs would have proper structure, but we test the extraction logic
pdfContent := []byte("(Hello World) (This is a test)")
content, err := ExtractPDF(pdfContent)
if err != nil {
t.Fatalf("ExtractPDF failed: %v", err)
}
// Should extract some text
if content.ContentLength == 0 && !strings.Contains(string(pdfContent), "(Hello") {
// Only fail if there's actually extractable content
t.Log("PDF extraction returned empty content (expected for simple test)")
}
// Features should be set
if content.Language == "" {
t.Error("Expected language to be set")
}
}
func TestExtractHTML_AdDensity(t *testing.T) {
html := []byte(`<html><body>
<div class="advertisement">Ad 1</div>
<div class="advertisement">Ad 2</div>
<div class="advertisement">Ad 3</div>
<p>Content</p>
<div>Normal div</div>
</body></html>`)
content, err := ExtractHTML(html)
if err != nil {
t.Fatal(err)
}
// Ad density should be calculated (3 ads / total divs)
if content.Features.AdDensity < 0 {
t.Error("AdDensity should not be negative")
}
}
func TestExtractHTML_HasMainContent(t *testing.T) {
tests := []struct {
name string
html string
expected bool
}{
{
name: "Sufficient content",
html: `<html><body><p>` + strings.Repeat("Content ", 50) + `</p></body></html>`,
expected: true,
},
{
name: "Insufficient content",
html: `<html><body><p>Short</p></body></html>`,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := ExtractHTML([]byte(tt.html))
if err != nil {
t.Fatal(err)
}
if content.Features.HasMainContent != tt.expected {
t.Errorf("HasMainContent = %v, expected %v", content.Features.HasMainContent, tt.expected)
}
})
}
}
// ============================================================
// PDF Extraction Tests
// ============================================================
func TestExtractPDF_FallbackForInvalidPDF(t *testing.T) {
// Test with non-PDF content - should fallback gracefully
invalidPDF := []byte("This is not a PDF file (just some text content)")
content, err := ExtractPDF(invalidPDF)
if err != nil {
t.Fatalf("ExtractPDF should not fail completely: %v", err)
}
// Should still return a valid ExtractedContent struct
if content == nil {
t.Fatal("Expected non-nil content")
}
// Should detect fallback method
if content.MetaData["extraction_method"] != "fallback" {
t.Log("PDF fallback extraction was used as expected")
}
}
func TestExtractPDF_MetadataSet(t *testing.T) {
// Simple test content
content, err := ExtractPDF([]byte("(Test content)"))
if err != nil {
t.Fatalf("ExtractPDF failed: %v", err)
}
// Content type should be set
if content.MetaData["content_type"] != "application/pdf" {
t.Errorf("Expected content_type 'application/pdf', got %q", content.MetaData["content_type"])
}
// Language should be detected (default to German)
if content.Language == "" {
t.Error("Expected language to be set")
}
}
func TestExtractPDFTitle(t *testing.T) {
tests := []struct {
name string
text string
expected string
}{
{
name: "Normal title",
text: "Lehrplan Mathematik Bayern\n\nDieses Dokument beschreibt...",
expected: "Lehrplan Mathematik Bayern",
},
{
name: "Skip page number",
text: "1\n\nLehrplan Mathematik Bayern\n\nDieses Dokument...",
expected: "Lehrplan Mathematik Bayern",
},
{
name: "Skip date",
text: "15.01.2025\n\nLehrplan Mathematik\n\nDieses Dokument...",
expected: "Lehrplan Mathematik",
},
{
name: "Skip short lines",
text: "Short\n\nThis is a proper title for the document\n\nContent...",
expected: "This is a proper title for the document",
},
{
name: "Empty text",
text: "",
expected: "",
},
{
name: "Only short lines",
text: "A\nB\nC\nD",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractPDFTitle(tt.text)
if result != tt.expected {
t.Errorf("extractPDFTitle() = %q, expected %q", result, tt.expected)
}
})
}
}
func TestExtractPDFHeadings(t *testing.T) {
tests := []struct {
name string
text string
minHeadingCount int
expectedFirst string
}{
{
name: "All caps headings",
text: `EINLEITUNG
Dieser Text beschreibt die wichtigsten Punkte.
KAPITEL EINS
Hier folgt der erste Abschnitt.`,
minHeadingCount: 2,
expectedFirst: "EINLEITUNG",
},
{
name: "Numbered headings",
text: `1. Einführung
Text hier.
1.1 Unterabschnitt
Mehr Text.
2. Hauptteil
Weiterer Inhalt.`,
minHeadingCount: 3,
expectedFirst: "1. Einführung",
},
{
name: "No headings",
text: "einfacher text ohne ueberschriften der nur aus kleinen buchstaben besteht und sehr lang ist damit er nicht als ueberschrift erkannt wird",
minHeadingCount: 0,
expectedFirst: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headings := extractPDFHeadings(tt.text)
if len(headings) < tt.minHeadingCount {
t.Errorf("Expected at least %d headings, got %d", tt.minHeadingCount, len(headings))
}
if tt.expectedFirst != "" && len(headings) > 0 && headings[0] != tt.expectedFirst {
t.Errorf("Expected first heading %q, got %q", tt.expectedFirst, headings[0])
}
})
}
}
func TestExtractPDFHeadings_Limit(t *testing.T) {
// Test that headings are limited to 10
text := ""
for i := 1; i <= 20; i++ {
text += "KAPITEL " + strings.Repeat("X", i) + "\n\nText Text Text.\n\n"
}
headings := extractPDFHeadings(text)
if len(headings) > 10 {
t.Errorf("Expected max 10 headings, got %d", len(headings))
}
}
func TestContainsHeading(t *testing.T) {
headings := []string{"Title One", "Title Two", "Title Three"}
if !containsHeading(headings, "Title Two") {
t.Error("Expected to find 'Title Two'")
}
if containsHeading(headings, "Title Four") {
t.Error("Should not find 'Title Four'")
}
if containsHeading([]string{}, "Any") {
t.Error("Empty list should not contain anything")
}
}
func TestExtractPDFFallback_BasicExtraction(t *testing.T) {
// Test fallback with text in parentheses (PDF text stream format)
pdfLike := []byte("stream\n(Hello World) (This is some text) (More content here)\nendstream")
content, err := extractPDFFallback(pdfLike)
if err != nil {
t.Fatalf("extractPDFFallback failed: %v", err)
}
// Should extract text from parentheses
if !strings.Contains(content.ContentText, "Hello World") && content.ContentLength > 0 {
t.Log("Extracted some content via fallback")
}
// Should mark as fallback
if content.MetaData["extraction_method"] != "fallback" {
t.Error("Expected extraction_method to be 'fallback'")
}
}
func TestExtractPDF_EmptyInput(t *testing.T) {
content, err := ExtractPDF([]byte{})
if err != nil {
t.Fatalf("ExtractPDF should handle empty input: %v", err)
}
if content == nil {
t.Fatal("Expected non-nil content for empty input")
}
if content.ContentLength != 0 {
t.Errorf("Expected 0 content length for empty input, got %d", content.ContentLength)
}
}
func TestExtractPDFWithMetadata_FallbackOnError(t *testing.T) {
// ExtractPDFWithMetadata should fallback gracefully
content, err := ExtractPDFWithMetadata([]byte("not a pdf"))
if err != nil {
t.Fatalf("ExtractPDFWithMetadata should not fail: %v", err)
}
if content == nil {
t.Fatal("Expected non-nil content")
}
}
func TestExtractPDF_LanguageDetection(t *testing.T) {
tests := []struct {
name string
text string
expected string
}{
{
name: "German content",
text: "(Der Lehrplan ist für alle Schulen verbindlich und enthält wichtige Informationen)",
expected: "de",
},
{
name: "Default to German",
text: "(Some text)",
expected: "de",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := ExtractPDF([]byte(tt.text))
if err != nil {
t.Fatalf("ExtractPDF failed: %v", err)
}
// Language should be detected
if content.Language != tt.expected {
t.Logf("Language detected: %s (expected %s)", content.Language, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,243 @@
package indexer
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/opensearch-project/opensearch-go/v2"
"github.com/opensearch-project/opensearch-go/v2/opensearchapi"
)
// IndexMapping defines the OpenSearch index mapping for education documents
const IndexMapping = `{
"settings": {
"index": {
"number_of_shards": 3,
"number_of_replicas": 1,
"refresh_interval": "5s"
},
"analysis": {
"analyzer": {
"german_custom": {
"type": "custom",
"tokenizer": "standard",
"filter": ["lowercase", "german_normalization", "german_stemmer"]
}
},
"filter": {
"german_stemmer": {
"type": "stemmer",
"language": "german"
}
}
}
},
"mappings": {
"properties": {
"doc_id": { "type": "keyword" },
"url": { "type": "keyword" },
"canonical_url": { "type": "keyword" },
"domain": { "type": "keyword" },
"fetch_time": { "type": "date" },
"last_modified": { "type": "date" },
"content_hash": { "type": "keyword" },
"title": {
"type": "text",
"analyzer": "german_custom",
"fields": {
"keyword": { "type": "keyword", "ignore_above": 512 }
}
},
"content_text": {
"type": "text",
"analyzer": "german_custom"
},
"snippet_text": { "type": "text", "index": false },
"content_type": { "type": "keyword" },
"language": { "type": "keyword" },
"country_hint": { "type": "keyword" },
"source_category": { "type": "keyword" },
"doc_type": { "type": "keyword" },
"school_level": { "type": "keyword" },
"subjects": { "type": "keyword" },
"state": { "type": "keyword" },
"trust_score": { "type": "float" },
"quality_score": { "type": "float" },
"spam_flags": { "type": "keyword" },
"outlinks": { "type": "keyword" },
"inlinks_count": { "type": "integer" },
"content_length": { "type": "integer" },
"raw_refs": {
"properties": {
"html_raw_ref": { "type": "keyword" },
"pdf_raw_ref": { "type": "keyword" }
}
},
"tag_reasons": { "type": "keyword" }
}
}
}`
// Document represents an indexed education document
type Document struct {
DocID string `json:"doc_id"`
URL string `json:"url"`
CanonicalURL string `json:"canonical_url,omitempty"`
Domain string `json:"domain"`
FetchedAt time.Time `json:"fetch_time"`
UpdatedAt time.Time `json:"last_modified,omitempty"`
ContentHash string `json:"content_hash"`
Title string `json:"title"`
ContentText string `json:"content_text"`
SnippetText string `json:"snippet_text"`
ContentType string `json:"content_type,omitempty"`
Language string `json:"language"`
CountryHint string `json:"country_hint,omitempty"`
SourceCategory string `json:"source_category,omitempty"`
DocType string `json:"doc_type"`
SchoolLevel string `json:"school_level"`
Subjects []string `json:"subjects"`
State string `json:"state,omitempty"`
TrustScore float64 `json:"trust_score"`
QualityScore float64 `json:"quality_score"`
SpamFlags []string `json:"spam_flags,omitempty"`
Outlinks []string `json:"outlinks,omitempty"`
InlinksCount int `json:"inlinks_count,omitempty"`
ContentLength int `json:"content_length,omitempty"`
TagReasons []string `json:"tag_reasons,omitempty"`
}
// Client wraps OpenSearch operations
type Client struct {
client *opensearch.Client
indexName string
}
// NewClient creates a new OpenSearch indexer client
func NewClient(url, username, password, indexName string) (*Client, error) {
cfg := opensearch.Config{
Addresses: []string{url},
Username: username,
Password: password,
}
client, err := opensearch.NewClient(cfg)
if err != nil {
return nil, err
}
return &Client{
client: client,
indexName: indexName,
}, nil
}
// CreateIndex creates the index with proper mapping
func (c *Client) CreateIndex(ctx context.Context) error {
// Check if index exists
res, err := c.client.Indices.Exists([]string{c.indexName})
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode == 200 {
// Index already exists
return nil
}
// Create index with mapping
req := opensearchapi.IndicesCreateRequest{
Index: c.indexName,
Body: strings.NewReader(IndexMapping),
}
res, err = req.Do(ctx, c.client)
if err != nil {
return err
}
defer res.Body.Close()
return nil
}
// IndexDocument indexes a single document
func (c *Client) IndexDocument(ctx context.Context, doc *Document) error {
body, err := json.Marshal(doc)
if err != nil {
return err
}
req := opensearchapi.IndexRequest{
Index: c.indexName,
DocumentID: doc.DocID,
Body: strings.NewReader(string(body)),
Refresh: "false",
}
res, err := req.Do(ctx, c.client)
if err != nil {
return err
}
defer res.Body.Close()
return nil
}
// BulkIndex indexes multiple documents efficiently
func (c *Client) BulkIndex(ctx context.Context, docs []Document) error {
if len(docs) == 0 {
return nil
}
var builder strings.Builder
for _, doc := range docs {
// Action line
meta := map[string]interface{}{
"index": map[string]interface{}{
"_index": c.indexName,
"_id": doc.DocID,
},
}
metaBytes, _ := json.Marshal(meta)
builder.Write(metaBytes)
builder.WriteString("\n")
// Document line
docBytes, _ := json.Marshal(doc)
builder.Write(docBytes)
builder.WriteString("\n")
}
req := opensearchapi.BulkRequest{
Body: strings.NewReader(builder.String()),
}
res, err := req.Do(ctx, c.client)
if err != nil {
return err
}
defer res.Body.Close()
return nil
}
// Health checks OpenSearch cluster health
func (c *Client) Health(ctx context.Context) (string, error) {
res, err := c.client.Cluster.Health()
if err != nil {
return "", err
}
defer res.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(res.Body).Decode(&result); err != nil {
return "", err
}
status, _ := result["status"].(string)
return status, nil
}

View File

@@ -0,0 +1,424 @@
// Package orchestrator implements multi-phase university crawling with queue management
package orchestrator
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
)
// Audience represents a target audience filter configuration
type Audience struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Filters AudienceFilters `json:"filters"`
MemberCount int `json:"member_count"`
LastCountUpdate *time.Time `json:"last_count_update,omitempty"`
CreatedBy string `json:"created_by,omitempty"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AudienceFilters defines the filter criteria for an audience
type AudienceFilters struct {
PositionTypes []string `json:"position_types,omitempty"` // professor, researcher, lecturer
SubjectAreas []uuid.UUID `json:"subject_areas,omitempty"` // Subject area UUIDs
States []string `json:"states,omitempty"` // BW, BY, etc.
UniTypes []string `json:"uni_types,omitempty"` // UNI, PH, HAW
Universities []uuid.UUID `json:"universities,omitempty"` // University UUIDs
HasEmail *bool `json:"has_email,omitempty"`
IsActive *bool `json:"is_active,omitempty"`
Keywords []string `json:"keywords,omitempty"` // Keywords in name/research
}
// AudienceExport tracks exports of audience data
type AudienceExport struct {
ID uuid.UUID `json:"id"`
AudienceID uuid.UUID `json:"audience_id"`
ExportType string `json:"export_type"` // csv, json, email_list
RecordCount int `json:"record_count"`
FilePath string `json:"file_path,omitempty"`
ExportedBy string `json:"exported_by,omitempty"`
Purpose string `json:"purpose,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// AudienceMember represents a staff member in an audience preview
type AudienceMember struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Email string `json:"email,omitempty"`
Position string `json:"position,omitempty"`
University string `json:"university"`
Department string `json:"department,omitempty"`
SubjectArea string `json:"subject_area,omitempty"`
PublicationCount int `json:"publication_count"`
}
// AudienceRepository extends Repository with audience operations
type AudienceRepository interface {
// Audience CRUD
CreateAudience(ctx context.Context, audience *Audience) error
GetAudience(ctx context.Context, id uuid.UUID) (*Audience, error)
ListAudiences(ctx context.Context, activeOnly bool) ([]Audience, error)
UpdateAudience(ctx context.Context, audience *Audience) error
DeleteAudience(ctx context.Context, id uuid.UUID) error
// Audience members
GetAudienceMembers(ctx context.Context, id uuid.UUID, limit, offset int) ([]AudienceMember, int, error)
UpdateAudienceCount(ctx context.Context, id uuid.UUID) (int, error)
// Exports
CreateExport(ctx context.Context, export *AudienceExport) error
ListExports(ctx context.Context, audienceID uuid.UUID) ([]AudienceExport, error)
}
// ============================================================================
// POSTGRES IMPLEMENTATION
// ============================================================================
// CreateAudience creates a new audience
func (r *PostgresRepository) CreateAudience(ctx context.Context, audience *Audience) error {
filtersJSON, err := json.Marshal(audience.Filters)
if err != nil {
return fmt.Errorf("failed to marshal filters: %w", err)
}
query := `
INSERT INTO audiences (name, description, filters, created_by, is_active)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, member_count, created_at, updated_at
`
return r.pool.QueryRow(ctx, query,
audience.Name,
audience.Description,
filtersJSON,
audience.CreatedBy,
audience.IsActive,
).Scan(&audience.ID, &audience.MemberCount, &audience.CreatedAt, &audience.UpdatedAt)
}
// GetAudience retrieves an audience by ID
func (r *PostgresRepository) GetAudience(ctx context.Context, id uuid.UUID) (*Audience, error) {
query := `
SELECT id, name, description, filters, member_count, last_count_update,
created_by, is_active, created_at, updated_at
FROM audiences
WHERE id = $1
`
var audience Audience
var filtersJSON []byte
err := r.pool.QueryRow(ctx, query, id).Scan(
&audience.ID, &audience.Name, &audience.Description, &filtersJSON,
&audience.MemberCount, &audience.LastCountUpdate,
&audience.CreatedBy, &audience.IsActive,
&audience.CreatedAt, &audience.UpdatedAt,
)
if err != nil {
return nil, err
}
if err := json.Unmarshal(filtersJSON, &audience.Filters); err != nil {
return nil, fmt.Errorf("failed to unmarshal filters: %w", err)
}
return &audience, nil
}
// ListAudiences lists all audiences
func (r *PostgresRepository) ListAudiences(ctx context.Context, activeOnly bool) ([]Audience, error) {
query := `
SELECT id, name, description, filters, member_count, last_count_update,
created_by, is_active, created_at, updated_at
FROM audiences
`
if activeOnly {
query += ` WHERE is_active = TRUE`
}
query += ` ORDER BY created_at DESC`
rows, err := r.pool.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query audiences: %w", err)
}
defer rows.Close()
var audiences []Audience
for rows.Next() {
var audience Audience
var filtersJSON []byte
if err := rows.Scan(
&audience.ID, &audience.Name, &audience.Description, &filtersJSON,
&audience.MemberCount, &audience.LastCountUpdate,
&audience.CreatedBy, &audience.IsActive,
&audience.CreatedAt, &audience.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan audience: %w", err)
}
if err := json.Unmarshal(filtersJSON, &audience.Filters); err != nil {
return nil, fmt.Errorf("failed to unmarshal filters: %w", err)
}
audiences = append(audiences, audience)
}
return audiences, rows.Err()
}
// UpdateAudience updates an existing audience
func (r *PostgresRepository) UpdateAudience(ctx context.Context, audience *Audience) error {
filtersJSON, err := json.Marshal(audience.Filters)
if err != nil {
return fmt.Errorf("failed to marshal filters: %w", err)
}
query := `
UPDATE audiences
SET name = $2, description = $3, filters = $4, is_active = $5, updated_at = NOW()
WHERE id = $1
RETURNING updated_at
`
return r.pool.QueryRow(ctx, query,
audience.ID,
audience.Name,
audience.Description,
filtersJSON,
audience.IsActive,
).Scan(&audience.UpdatedAt)
}
// DeleteAudience soft-deletes an audience (sets is_active = false)
func (r *PostgresRepository) DeleteAudience(ctx context.Context, id uuid.UUID) error {
query := `UPDATE audiences SET is_active = FALSE, updated_at = NOW() WHERE id = $1`
_, err := r.pool.Exec(ctx, query, id)
return err
}
// GetAudienceMembers retrieves members matching the audience filters
func (r *PostgresRepository) GetAudienceMembers(ctx context.Context, id uuid.UUID, limit, offset int) ([]AudienceMember, int, error) {
// First get the audience filters
audience, err := r.GetAudience(ctx, id)
if err != nil {
return nil, 0, fmt.Errorf("failed to get audience: %w", err)
}
// Build dynamic query based on filters
query, args := r.buildAudienceMemberQuery(audience.Filters, limit, offset, false)
countQuery, countArgs := r.buildAudienceMemberQuery(audience.Filters, 0, 0, true)
// Get total count
var totalCount int
if err := r.pool.QueryRow(ctx, countQuery, countArgs...).Scan(&totalCount); err != nil {
return nil, 0, fmt.Errorf("failed to count members: %w", err)
}
// Get members
rows, err := r.pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query members: %w", err)
}
defer rows.Close()
var members []AudienceMember
for rows.Next() {
var m AudienceMember
if err := rows.Scan(
&m.ID, &m.Name, &m.Email, &m.Position,
&m.University, &m.Department, &m.SubjectArea, &m.PublicationCount,
); err != nil {
return nil, 0, fmt.Errorf("failed to scan member: %w", err)
}
members = append(members, m)
}
return members, totalCount, rows.Err()
}
// buildAudienceMemberQuery constructs a SQL query for audience members
func (r *PostgresRepository) buildAudienceMemberQuery(filters AudienceFilters, limit, offset int, countOnly bool) (string, []interface{}) {
var args []interface{}
argNum := 1
var selectClause string
if countOnly {
selectClause = "SELECT COUNT(*)"
} else {
selectClause = `
SELECT
s.id,
COALESCE(s.title || ' ', '') || s.first_name || ' ' || s.last_name as name,
COALESCE(s.email, '') as email,
COALESCE(s.position_type, '') as position,
u.name as university,
COALESCE(d.name, '') as department,
COALESCE(sa.name, '') as subject_area,
(SELECT COUNT(*) FROM staff_publications sp WHERE sp.staff_id = s.id) as publication_count
`
}
query := selectClause + `
FROM university_staff s
JOIN universities u ON s.university_id = u.id
LEFT JOIN departments d ON s.department_id = d.id
LEFT JOIN subject_areas sa ON s.subject_area_id = sa.id
WHERE 1=1
`
// Position types filter
if len(filters.PositionTypes) > 0 {
query += fmt.Sprintf(" AND s.position_type = ANY($%d)", argNum)
args = append(args, filters.PositionTypes)
argNum++
}
// Subject areas filter
if len(filters.SubjectAreas) > 0 {
query += fmt.Sprintf(" AND s.subject_area_id = ANY($%d)", argNum)
args = append(args, filters.SubjectAreas)
argNum++
}
// States filter
if len(filters.States) > 0 {
query += fmt.Sprintf(" AND u.state = ANY($%d)", argNum)
args = append(args, filters.States)
argNum++
}
// Uni types filter
if len(filters.UniTypes) > 0 {
query += fmt.Sprintf(" AND u.uni_type = ANY($%d)", argNum)
args = append(args, filters.UniTypes)
argNum++
}
// Universities filter
if len(filters.Universities) > 0 {
query += fmt.Sprintf(" AND s.university_id = ANY($%d)", argNum)
args = append(args, filters.Universities)
argNum++
}
// Has email filter
if filters.HasEmail != nil && *filters.HasEmail {
query += " AND s.email IS NOT NULL AND s.email != ''"
}
// Is active filter
if filters.IsActive != nil && *filters.IsActive {
query += " AND s.is_active = TRUE"
}
// Keywords filter (search in name and research_areas)
if len(filters.Keywords) > 0 {
for _, keyword := range filters.Keywords {
query += fmt.Sprintf(" AND (s.first_name ILIKE $%d OR s.last_name ILIKE $%d OR s.research_areas ILIKE $%d)", argNum, argNum, argNum)
args = append(args, "%"+keyword+"%")
argNum++
}
}
if !countOnly {
query += " ORDER BY s.last_name, s.first_name"
if limit > 0 {
query += fmt.Sprintf(" LIMIT $%d", argNum)
args = append(args, limit)
argNum++
}
if offset > 0 {
query += fmt.Sprintf(" OFFSET $%d", argNum)
args = append(args, offset)
}
}
return query, args
}
// UpdateAudienceCount updates the cached member count for an audience
func (r *PostgresRepository) UpdateAudienceCount(ctx context.Context, id uuid.UUID) (int, error) {
// Get the audience filters
audience, err := r.GetAudience(ctx, id)
if err != nil {
return 0, fmt.Errorf("failed to get audience: %w", err)
}
// Count members
countQuery, countArgs := r.buildAudienceMemberQuery(audience.Filters, 0, 0, true)
var count int
if err := r.pool.QueryRow(ctx, countQuery, countArgs...).Scan(&count); err != nil {
return 0, fmt.Errorf("failed to count members: %w", err)
}
// Update the cached count
updateQuery := `
UPDATE audiences
SET member_count = $2, last_count_update = NOW(), updated_at = NOW()
WHERE id = $1
`
if _, err := r.pool.Exec(ctx, updateQuery, id, count); err != nil {
return 0, fmt.Errorf("failed to update count: %w", err)
}
return count, nil
}
// CreateExport creates a new export record
func (r *PostgresRepository) CreateExport(ctx context.Context, export *AudienceExport) error {
query := `
INSERT INTO audience_exports (audience_id, export_type, record_count, file_path, exported_by, purpose)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at
`
return r.pool.QueryRow(ctx, query,
export.AudienceID,
export.ExportType,
export.RecordCount,
export.FilePath,
export.ExportedBy,
export.Purpose,
).Scan(&export.ID, &export.CreatedAt)
}
// ListExports lists exports for an audience
func (r *PostgresRepository) ListExports(ctx context.Context, audienceID uuid.UUID) ([]AudienceExport, error) {
query := `
SELECT id, audience_id, export_type, record_count, file_path, exported_by, purpose, created_at
FROM audience_exports
WHERE audience_id = $1
ORDER BY created_at DESC
`
rows, err := r.pool.Query(ctx, query, audienceID)
if err != nil {
return nil, fmt.Errorf("failed to query exports: %w", err)
}
defer rows.Close()
var exports []AudienceExport
for rows.Next() {
var e AudienceExport
if err := rows.Scan(
&e.ID, &e.AudienceID, &e.ExportType, &e.RecordCount,
&e.FilePath, &e.ExportedBy, &e.Purpose, &e.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan export: %w", err)
}
exports = append(exports, e)
}
return exports, rows.Err()
}

View File

@@ -0,0 +1,407 @@
// Package orchestrator implements multi-phase university crawling with queue management
package orchestrator
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/google/uuid"
)
// CrawlPhase represents a phase in the crawl process
type CrawlPhase string
const (
PhasePending CrawlPhase = "pending"
PhaseDiscovery CrawlPhase = "discovery" // Find sample professor to validate crawling works
PhaseProfessors CrawlPhase = "professors" // Crawl all professors
PhaseAllStaff CrawlPhase = "all_staff" // Crawl all staff members
PhasePublications CrawlPhase = "publications" // Crawl publications for all staff
PhaseCompleted CrawlPhase = "completed"
PhaseFailed CrawlPhase = "failed"
PhasePaused CrawlPhase = "paused"
)
// CrawlQueueItem represents a university in the crawl queue
type CrawlQueueItem struct {
ID uuid.UUID `json:"id"`
UniversityID uuid.UUID `json:"university_id"`
UniversityName string `json:"university_name"`
UniversityShort string `json:"university_short"`
QueuePosition *int `json:"queue_position"`
Priority int `json:"priority"`
CurrentPhase CrawlPhase `json:"current_phase"`
DiscoveryCompleted bool `json:"discovery_completed"`
DiscoveryCompletedAt *time.Time `json:"discovery_completed_at,omitempty"`
ProfessorsCompleted bool `json:"professors_completed"`
ProfessorsCompletedAt *time.Time `json:"professors_completed_at,omitempty"`
AllStaffCompleted bool `json:"all_staff_completed"`
AllStaffCompletedAt *time.Time `json:"all_staff_completed_at,omitempty"`
PublicationsCompleted bool `json:"publications_completed"`
PublicationsCompletedAt *time.Time `json:"publications_completed_at,omitempty"`
DiscoveryCount int `json:"discovery_count"`
ProfessorsCount int `json:"professors_count"`
StaffCount int `json:"staff_count"`
PublicationsCount int `json:"publications_count"`
RetryCount int `json:"retry_count"`
MaxRetries int `json:"max_retries"`
LastError string `json:"last_error,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
ProgressPercent int `json:"progress_percent"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// CrawlProgress represents progress for a single phase
type CrawlProgress struct {
Phase CrawlPhase `json:"phase"`
ItemsFound int `json:"items_found"`
ItemsProcessed int `json:"items_processed"`
Errors []string `json:"errors,omitempty"`
StartedAt time.Time `json:"started_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// OrchestratorStatus represents the current state of the orchestrator
type OrchestratorStatus struct {
IsRunning bool `json:"is_running"`
CurrentUniversity *CrawlQueueItem `json:"current_university,omitempty"`
CurrentPhase CrawlPhase `json:"current_phase"`
QueueLength int `json:"queue_length"`
CompletedToday int `json:"completed_today"`
TotalProcessed int `json:"total_processed"`
LastActivity *time.Time `json:"last_activity,omitempty"`
}
// StaffCrawlerInterface defines what the staff crawler must implement
type StaffCrawlerInterface interface {
// DiscoverSampleProfessor finds at least one professor to validate crawling works
DiscoverSampleProfessor(ctx context.Context, universityID uuid.UUID) (*CrawlProgress, error)
// CrawlProfessors crawls all professors at a university
CrawlProfessors(ctx context.Context, universityID uuid.UUID) (*CrawlProgress, error)
// CrawlAllStaff crawls all staff members at a university
CrawlAllStaff(ctx context.Context, universityID uuid.UUID) (*CrawlProgress, error)
}
// PublicationCrawlerInterface defines what the publication crawler must implement
type PublicationCrawlerInterface interface {
// CrawlPublicationsForUniversity crawls publications for all staff at a university
CrawlPublicationsForUniversity(ctx context.Context, universityID uuid.UUID) (*CrawlProgress, error)
}
// Repository defines database operations for the orchestrator
type Repository interface {
// Queue operations
GetQueueItems(ctx context.Context) ([]CrawlQueueItem, error)
GetNextInQueue(ctx context.Context) (*CrawlQueueItem, error)
AddToQueue(ctx context.Context, universityID uuid.UUID, priority int, initiatedBy string) (*CrawlQueueItem, error)
RemoveFromQueue(ctx context.Context, universityID uuid.UUID) error
UpdateQueueItem(ctx context.Context, item *CrawlQueueItem) error
PauseQueueItem(ctx context.Context, universityID uuid.UUID) error
ResumeQueueItem(ctx context.Context, universityID uuid.UUID) error
// Phase updates
CompletePhase(ctx context.Context, universityID uuid.UUID, phase CrawlPhase, count int) error
FailPhase(ctx context.Context, universityID uuid.UUID, phase CrawlPhase, err string) error
// Stats
GetCompletedTodayCount(ctx context.Context) (int, error)
GetTotalProcessedCount(ctx context.Context) (int, error)
}
// Orchestrator manages the multi-phase crawl process
type Orchestrator struct {
repo Repository
staffCrawler StaffCrawlerInterface
pubCrawler PublicationCrawlerInterface
// Runtime state
mu sync.RWMutex
isRunning bool
stopChan chan struct{}
currentItem *CrawlQueueItem
lastActivity time.Time
// Configuration
phaseCooldown time.Duration // Wait time between phases
retryCooldown time.Duration // Wait time after failure before retry
maxConcurrent int // Max concurrent crawls (always 1 for now)
}
// NewOrchestrator creates a new orchestrator instance
func NewOrchestrator(repo Repository, staffCrawler StaffCrawlerInterface, pubCrawler PublicationCrawlerInterface) *Orchestrator {
return &Orchestrator{
repo: repo,
staffCrawler: staffCrawler,
pubCrawler: pubCrawler,
phaseCooldown: 5 * time.Second, // Small pause between phases
retryCooldown: 30 * time.Second, // Wait before retry after failure
maxConcurrent: 1, // Sequential processing
}
}
// Start begins the orchestrator loop
func (o *Orchestrator) Start() error {
o.mu.Lock()
if o.isRunning {
o.mu.Unlock()
return fmt.Errorf("orchestrator already running")
}
o.isRunning = true
o.stopChan = make(chan struct{})
o.mu.Unlock()
log.Println("[Orchestrator] Starting crawl orchestration loop")
go o.runLoop()
return nil
}
// Stop gracefully stops the orchestrator
func (o *Orchestrator) Stop() error {
o.mu.Lock()
if !o.isRunning {
o.mu.Unlock()
return fmt.Errorf("orchestrator not running")
}
close(o.stopChan)
o.isRunning = false
o.mu.Unlock()
log.Println("[Orchestrator] Stopped")
return nil
}
// Status returns the current orchestrator status
func (o *Orchestrator) Status(ctx context.Context) (*OrchestratorStatus, error) {
o.mu.RLock()
defer o.mu.RUnlock()
status := &OrchestratorStatus{
IsRunning: o.isRunning,
CurrentPhase: PhasePending,
}
if o.currentItem != nil {
status.CurrentUniversity = o.currentItem
status.CurrentPhase = o.currentItem.CurrentPhase
}
if !o.lastActivity.IsZero() {
status.LastActivity = &o.lastActivity
}
// Get queue stats from DB
items, err := o.repo.GetQueueItems(ctx)
if err == nil {
status.QueueLength = len(items)
}
completedToday, _ := o.repo.GetCompletedTodayCount(ctx)
status.CompletedToday = completedToday
totalProcessed, _ := o.repo.GetTotalProcessedCount(ctx)
status.TotalProcessed = totalProcessed
return status, nil
}
// AddUniversity adds a university to the crawl queue
func (o *Orchestrator) AddUniversity(ctx context.Context, universityID uuid.UUID, priority int, initiatedBy string) (*CrawlQueueItem, error) {
item, err := o.repo.AddToQueue(ctx, universityID, priority, initiatedBy)
if err != nil {
return nil, fmt.Errorf("failed to add to queue: %w", err)
}
log.Printf("[Orchestrator] Added university %s to queue with priority %d", universityID, priority)
return item, nil
}
// RemoveUniversity removes a university from the queue
func (o *Orchestrator) RemoveUniversity(ctx context.Context, universityID uuid.UUID) error {
return o.repo.RemoveFromQueue(ctx, universityID)
}
// PauseUniversity pauses crawling for a university
func (o *Orchestrator) PauseUniversity(ctx context.Context, universityID uuid.UUID) error {
return o.repo.PauseQueueItem(ctx, universityID)
}
// ResumeUniversity resumes crawling for a paused university
func (o *Orchestrator) ResumeUniversity(ctx context.Context, universityID uuid.UUID) error {
return o.repo.ResumeQueueItem(ctx, universityID)
}
// GetQueue returns all items in the queue
func (o *Orchestrator) GetQueue(ctx context.Context) ([]CrawlQueueItem, error) {
return o.repo.GetQueueItems(ctx)
}
// runLoop is the main orchestration loop
func (o *Orchestrator) runLoop() {
ticker := time.NewTicker(10 * time.Second) // Check queue every 10 seconds
defer ticker.Stop()
for {
select {
case <-o.stopChan:
return
case <-ticker.C:
o.processNextInQueue()
}
}
}
// processNextInQueue processes the next university in the queue
func (o *Orchestrator) processNextInQueue() {
ctx := context.Background()
// Get next item in queue
item, err := o.repo.GetNextInQueue(ctx)
if err != nil {
log.Printf("[Orchestrator] Error getting next item: %v", err)
return
}
if item == nil {
// No items to process
return
}
// Check if paused
if item.CurrentPhase == PhasePaused {
return
}
// Set current item
o.mu.Lock()
o.currentItem = item
o.lastActivity = time.Now()
o.mu.Unlock()
defer func() {
o.mu.Lock()
o.currentItem = nil
o.mu.Unlock()
}()
log.Printf("[Orchestrator] Processing university: %s (Phase: %s)", item.UniversityName, item.CurrentPhase)
// Process based on current phase
switch item.CurrentPhase {
case PhasePending:
o.runPhase(ctx, item, PhaseDiscovery)
case PhaseDiscovery:
if item.DiscoveryCompleted {
o.runPhase(ctx, item, PhaseProfessors)
} else {
o.runPhase(ctx, item, PhaseDiscovery)
}
case PhaseProfessors:
if item.ProfessorsCompleted {
o.runPhase(ctx, item, PhaseAllStaff)
} else {
o.runPhase(ctx, item, PhaseProfessors)
}
case PhaseAllStaff:
if item.AllStaffCompleted {
o.runPhase(ctx, item, PhasePublications)
} else {
o.runPhase(ctx, item, PhaseAllStaff)
}
case PhasePublications:
if item.PublicationsCompleted {
o.completeUniversity(ctx, item)
} else {
o.runPhase(ctx, item, PhasePublications)
}
}
}
// runPhase executes a specific crawl phase
func (o *Orchestrator) runPhase(ctx context.Context, item *CrawlQueueItem, phase CrawlPhase) {
log.Printf("[Orchestrator] Running phase %s for %s", phase, item.UniversityName)
// Update current phase
item.CurrentPhase = phase
if err := o.repo.UpdateQueueItem(ctx, item); err != nil {
log.Printf("[Orchestrator] Failed to update phase: %v", err)
return
}
var progress *CrawlProgress
var err error
// Execute phase
switch phase {
case PhaseDiscovery:
progress, err = o.staffCrawler.DiscoverSampleProfessor(ctx, item.UniversityID)
case PhaseProfessors:
progress, err = o.staffCrawler.CrawlProfessors(ctx, item.UniversityID)
case PhaseAllStaff:
progress, err = o.staffCrawler.CrawlAllStaff(ctx, item.UniversityID)
case PhasePublications:
progress, err = o.pubCrawler.CrawlPublicationsForUniversity(ctx, item.UniversityID)
}
// Handle result
if err != nil {
log.Printf("[Orchestrator] Phase %s failed: %v", phase, err)
o.handlePhaseFailure(ctx, item, phase, err)
return
}
// Mark phase complete
count := 0
if progress != nil {
count = progress.ItemsFound
}
if err := o.repo.CompletePhase(ctx, item.UniversityID, phase, count); err != nil {
log.Printf("[Orchestrator] Failed to complete phase: %v", err)
}
log.Printf("[Orchestrator] Phase %s completed for %s (found: %d)", phase, item.UniversityName, count)
// Wait before next phase
time.Sleep(o.phaseCooldown)
}
// handlePhaseFailure handles a phase failure
func (o *Orchestrator) handlePhaseFailure(ctx context.Context, item *CrawlQueueItem, phase CrawlPhase, err error) {
item.RetryCount++
item.LastError = err.Error()
if item.RetryCount >= item.MaxRetries {
// Max retries reached, mark as failed
item.CurrentPhase = PhaseFailed
log.Printf("[Orchestrator] University %s failed after %d retries", item.UniversityName, item.RetryCount)
}
if updateErr := o.repo.FailPhase(ctx, item.UniversityID, phase, err.Error()); updateErr != nil {
log.Printf("[Orchestrator] Failed to update failure status: %v", updateErr)
}
// Wait before potential retry
time.Sleep(o.retryCooldown)
}
// completeUniversity marks a university as fully crawled
func (o *Orchestrator) completeUniversity(ctx context.Context, item *CrawlQueueItem) {
now := time.Now()
item.CurrentPhase = PhaseCompleted
item.CompletedAt = &now
item.QueuePosition = nil // Remove from active queue
if err := o.repo.UpdateQueueItem(ctx, item); err != nil {
log.Printf("[Orchestrator] Failed to complete university: %v", err)
return
}
log.Printf("[Orchestrator] University %s completed! Professors: %d, Staff: %d, Publications: %d",
item.UniversityName, item.ProfessorsCount, item.StaffCount, item.PublicationsCount)
}

View File

@@ -0,0 +1,316 @@
// Package orchestrator implements multi-phase university crawling with queue management
package orchestrator
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// PostgresRepository implements the Repository interface using PostgreSQL
type PostgresRepository struct {
pool *pgxpool.Pool
}
// NewPostgresRepository creates a new PostgresRepository
func NewPostgresRepository(pool *pgxpool.Pool) *PostgresRepository {
return &PostgresRepository{pool: pool}
}
// ============================================================================
// QUEUE OPERATIONS
// ============================================================================
// GetQueueItems retrieves all items in the crawl queue
func (r *PostgresRepository) GetQueueItems(ctx context.Context) ([]CrawlQueueItem, error) {
query := `
SELECT
cq.id, cq.university_id, u.name, COALESCE(u.short_name, ''),
cq.queue_position, cq.priority, cq.current_phase,
cq.discovery_completed, cq.discovery_completed_at,
cq.professors_completed, cq.professors_completed_at,
cq.all_staff_completed, cq.all_staff_completed_at,
cq.publications_completed, cq.publications_completed_at,
cq.discovery_count, cq.professors_count, cq.staff_count, cq.publications_count,
cq.retry_count, cq.max_retries, COALESCE(cq.last_error, ''),
cq.started_at, cq.completed_at,
CASE
WHEN cq.current_phase = 'pending' THEN 0
WHEN cq.current_phase = 'discovery' THEN 10
WHEN cq.current_phase = 'professors' THEN 30
WHEN cq.current_phase = 'all_staff' THEN 60
WHEN cq.current_phase = 'publications' THEN 90
WHEN cq.current_phase = 'completed' THEN 100
ELSE 0
END as progress_percent,
cq.created_at, cq.updated_at
FROM crawl_queue cq
JOIN universities u ON cq.university_id = u.id
ORDER BY cq.queue_position NULLS LAST, cq.priority DESC
`
rows, err := r.pool.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query queue items: %w", err)
}
defer rows.Close()
var items []CrawlQueueItem
for rows.Next() {
var item CrawlQueueItem
var phase string
if err := rows.Scan(
&item.ID, &item.UniversityID, &item.UniversityName, &item.UniversityShort,
&item.QueuePosition, &item.Priority, &phase,
&item.DiscoveryCompleted, &item.DiscoveryCompletedAt,
&item.ProfessorsCompleted, &item.ProfessorsCompletedAt,
&item.AllStaffCompleted, &item.AllStaffCompletedAt,
&item.PublicationsCompleted, &item.PublicationsCompletedAt,
&item.DiscoveryCount, &item.ProfessorsCount, &item.StaffCount, &item.PublicationsCount,
&item.RetryCount, &item.MaxRetries, &item.LastError,
&item.StartedAt, &item.CompletedAt,
&item.ProgressPercent,
&item.CreatedAt, &item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan queue item: %w", err)
}
item.CurrentPhase = CrawlPhase(phase)
items = append(items, item)
}
return items, rows.Err()
}
// GetNextInQueue retrieves the next item to process
func (r *PostgresRepository) GetNextInQueue(ctx context.Context) (*CrawlQueueItem, error) {
query := `
SELECT
cq.id, cq.university_id, u.name, COALESCE(u.short_name, ''),
cq.queue_position, cq.priority, cq.current_phase,
cq.discovery_completed, cq.discovery_completed_at,
cq.professors_completed, cq.professors_completed_at,
cq.all_staff_completed, cq.all_staff_completed_at,
cq.publications_completed, cq.publications_completed_at,
cq.discovery_count, cq.professors_count, cq.staff_count, cq.publications_count,
cq.retry_count, cq.max_retries, COALESCE(cq.last_error, ''),
cq.started_at, cq.completed_at,
cq.created_at, cq.updated_at
FROM crawl_queue cq
JOIN universities u ON cq.university_id = u.id
WHERE cq.current_phase NOT IN ('completed', 'failed', 'paused')
AND cq.queue_position IS NOT NULL
ORDER BY cq.queue_position ASC, cq.priority DESC
LIMIT 1
`
var item CrawlQueueItem
var phase string
err := r.pool.QueryRow(ctx, query).Scan(
&item.ID, &item.UniversityID, &item.UniversityName, &item.UniversityShort,
&item.QueuePosition, &item.Priority, &phase,
&item.DiscoveryCompleted, &item.DiscoveryCompletedAt,
&item.ProfessorsCompleted, &item.ProfessorsCompletedAt,
&item.AllStaffCompleted, &item.AllStaffCompletedAt,
&item.PublicationsCompleted, &item.PublicationsCompletedAt,
&item.DiscoveryCount, &item.ProfessorsCount, &item.StaffCount, &item.PublicationsCount,
&item.RetryCount, &item.MaxRetries, &item.LastError,
&item.StartedAt, &item.CompletedAt,
&item.CreatedAt, &item.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get next queue item: %w", err)
}
item.CurrentPhase = CrawlPhase(phase)
return &item, nil
}
// AddToQueue adds a university to the crawl queue
func (r *PostgresRepository) AddToQueue(ctx context.Context, universityID uuid.UUID, priority int, initiatedBy string) (*CrawlQueueItem, error) {
// Get next queue position
var nextPosition int
err := r.pool.QueryRow(ctx, `SELECT COALESCE(MAX(queue_position), 0) + 1 FROM crawl_queue WHERE queue_position IS NOT NULL`).Scan(&nextPosition)
if err != nil {
return nil, fmt.Errorf("failed to get next queue position: %w", err)
}
query := `
INSERT INTO crawl_queue (university_id, queue_position, priority, initiated_by)
VALUES ($1, $2, $3, $4)
ON CONFLICT (university_id) DO UPDATE SET
queue_position = EXCLUDED.queue_position,
priority = EXCLUDED.priority,
current_phase = 'pending',
retry_count = 0,
last_error = NULL,
updated_at = NOW()
RETURNING id, created_at, updated_at
`
item := &CrawlQueueItem{
UniversityID: universityID,
QueuePosition: &nextPosition,
Priority: priority,
CurrentPhase: PhasePending,
MaxRetries: 3,
}
err = r.pool.QueryRow(ctx, query, universityID, nextPosition, priority, initiatedBy).Scan(
&item.ID, &item.CreatedAt, &item.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to add to queue: %w", err)
}
// Get university name
r.pool.QueryRow(ctx, `SELECT name, short_name FROM universities WHERE id = $1`, universityID).Scan(
&item.UniversityName, &item.UniversityShort,
)
return item, nil
}
// RemoveFromQueue removes a university from the queue
func (r *PostgresRepository) RemoveFromQueue(ctx context.Context, universityID uuid.UUID) error {
_, err := r.pool.Exec(ctx, `DELETE FROM crawl_queue WHERE university_id = $1`, universityID)
return err
}
// UpdateQueueItem updates a queue item
func (r *PostgresRepository) UpdateQueueItem(ctx context.Context, item *CrawlQueueItem) error {
query := `
UPDATE crawl_queue SET
queue_position = $2,
priority = $3,
current_phase = $4,
discovery_completed = $5,
discovery_completed_at = $6,
professors_completed = $7,
professors_completed_at = $8,
all_staff_completed = $9,
all_staff_completed_at = $10,
publications_completed = $11,
publications_completed_at = $12,
discovery_count = $13,
professors_count = $14,
staff_count = $15,
publications_count = $16,
retry_count = $17,
last_error = $18,
started_at = $19,
completed_at = $20,
updated_at = NOW()
WHERE university_id = $1
`
_, err := r.pool.Exec(ctx, query,
item.UniversityID,
item.QueuePosition, item.Priority, string(item.CurrentPhase),
item.DiscoveryCompleted, item.DiscoveryCompletedAt,
item.ProfessorsCompleted, item.ProfessorsCompletedAt,
item.AllStaffCompleted, item.AllStaffCompletedAt,
item.PublicationsCompleted, item.PublicationsCompletedAt,
item.DiscoveryCount, item.ProfessorsCount, item.StaffCount, item.PublicationsCount,
item.RetryCount, item.LastError,
item.StartedAt, item.CompletedAt,
)
return err
}
// PauseQueueItem pauses a crawl
func (r *PostgresRepository) PauseQueueItem(ctx context.Context, universityID uuid.UUID) error {
_, err := r.pool.Exec(ctx, `UPDATE crawl_queue SET current_phase = 'paused', updated_at = NOW() WHERE university_id = $1`, universityID)
return err
}
// ResumeQueueItem resumes a paused crawl
func (r *PostgresRepository) ResumeQueueItem(ctx context.Context, universityID uuid.UUID) error {
// Determine what phase to resume from
query := `
UPDATE crawl_queue SET
current_phase = CASE
WHEN NOT discovery_completed THEN 'discovery'
WHEN NOT professors_completed THEN 'professors'
WHEN NOT all_staff_completed THEN 'all_staff'
WHEN NOT publications_completed THEN 'publications'
ELSE 'pending'
END,
updated_at = NOW()
WHERE university_id = $1 AND current_phase = 'paused'
`
_, err := r.pool.Exec(ctx, query, universityID)
return err
}
// ============================================================================
// PHASE UPDATES
// ============================================================================
// CompletePhase marks a phase as completed
func (r *PostgresRepository) CompletePhase(ctx context.Context, universityID uuid.UUID, phase CrawlPhase, count int) error {
now := time.Now()
var query string
switch phase {
case PhaseDiscovery:
query = `UPDATE crawl_queue SET discovery_completed = true, discovery_completed_at = $2, discovery_count = $3, updated_at = NOW() WHERE university_id = $1`
case PhaseProfessors:
query = `UPDATE crawl_queue SET professors_completed = true, professors_completed_at = $2, professors_count = $3, updated_at = NOW() WHERE university_id = $1`
case PhaseAllStaff:
query = `UPDATE crawl_queue SET all_staff_completed = true, all_staff_completed_at = $2, staff_count = $3, updated_at = NOW() WHERE university_id = $1`
case PhasePublications:
query = `UPDATE crawl_queue SET publications_completed = true, publications_completed_at = $2, publications_count = $3, updated_at = NOW() WHERE university_id = $1`
default:
return fmt.Errorf("unknown phase: %s", phase)
}
_, err := r.pool.Exec(ctx, query, universityID, now, count)
return err
}
// FailPhase records a phase failure
func (r *PostgresRepository) FailPhase(ctx context.Context, universityID uuid.UUID, phase CrawlPhase, errMsg string) error {
query := `
UPDATE crawl_queue SET
retry_count = retry_count + 1,
last_error = $2,
current_phase = CASE
WHEN retry_count + 1 >= max_retries THEN 'failed'
ELSE current_phase
END,
updated_at = NOW()
WHERE university_id = $1
`
_, err := r.pool.Exec(ctx, query, universityID, errMsg)
return err
}
// ============================================================================
// STATS
// ============================================================================
// GetCompletedTodayCount returns the number of universities completed today
func (r *PostgresRepository) GetCompletedTodayCount(ctx context.Context) (int, error) {
var count int
err := r.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM crawl_queue
WHERE current_phase = 'completed'
AND completed_at >= CURRENT_DATE
`).Scan(&count)
return count, err
}
// GetTotalProcessedCount returns the total number of processed universities
func (r *PostgresRepository) GetTotalProcessedCount(ctx context.Context) (int, error) {
var count int
err := r.pool.QueryRow(ctx, `SELECT COUNT(*) FROM crawl_queue WHERE current_phase = 'completed'`).Scan(&count)
return count, err
}

View File

@@ -0,0 +1,301 @@
package pipeline
import (
"context"
"log"
"strings"
"sync"
"time"
"github.com/breakpilot/edu-search-service/internal/crawler"
"github.com/breakpilot/edu-search-service/internal/extractor"
"github.com/breakpilot/edu-search-service/internal/indexer"
"github.com/breakpilot/edu-search-service/internal/tagger"
)
// Pipeline orchestrates crawling, extraction, tagging, and indexing
type Pipeline struct {
crawler *crawler.Crawler
tagger *tagger.Tagger
indexClient *indexer.Client
maxPages int
workers int
}
// Stats tracks pipeline execution statistics
type Stats struct {
StartTime time.Time
EndTime time.Time
URLsProcessed int
URLsSuccessful int
URLsFailed int
URLsSkipped int
DocumentsIndexed int
}
// NewPipeline creates a new crawl pipeline
func NewPipeline(
crawlerInstance *crawler.Crawler,
taggerInstance *tagger.Tagger,
indexClient *indexer.Client,
maxPages int,
) *Pipeline {
return &Pipeline{
crawler: crawlerInstance,
tagger: taggerInstance,
indexClient: indexClient,
maxPages: maxPages,
workers: 5, // concurrent workers
}
}
// Run executes the crawl pipeline
func (p *Pipeline) Run(ctx context.Context, seedsDir string) (*Stats, error) {
stats := &Stats{
StartTime: time.Now(),
}
// Load seed URLs
seeds, err := p.crawler.LoadSeeds(seedsDir)
if err != nil {
return nil, err
}
log.Printf("Pipeline starting with %d seeds, max %d pages", len(seeds), p.maxPages)
// Create URL queue
urlQueue := make(chan string, len(seeds)*10)
visited := &sync.Map{}
// Add seeds to queue
for _, seed := range seeds {
normalized := crawler.NormalizeURL(seed)
if _, loaded := visited.LoadOrStore(normalized, true); !loaded {
urlQueue <- seed
}
}
// Results channel
results := make(chan *processResult, p.workers*2)
var wg sync.WaitGroup
// Start workers
for i := 0; i < p.workers; i++ {
wg.Add(1)
go p.worker(ctx, i, urlQueue, results, visited, &wg)
}
// Close results when all workers done
go func() {
wg.Wait()
close(results)
}()
// Process results and collect stats
var documents []indexer.Document
processed := 0
for result := range results {
stats.URLsProcessed++
if result.err != nil {
stats.URLsFailed++
continue
}
if result.skipped {
stats.URLsSkipped++
continue
}
if result.document != nil {
documents = append(documents, *result.document)
stats.URLsSuccessful++
// Bulk index every 50 documents
if len(documents) >= 50 {
if err := p.indexClient.BulkIndex(ctx, documents); err != nil {
log.Printf("Bulk index error: %v", err)
} else {
stats.DocumentsIndexed += len(documents)
}
documents = nil
}
}
processed++
if processed >= p.maxPages {
log.Printf("Reached max pages limit (%d)", p.maxPages)
close(urlQueue)
break
}
}
// Index remaining documents
if len(documents) > 0 {
if err := p.indexClient.BulkIndex(ctx, documents); err != nil {
log.Printf("Final bulk index error: %v", err)
} else {
stats.DocumentsIndexed += len(documents)
}
}
stats.EndTime = time.Now()
log.Printf("Pipeline completed: %d processed, %d indexed, %d failed, %d skipped in %v",
stats.URLsProcessed, stats.DocumentsIndexed, stats.URLsFailed, stats.URLsSkipped,
stats.EndTime.Sub(stats.StartTime))
return stats, nil
}
type processResult struct {
url string
document *indexer.Document
err error
skipped bool
}
func (p *Pipeline) worker(
ctx context.Context,
id int,
urlQueue chan string,
results chan<- *processResult,
visited *sync.Map,
wg *sync.WaitGroup,
) {
defer wg.Done()
for url := range urlQueue {
select {
case <-ctx.Done():
return
default:
result := p.processURL(ctx, url, urlQueue, visited)
results <- result
}
}
}
func (p *Pipeline) processURL(
ctx context.Context,
url string,
urlQueue chan<- string,
visited *sync.Map,
) *processResult {
result := &processResult{url: url}
// Fetch URL
fetchResult, err := p.crawler.Fetch(ctx, url)
if err != nil {
result.err = err
return result
}
// Check content type
contentType := strings.ToLower(fetchResult.ContentType)
if !strings.Contains(contentType, "text/html") && !strings.Contains(contentType, "application/pdf") {
result.skipped = true
return result
}
// Extract content
var extracted *extractor.ExtractedContent
if strings.Contains(contentType, "text/html") {
extracted, err = extractor.ExtractHTML(fetchResult.Body)
} else if strings.Contains(contentType, "application/pdf") {
extracted, err = extractor.ExtractPDF(fetchResult.Body)
}
if err != nil {
result.err = err
return result
}
// Skip if too little content
if extracted.ContentLength < 100 {
result.skipped = true
return result
}
// Tag content
features := tagger.ContentFeatures{
AdDensity: extracted.Features.AdDensity,
LinkDensity: extracted.Features.LinkDensity,
ContentLength: extracted.ContentLength,
}
tags := p.tagger.Tag(fetchResult.CanonicalURL, extracted.Title, extracted.ContentText, features)
// Create document
doc := &indexer.Document{
DocID: crawler.GenerateDocID(),
URL: fetchResult.CanonicalURL,
Domain: crawler.ExtractDomain(fetchResult.CanonicalURL),
Title: extracted.Title,
ContentText: extracted.ContentText,
SnippetText: extracted.SnippetText,
ContentHash: fetchResult.ContentHash,
DocType: tags.DocType,
Subjects: tags.Subjects,
SchoolLevel: tags.SchoolLevel,
State: tags.State,
Language: extracted.Language,
TrustScore: tags.TrustScore,
QualityScore: calculateQualityScore(extracted, tags),
FetchedAt: fetchResult.FetchTime,
UpdatedAt: time.Now(),
}
result.document = doc
// Extract and queue new links (limited to same domain for now)
docDomain := crawler.ExtractDomain(url)
for _, link := range extracted.Links {
linkDomain := crawler.ExtractDomain(link)
if linkDomain == docDomain {
normalized := crawler.NormalizeURL(link)
if _, loaded := visited.LoadOrStore(normalized, true); !loaded {
select {
case urlQueue <- link:
default:
// Queue full, skip
}
}
}
}
return result
}
func calculateQualityScore(extracted *extractor.ExtractedContent, tags tagger.TagResult) float64 {
score := 0.5 // base
// Content length bonus
if extracted.ContentLength > 1000 {
score += 0.1
}
if extracted.ContentLength > 5000 {
score += 0.1
}
// Has headings
if len(extracted.Headings) > 0 {
score += 0.1
}
// Low ad density
if extracted.Features.AdDensity < 0.1 {
score += 0.1
}
// Good text/HTML ratio
if extracted.Features.TextToHTMLRatio > 0.2 {
score += 0.1
}
// Clamp
if score > 1 {
score = 1
}
return score
}

View File

@@ -0,0 +1,255 @@
package policy
import (
"context"
"encoding/json"
"github.com/google/uuid"
)
// Auditor provides audit logging functionality for the policy system.
type Auditor struct {
store *Store
}
// NewAuditor creates a new Auditor instance.
func NewAuditor(store *Store) *Auditor {
return &Auditor{store: store}
}
// LogChange logs a policy change to the audit trail.
func (a *Auditor) LogChange(ctx context.Context, action AuditAction, entityType AuditEntityType, entityID *uuid.UUID, oldValue, newValue interface{}, userEmail, ipAddress, userAgent *string) error {
entry := &PolicyAuditLog{
Action: action,
EntityType: entityType,
EntityID: entityID,
UserEmail: userEmail,
IPAddress: ipAddress,
UserAgent: userAgent,
}
if oldValue != nil {
entry.OldValue = toJSON(oldValue)
}
if newValue != nil {
entry.NewValue = toJSON(newValue)
}
return a.store.CreateAuditLog(ctx, entry)
}
// LogBlocked logs a blocked URL to the blocked content log.
func (a *Auditor) LogBlocked(ctx context.Context, url, domain string, reason BlockReason, ruleID *uuid.UUID, details map[string]interface{}) error {
entry := &BlockedContentLog{
URL: url,
Domain: domain,
BlockReason: reason,
MatchedRuleID: ruleID,
}
if details != nil {
entry.Details = toJSON(details)
}
return a.store.CreateBlockedContentLog(ctx, entry)
}
// =============================================================================
// CONVENIENCE METHODS
// =============================================================================
// LogPolicyCreated logs a policy creation event.
func (a *Auditor) LogPolicyCreated(ctx context.Context, policy *SourcePolicy, userEmail *string) error {
return a.LogChange(ctx, AuditActionCreate, AuditEntitySourcePolicy, &policy.ID, nil, policy, userEmail, nil, nil)
}
// LogPolicyUpdated logs a policy update event.
func (a *Auditor) LogPolicyUpdated(ctx context.Context, oldPolicy, newPolicy *SourcePolicy, userEmail *string) error {
return a.LogChange(ctx, AuditActionUpdate, AuditEntitySourcePolicy, &newPolicy.ID, oldPolicy, newPolicy, userEmail, nil, nil)
}
// LogPolicyDeleted logs a policy deletion event.
func (a *Auditor) LogPolicyDeleted(ctx context.Context, policy *SourcePolicy, userEmail *string) error {
return a.LogChange(ctx, AuditActionDelete, AuditEntitySourcePolicy, &policy.ID, policy, nil, userEmail, nil, nil)
}
// LogPolicyActivated logs a policy activation event.
func (a *Auditor) LogPolicyActivated(ctx context.Context, policy *SourcePolicy, userEmail *string) error {
return a.LogChange(ctx, AuditActionActivate, AuditEntitySourcePolicy, &policy.ID, nil, policy, userEmail, nil, nil)
}
// LogPolicyDeactivated logs a policy deactivation event.
func (a *Auditor) LogPolicyDeactivated(ctx context.Context, policy *SourcePolicy, userEmail *string) error {
return a.LogChange(ctx, AuditActionDeactivate, AuditEntitySourcePolicy, &policy.ID, policy, nil, userEmail, nil, nil)
}
// LogSourceCreated logs a source creation event.
func (a *Auditor) LogSourceCreated(ctx context.Context, source *AllowedSource, userEmail *string) error {
return a.LogChange(ctx, AuditActionCreate, AuditEntityAllowedSource, &source.ID, nil, source, userEmail, nil, nil)
}
// LogSourceUpdated logs a source update event.
func (a *Auditor) LogSourceUpdated(ctx context.Context, oldSource, newSource *AllowedSource, userEmail *string) error {
return a.LogChange(ctx, AuditActionUpdate, AuditEntityAllowedSource, &newSource.ID, oldSource, newSource, userEmail, nil, nil)
}
// LogSourceDeleted logs a source deletion event.
func (a *Auditor) LogSourceDeleted(ctx context.Context, source *AllowedSource, userEmail *string) error {
return a.LogChange(ctx, AuditActionDelete, AuditEntityAllowedSource, &source.ID, source, nil, userEmail, nil, nil)
}
// LogOperationUpdated logs an operation permission update event.
func (a *Auditor) LogOperationUpdated(ctx context.Context, oldOp, newOp *OperationPermission, userEmail *string) error {
return a.LogChange(ctx, AuditActionUpdate, AuditEntityOperationPermission, &newOp.ID, oldOp, newOp, userEmail, nil, nil)
}
// LogPIIRuleCreated logs a PII rule creation event.
func (a *Auditor) LogPIIRuleCreated(ctx context.Context, rule *PIIRule, userEmail *string) error {
return a.LogChange(ctx, AuditActionCreate, AuditEntityPIIRule, &rule.ID, nil, rule, userEmail, nil, nil)
}
// LogPIIRuleUpdated logs a PII rule update event.
func (a *Auditor) LogPIIRuleUpdated(ctx context.Context, oldRule, newRule *PIIRule, userEmail *string) error {
return a.LogChange(ctx, AuditActionUpdate, AuditEntityPIIRule, &newRule.ID, oldRule, newRule, userEmail, nil, nil)
}
// LogPIIRuleDeleted logs a PII rule deletion event.
func (a *Auditor) LogPIIRuleDeleted(ctx context.Context, rule *PIIRule, userEmail *string) error {
return a.LogChange(ctx, AuditActionDelete, AuditEntityPIIRule, &rule.ID, rule, nil, userEmail, nil, nil)
}
// LogContentBlocked logs a blocked content event with details.
func (a *Auditor) LogContentBlocked(ctx context.Context, url, domain string, reason BlockReason, matchedPatterns []string, ruleID *uuid.UUID) error {
details := map[string]interface{}{
"matched_patterns": matchedPatterns,
}
return a.LogBlocked(ctx, url, domain, reason, ruleID, details)
}
// LogPIIBlocked logs content blocked due to PII detection.
func (a *Auditor) LogPIIBlocked(ctx context.Context, url, domain string, matches []PIIMatch) error {
matchDetails := make([]map[string]interface{}, len(matches))
var ruleID *uuid.UUID
for i, m := range matches {
matchDetails[i] = map[string]interface{}{
"rule_name": m.RuleName,
"severity": m.Severity,
"match": maskPII(m.Match), // Mask the actual PII in logs
}
if ruleID == nil {
ruleID = &m.RuleID
}
}
details := map[string]interface{}{
"pii_matches": matchDetails,
"match_count": len(matches),
}
return a.LogBlocked(ctx, url, domain, BlockReasonPIIDetected, ruleID, details)
}
// =============================================================================
// HELPERS
// =============================================================================
// toJSON converts a value to JSON.
func toJSON(v interface{}) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
return nil
}
return data
}
// maskPII masks PII data for safe logging.
func maskPII(pii string) string {
if len(pii) <= 4 {
return "****"
}
// Show first 2 and last 2 characters
return pii[:2] + "****" + pii[len(pii)-2:]
}
// =============================================================================
// AUDIT REPORT GENERATION
// =============================================================================
// AuditReport represents an audit report for compliance.
type AuditReport struct {
GeneratedAt string `json:"generated_at"`
PeriodStart string `json:"period_start"`
PeriodEnd string `json:"period_end"`
Summary AuditReportSummary `json:"summary"`
PolicyChanges []PolicyAuditLog `json:"policy_changes"`
BlockedContent []BlockedContentLog `json:"blocked_content"`
Stats *PolicyStats `json:"stats"`
}
// AuditReportSummary contains summary statistics for the audit report.
type AuditReportSummary struct {
TotalPolicyChanges int `json:"total_policy_changes"`
TotalBlocked int `json:"total_blocked"`
ChangesByAction map[string]int `json:"changes_by_action"`
BlocksByReason map[string]int `json:"blocks_by_reason"`
}
// GenerateAuditReport generates a compliance audit report.
func (a *Auditor) GenerateAuditReport(ctx context.Context, filter *AuditLogFilter, blockedFilter *BlockedContentFilter) (*AuditReport, error) {
// Get audit logs
auditLogs, _, err := a.store.ListAuditLogs(ctx, filter)
if err != nil {
return nil, err
}
// Get blocked content
blockedLogs, _, err := a.store.ListBlockedContent(ctx, blockedFilter)
if err != nil {
return nil, err
}
// Get stats
stats, err := a.store.GetStats(ctx)
if err != nil {
return nil, err
}
// Build summary
summary := AuditReportSummary{
TotalPolicyChanges: len(auditLogs),
TotalBlocked: len(blockedLogs),
ChangesByAction: make(map[string]int),
BlocksByReason: make(map[string]int),
}
for _, log := range auditLogs {
summary.ChangesByAction[string(log.Action)]++
}
for _, log := range blockedLogs {
summary.BlocksByReason[string(log.BlockReason)]++
}
// Build report
periodStart := ""
periodEnd := ""
if filter.FromDate != nil {
periodStart = filter.FromDate.Format("2006-01-02")
}
if filter.ToDate != nil {
periodEnd = filter.ToDate.Format("2006-01-02")
}
report := &AuditReport{
GeneratedAt: uuid.New().String()[:19], // Timestamp placeholder
PeriodStart: periodStart,
PeriodEnd: periodEnd,
Summary: summary,
PolicyChanges: auditLogs,
BlockedContent: blockedLogs,
Stats: stats,
}
return report, nil
}

View File

@@ -0,0 +1,281 @@
package policy
import (
"context"
"net/url"
"strings"
"github.com/google/uuid"
)
// Enforcer provides policy enforcement for the crawler and pipeline.
type Enforcer struct {
store *Store
piiDetector *PIIDetector
auditor *Auditor
}
// NewEnforcer creates a new Enforcer instance.
func NewEnforcer(store *Store) *Enforcer {
return &Enforcer{
store: store,
piiDetector: NewPIIDetector(store),
auditor: NewAuditor(store),
}
}
// =============================================================================
// SOURCE CHECKING
// =============================================================================
// CheckSource verifies if a URL is allowed based on the whitelist.
// Returns the AllowedSource if found, nil if not whitelisted.
func (e *Enforcer) CheckSource(ctx context.Context, rawURL string, bundesland *Bundesland) (*AllowedSource, error) {
domain, err := extractDomain(rawURL)
if err != nil {
return nil, err
}
source, err := e.store.GetSourceByDomain(ctx, domain, bundesland)
if err != nil {
return nil, err
}
return source, nil
}
// CheckOperation verifies if a specific operation is allowed for a source.
func (e *Enforcer) CheckOperation(ctx context.Context, source *AllowedSource, operation Operation) (*OperationPermission, error) {
for _, op := range source.Operations {
if op.Operation == operation {
return &op, nil
}
}
// If not found in loaded operations, query directly
ops, err := e.store.GetOperationsBySourceID(ctx, source.ID)
if err != nil {
return nil, err
}
for _, op := range ops {
if op.Operation == operation {
return &op, nil
}
}
return nil, nil
}
// CheckCompliance performs a full compliance check for a URL and operation.
func (e *Enforcer) CheckCompliance(ctx context.Context, req *CheckComplianceRequest) (*CheckComplianceResponse, error) {
response := &CheckComplianceResponse{
IsAllowed: false,
RequiresCitation: false,
}
// Check if source is whitelisted
source, err := e.CheckSource(ctx, req.URL, req.Bundesland)
if err != nil {
return nil, err
}
if source == nil {
reason := BlockReasonNotWhitelisted
response.BlockReason = &reason
return response, nil
}
response.Source = source
response.License = &source.License
response.CitationTemplate = source.CitationTemplate
// Check operation permission
opPerm, err := e.CheckOperation(ctx, source, req.Operation)
if err != nil {
return nil, err
}
if opPerm == nil || !opPerm.IsAllowed {
var reason BlockReason
if req.Operation == OperationTraining {
reason = BlockReasonTrainingForbidden
} else {
reason = BlockReasonLicenseViolation
}
response.BlockReason = &reason
return response, nil
}
response.IsAllowed = true
response.RequiresCitation = opPerm.RequiresCitation
return response, nil
}
// =============================================================================
// PII CHECKING
// =============================================================================
// DetectPII scans text for PII patterns and returns matches.
func (e *Enforcer) DetectPII(ctx context.Context, text string) (*PIITestResponse, error) {
return e.piiDetector.Detect(ctx, text)
}
// ShouldBlockForPII determines if content should be blocked based on PII matches.
func (e *Enforcer) ShouldBlockForPII(response *PIITestResponse) bool {
if response == nil {
return false
}
return response.ShouldBlock
}
// =============================================================================
// LOGGING
// =============================================================================
// LogBlocked logs a blocked URL to the blocked content log.
func (e *Enforcer) LogBlocked(ctx context.Context, rawURL string, reason BlockReason, ruleID *uuid.UUID, details map[string]interface{}) error {
domain, _ := extractDomain(rawURL)
return e.auditor.LogBlocked(ctx, rawURL, domain, reason, ruleID, details)
}
// LogChange logs a policy change to the audit log.
func (e *Enforcer) LogChange(ctx context.Context, action AuditAction, entityType AuditEntityType, entityID *uuid.UUID, oldValue, newValue interface{}, userEmail *string) error {
return e.auditor.LogChange(ctx, action, entityType, entityID, oldValue, newValue, userEmail, nil, nil)
}
// =============================================================================
// BATCH OPERATIONS
// =============================================================================
// FilterURLs filters a list of URLs, returning only whitelisted ones.
func (e *Enforcer) FilterURLs(ctx context.Context, urls []string, bundesland *Bundesland, operation Operation) ([]FilteredURL, error) {
results := make([]FilteredURL, 0, len(urls))
for _, u := range urls {
result := FilteredURL{
URL: u,
IsAllowed: false,
}
source, err := e.CheckSource(ctx, u, bundesland)
if err != nil {
result.Error = err.Error()
results = append(results, result)
continue
}
if source == nil {
result.BlockReason = BlockReasonNotWhitelisted
results = append(results, result)
continue
}
opPerm, err := e.CheckOperation(ctx, source, operation)
if err != nil {
result.Error = err.Error()
results = append(results, result)
continue
}
if opPerm == nil || !opPerm.IsAllowed {
if operation == OperationTraining {
result.BlockReason = BlockReasonTrainingForbidden
} else {
result.BlockReason = BlockReasonLicenseViolation
}
results = append(results, result)
continue
}
result.IsAllowed = true
result.Source = source
result.RequiresCitation = opPerm.RequiresCitation
results = append(results, result)
}
return results, nil
}
// FilteredURL represents the result of filtering a single URL.
type FilteredURL struct {
URL string `json:"url"`
IsAllowed bool `json:"is_allowed"`
Source *AllowedSource `json:"source,omitempty"`
BlockReason BlockReason `json:"block_reason,omitempty"`
RequiresCitation bool `json:"requires_citation"`
Error string `json:"error,omitempty"`
}
// =============================================================================
// HELPERS
// =============================================================================
// extractDomain extracts the domain from a URL.
func extractDomain(rawURL string) (string, error) {
// Handle URLs without scheme
if !strings.Contains(rawURL, "://") {
rawURL = "https://" + rawURL
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "", err
}
host := parsed.Hostname()
// Remove www. prefix
host = strings.TrimPrefix(host, "www.")
return host, nil
}
// IsTrainingAllowed checks if training is allowed for any source (should always be false).
func (e *Enforcer) IsTrainingAllowed(ctx context.Context) (bool, error) {
// Training should NEVER be allowed - this is a safeguard
matrix, err := e.store.GetOperationsMatrix(ctx)
if err != nil {
return false, err
}
for _, source := range matrix {
for _, op := range source.Operations {
if op.Operation == OperationTraining && op.IsAllowed {
// This should never happen - log a warning
return true, nil
}
}
}
return false, nil
}
// GetSourceByURL is a convenience method to get a source by URL.
func (e *Enforcer) GetSourceByURL(ctx context.Context, rawURL string, bundesland *Bundesland) (*AllowedSource, error) {
return e.CheckSource(ctx, rawURL, bundesland)
}
// GetCitationForURL generates a citation for a URL if required.
func (e *Enforcer) GetCitationForURL(ctx context.Context, rawURL string, bundesland *Bundesland, title string, date string) (string, error) {
source, err := e.CheckSource(ctx, rawURL, bundesland)
if err != nil || source == nil {
return "", err
}
if source.CitationTemplate == nil || *source.CitationTemplate == "" {
// Default citation format
return "Quelle: " + source.Name + ", " + title + ", " + date, nil
}
// Replace placeholders in template
citation := *source.CitationTemplate
citation = strings.ReplaceAll(citation, "{title}", title)
citation = strings.ReplaceAll(citation, "{date}", date)
citation = strings.ReplaceAll(citation, "{url}", rawURL)
citation = strings.ReplaceAll(citation, "{domain}", source.Domain)
citation = strings.ReplaceAll(citation, "{source}", source.Name)
return citation, nil
}

View File

@@ -0,0 +1,255 @@
package policy
import (
"context"
"fmt"
"os"
"gopkg.in/yaml.v3"
)
// Loader handles loading policy configuration from YAML files.
type Loader struct {
store *Store
}
// NewLoader creates a new Loader instance.
func NewLoader(store *Store) *Loader {
return &Loader{store: store}
}
// LoadFromFile loads policy configuration from a YAML file.
func (l *Loader) LoadFromFile(ctx context.Context, path string) error {
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read YAML file: %w", err)
}
config, err := ParseYAML(data)
if err != nil {
return fmt.Errorf("failed to parse YAML: %w", err)
}
return l.store.LoadFromYAML(ctx, config)
}
// ParseYAML parses YAML configuration data.
func ParseYAML(data []byte) (*BundeslaenderConfig, error) {
// First, parse as a generic map to handle the inline Bundeslaender
var rawConfig map[string]interface{}
if err := yaml.Unmarshal(data, &rawConfig); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
config := &BundeslaenderConfig{
Bundeslaender: make(map[string]PolicyConfig),
}
// Parse federal
if federal, ok := rawConfig["federal"]; ok {
if federalMap, ok := federal.(map[string]interface{}); ok {
config.Federal = parsePolicyConfig(federalMap)
}
}
// Parse default_operations
if ops, ok := rawConfig["default_operations"]; ok {
if opsMap, ok := ops.(map[string]interface{}); ok {
config.DefaultOperations = parseOperationsConfig(opsMap)
}
}
// Parse pii_rules
if rules, ok := rawConfig["pii_rules"]; ok {
if rulesSlice, ok := rules.([]interface{}); ok {
for _, rule := range rulesSlice {
if ruleMap, ok := rule.(map[string]interface{}); ok {
config.PIIRules = append(config.PIIRules, parsePIIRuleConfig(ruleMap))
}
}
}
}
// Parse Bundeslaender (2-letter codes)
bundeslaender := []string{"BW", "BY", "BE", "BB", "HB", "HH", "HE", "MV", "NI", "NW", "RP", "SL", "SN", "ST", "SH", "TH"}
for _, bl := range bundeslaender {
if blConfig, ok := rawConfig[bl]; ok {
if blMap, ok := blConfig.(map[string]interface{}); ok {
config.Bundeslaender[bl] = parsePolicyConfig(blMap)
}
}
}
return config, nil
}
func parsePolicyConfig(m map[string]interface{}) PolicyConfig {
pc := PolicyConfig{}
if name, ok := m["name"].(string); ok {
pc.Name = name
}
if sources, ok := m["sources"].([]interface{}); ok {
for _, src := range sources {
if srcMap, ok := src.(map[string]interface{}); ok {
pc.Sources = append(pc.Sources, parseSourceConfig(srcMap))
}
}
}
return pc
}
func parseSourceConfig(m map[string]interface{}) SourceConfig {
sc := SourceConfig{
TrustBoost: 0.5, // Default
}
if domain, ok := m["domain"].(string); ok {
sc.Domain = domain
}
if name, ok := m["name"].(string); ok {
sc.Name = name
}
if license, ok := m["license"].(string); ok {
sc.License = license
}
if legalBasis, ok := m["legal_basis"].(string); ok {
sc.LegalBasis = legalBasis
}
if citation, ok := m["citation_template"].(string); ok {
sc.CitationTemplate = citation
}
if trustBoost, ok := m["trust_boost"].(float64); ok {
sc.TrustBoost = trustBoost
}
return sc
}
func parseOperationsConfig(m map[string]interface{}) OperationsConfig {
oc := OperationsConfig{}
if lookup, ok := m["lookup"].(map[string]interface{}); ok {
oc.Lookup = parseOperationConfig(lookup)
}
if rag, ok := m["rag"].(map[string]interface{}); ok {
oc.RAG = parseOperationConfig(rag)
}
if training, ok := m["training"].(map[string]interface{}); ok {
oc.Training = parseOperationConfig(training)
}
if export, ok := m["export"].(map[string]interface{}); ok {
oc.Export = parseOperationConfig(export)
}
return oc
}
func parseOperationConfig(m map[string]interface{}) OperationConfig {
oc := OperationConfig{}
if allowed, ok := m["allowed"].(bool); ok {
oc.Allowed = allowed
}
if requiresCitation, ok := m["requires_citation"].(bool); ok {
oc.RequiresCitation = requiresCitation
}
return oc
}
func parsePIIRuleConfig(m map[string]interface{}) PIIRuleConfig {
rc := PIIRuleConfig{
Severity: "block", // Default
}
if name, ok := m["name"].(string); ok {
rc.Name = name
}
if ruleType, ok := m["type"].(string); ok {
rc.Type = ruleType
}
if pattern, ok := m["pattern"].(string); ok {
rc.Pattern = pattern
}
if severity, ok := m["severity"].(string); ok {
rc.Severity = severity
}
return rc
}
// LoadDefaults loads a minimal set of default data (for testing or when no YAML exists).
func (l *Loader) LoadDefaults(ctx context.Context) error {
// Create federal policy with KMK
federalPolicy, err := l.store.CreatePolicy(ctx, &CreateSourcePolicyRequest{
Name: "KMK & Bundesebene",
})
if err != nil {
return fmt.Errorf("failed to create federal policy: %w", err)
}
trustBoost := 0.95
legalBasis := "Amtliche Werke (§5 UrhG)"
citation := "Quelle: KMK, {title}, {date}"
_, err = l.store.CreateSource(ctx, &CreateAllowedSourceRequest{
PolicyID: federalPolicy.ID,
Domain: "kmk.org",
Name: "Kultusministerkonferenz",
License: LicenseParagraph5,
LegalBasis: &legalBasis,
CitationTemplate: &citation,
TrustBoost: &trustBoost,
})
if err != nil {
return fmt.Errorf("failed to create KMK source: %w", err)
}
// Create default PII rules
defaultRules := DefaultPIIRules()
for _, rule := range defaultRules {
_, err := l.store.CreatePIIRule(ctx, &CreatePIIRuleRequest{
Name: rule.Name,
RuleType: PIIRuleType(rule.Type),
Pattern: rule.Pattern,
Severity: PIISeverity(rule.Severity),
})
if err != nil {
return fmt.Errorf("failed to create PII rule %s: %w", rule.Name, err)
}
}
return nil
}
// HasData checks if the policy tables already have data.
func (l *Loader) HasData(ctx context.Context) (bool, error) {
policies, _, err := l.store.ListPolicies(ctx, &PolicyListFilter{Limit: 1})
if err != nil {
return false, err
}
return len(policies) > 0, nil
}
// LoadIfEmpty loads data from YAML only if tables are empty.
func (l *Loader) LoadIfEmpty(ctx context.Context, path string) error {
hasData, err := l.HasData(ctx)
if err != nil {
return err
}
if hasData {
return nil // Already has data, skip loading
}
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
// File doesn't exist, load defaults
return l.LoadDefaults(ctx)
}
return l.LoadFromFile(ctx, path)
}

View File

@@ -0,0 +1,445 @@
// Package policy provides whitelist-based data source management for the edu-search-service.
// It implements source policies, operation permissions, PII detection, and audit logging
// for compliance with German data protection regulations.
package policy
import (
"encoding/json"
"time"
"github.com/google/uuid"
)
// =============================================================================
// ENUMS AND CONSTANTS
// =============================================================================
// Bundesland represents German federal states (2-letter codes).
type Bundesland string
const (
BundeslandBW Bundesland = "BW" // Baden-Wuerttemberg
BundeslandBY Bundesland = "BY" // Bayern
BundeslandBE Bundesland = "BE" // Berlin
BundeslandBB Bundesland = "BB" // Brandenburg
BundeslandHB Bundesland = "HB" // Bremen
BundeslandHH Bundesland = "HH" // Hamburg
BundeslandHE Bundesland = "HE" // Hessen
BundeslandMV Bundesland = "MV" // Mecklenburg-Vorpommern
BundeslandNI Bundesland = "NI" // Niedersachsen
BundeslandNW Bundesland = "NW" // Nordrhein-Westfalen
BundeslandRP Bundesland = "RP" // Rheinland-Pfalz
BundeslandSL Bundesland = "SL" // Saarland
BundeslandSN Bundesland = "SN" // Sachsen
BundeslandST Bundesland = "ST" // Sachsen-Anhalt
BundeslandSH Bundesland = "SH" // Schleswig-Holstein
BundeslandTH Bundesland = "TH" // Thueringen
)
// ValidBundeslaender contains all valid German federal state codes.
var ValidBundeslaender = []Bundesland{
BundeslandBW, BundeslandBY, BundeslandBE, BundeslandBB,
BundeslandHB, BundeslandHH, BundeslandHE, BundeslandMV,
BundeslandNI, BundeslandNW, BundeslandRP, BundeslandSL,
BundeslandSN, BundeslandST, BundeslandSH, BundeslandTH,
}
// License represents allowed license types for data sources.
type License string
const (
LicenseDLDEBY20 License = "DL-DE-BY-2.0" // Datenlizenz Deutschland - Namensnennung
LicenseCCBY License = "CC-BY" // Creative Commons Attribution
LicenseCCBYSA License = "CC-BY-SA" // Creative Commons Attribution-ShareAlike
LicenseCCBYNC License = "CC-BY-NC" // Creative Commons Attribution-NonCommercial
LicenseCCBYNCSA License = "CC-BY-NC-SA" // Creative Commons Attribution-NonCommercial-ShareAlike
LicenseCC0 License = "CC0" // Public Domain
LicenseParagraph5 License = "§5 UrhG" // Amtliche Werke (German Copyright Act)
LicenseCustom License = "Custom" // Custom license (requires legal basis)
)
// Operation represents the types of operations that can be performed on data.
type Operation string
const (
OperationLookup Operation = "lookup" // Display/Search
OperationRAG Operation = "rag" // RAG (Retrieval-Augmented Generation)
OperationTraining Operation = "training" // Model Training (VERBOTEN by default)
OperationExport Operation = "export" // Data Export
)
// ValidOperations contains all valid operation types.
var ValidOperations = []Operation{
OperationLookup,
OperationRAG,
OperationTraining,
OperationExport,
}
// PIIRuleType represents the type of PII detection rule.
type PIIRuleType string
const (
PIIRuleTypeRegex PIIRuleType = "regex" // Regular expression pattern
PIIRuleTypeKeyword PIIRuleType = "keyword" // Keyword matching
)
// PIISeverity represents the severity level of a PII match.
type PIISeverity string
const (
PIISeverityBlock PIISeverity = "block" // Block content completely
PIISeverityWarn PIISeverity = "warn" // Warn but allow
PIISeverityRedact PIISeverity = "redact" // Redact matched content
)
// AuditAction represents the type of action logged in the audit trail.
type AuditAction string
const (
AuditActionCreate AuditAction = "create"
AuditActionUpdate AuditAction = "update"
AuditActionDelete AuditAction = "delete"
AuditActionActivate AuditAction = "activate"
AuditActionDeactivate AuditAction = "deactivate"
AuditActionApprove AuditAction = "approve"
)
// AuditEntityType represents the type of entity being audited.
type AuditEntityType string
const (
AuditEntitySourcePolicy AuditEntityType = "source_policy"
AuditEntityAllowedSource AuditEntityType = "allowed_source"
AuditEntityOperationPermission AuditEntityType = "operation_permission"
AuditEntityPIIRule AuditEntityType = "pii_rule"
)
// BlockReason represents the reason why content was blocked.
type BlockReason string
const (
BlockReasonNotWhitelisted BlockReason = "not_whitelisted"
BlockReasonPIIDetected BlockReason = "pii_detected"
BlockReasonTrainingForbidden BlockReason = "training_forbidden"
BlockReasonLicenseViolation BlockReason = "license_violation"
BlockReasonManualBlock BlockReason = "manual_block"
)
// =============================================================================
// CORE MODELS
// =============================================================================
// SourcePolicy represents a versioned policy for data source management.
// Policies can be scoped to a specific Bundesland or apply federally (bundesland = nil).
type SourcePolicy struct {
ID uuid.UUID `json:"id" db:"id"`
Version int `json:"version" db:"version"`
Name string `json:"name" db:"name"`
Description *string `json:"description,omitempty" db:"description"`
Bundesland *Bundesland `json:"bundesland,omitempty" db:"bundesland"`
IsActive bool `json:"is_active" db:"is_active"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
ApprovedBy *uuid.UUID `json:"approved_by,omitempty" db:"approved_by"`
ApprovedAt *time.Time `json:"approved_at,omitempty" db:"approved_at"`
// Joined fields (populated by queries)
Sources []AllowedSource `json:"sources,omitempty"`
}
// AllowedSource represents a whitelisted data source with license information.
type AllowedSource struct {
ID uuid.UUID `json:"id" db:"id"`
PolicyID uuid.UUID `json:"policy_id" db:"policy_id"`
Domain string `json:"domain" db:"domain"`
Name string `json:"name" db:"name"`
Description *string `json:"description,omitempty" db:"description"`
License License `json:"license" db:"license"`
LegalBasis *string `json:"legal_basis,omitempty" db:"legal_basis"`
CitationTemplate *string `json:"citation_template,omitempty" db:"citation_template"`
TrustBoost float64 `json:"trust_boost" db:"trust_boost"`
IsActive bool `json:"is_active" db:"is_active"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
// Joined fields (populated by queries)
Operations []OperationPermission `json:"operations,omitempty"`
PolicyName *string `json:"policy_name,omitempty"`
}
// OperationPermission represents the permission matrix for a specific source.
type OperationPermission struct {
ID uuid.UUID `json:"id" db:"id"`
SourceID uuid.UUID `json:"source_id" db:"source_id"`
Operation Operation `json:"operation" db:"operation"`
IsAllowed bool `json:"is_allowed" db:"is_allowed"`
RequiresCitation bool `json:"requires_citation" db:"requires_citation"`
Notes *string `json:"notes,omitempty" db:"notes"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// PIIRule represents a rule for detecting personally identifiable information.
type PIIRule struct {
ID uuid.UUID `json:"id" db:"id"`
Name string `json:"name" db:"name"`
Description *string `json:"description,omitempty" db:"description"`
RuleType PIIRuleType `json:"rule_type" db:"rule_type"`
Pattern string `json:"pattern" db:"pattern"`
Severity PIISeverity `json:"severity" db:"severity"`
IsActive bool `json:"is_active" db:"is_active"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// =============================================================================
// AUDIT AND LOGGING MODELS
// =============================================================================
// PolicyAuditLog represents an immutable audit log entry for policy changes.
type PolicyAuditLog struct {
ID uuid.UUID `json:"id" db:"id"`
Action AuditAction `json:"action" db:"action"`
EntityType AuditEntityType `json:"entity_type" db:"entity_type"`
EntityID *uuid.UUID `json:"entity_id,omitempty" db:"entity_id"`
OldValue json.RawMessage `json:"old_value,omitempty" db:"old_value"`
NewValue json.RawMessage `json:"new_value,omitempty" db:"new_value"`
UserID *uuid.UUID `json:"user_id,omitempty" db:"user_id"`
UserEmail *string `json:"user_email,omitempty" db:"user_email"`
IPAddress *string `json:"ip_address,omitempty" db:"ip_address"`
UserAgent *string `json:"user_agent,omitempty" db:"user_agent"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// BlockedContentLog represents a log entry for blocked URLs.
type BlockedContentLog struct {
ID uuid.UUID `json:"id" db:"id"`
URL string `json:"url" db:"url"`
Domain string `json:"domain" db:"domain"`
BlockReason BlockReason `json:"block_reason" db:"block_reason"`
MatchedRuleID *uuid.UUID `json:"matched_rule_id,omitempty" db:"matched_rule_id"`
Details json.RawMessage `json:"details,omitempty" db:"details"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// =============================================================================
// REQUEST/RESPONSE MODELS
// =============================================================================
// CreateSourcePolicyRequest represents a request to create a new policy.
type CreateSourcePolicyRequest struct {
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
Bundesland *Bundesland `json:"bundesland"`
}
// UpdateSourcePolicyRequest represents a request to update a policy.
type UpdateSourcePolicyRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Bundesland *Bundesland `json:"bundesland"`
IsActive *bool `json:"is_active"`
}
// CreateAllowedSourceRequest represents a request to create a new allowed source.
type CreateAllowedSourceRequest struct {
PolicyID uuid.UUID `json:"policy_id" binding:"required"`
Domain string `json:"domain" binding:"required"`
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
License License `json:"license" binding:"required"`
LegalBasis *string `json:"legal_basis"`
CitationTemplate *string `json:"citation_template"`
TrustBoost *float64 `json:"trust_boost"`
}
// UpdateAllowedSourceRequest represents a request to update an allowed source.
type UpdateAllowedSourceRequest struct {
Domain *string `json:"domain"`
Name *string `json:"name"`
Description *string `json:"description"`
License *License `json:"license"`
LegalBasis *string `json:"legal_basis"`
CitationTemplate *string `json:"citation_template"`
TrustBoost *float64 `json:"trust_boost"`
IsActive *bool `json:"is_active"`
}
// UpdateOperationPermissionRequest represents a request to update operation permissions.
type UpdateOperationPermissionRequest struct {
IsAllowed *bool `json:"is_allowed"`
RequiresCitation *bool `json:"requires_citation"`
Notes *string `json:"notes"`
}
// CreatePIIRuleRequest represents a request to create a new PII rule.
type CreatePIIRuleRequest struct {
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
RuleType PIIRuleType `json:"rule_type" binding:"required"`
Pattern string `json:"pattern" binding:"required"`
Severity PIISeverity `json:"severity"`
}
// UpdatePIIRuleRequest represents a request to update a PII rule.
type UpdatePIIRuleRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
RuleType *PIIRuleType `json:"rule_type"`
Pattern *string `json:"pattern"`
Severity *PIISeverity `json:"severity"`
IsActive *bool `json:"is_active"`
}
// CheckComplianceRequest represents a request to check URL compliance.
type CheckComplianceRequest struct {
URL string `json:"url" binding:"required"`
Operation Operation `json:"operation" binding:"required"`
Bundesland *Bundesland `json:"bundesland"`
}
// CheckComplianceResponse represents the compliance check result.
type CheckComplianceResponse struct {
IsAllowed bool `json:"is_allowed"`
Source *AllowedSource `json:"source,omitempty"`
BlockReason *BlockReason `json:"block_reason,omitempty"`
RequiresCitation bool `json:"requires_citation"`
CitationTemplate *string `json:"citation_template,omitempty"`
License *License `json:"license,omitempty"`
}
// PIITestRequest represents a request to test PII detection.
type PIITestRequest struct {
Text string `json:"text" binding:"required"`
}
// PIIMatch represents a single PII match in text.
type PIIMatch struct {
RuleID uuid.UUID `json:"rule_id"`
RuleName string `json:"rule_name"`
RuleType PIIRuleType `json:"rule_type"`
Severity PIISeverity `json:"severity"`
Match string `json:"match"`
StartIndex int `json:"start_index"`
EndIndex int `json:"end_index"`
}
// PIITestResponse represents the result of PII detection test.
type PIITestResponse struct {
HasPII bool `json:"has_pii"`
Matches []PIIMatch `json:"matches"`
BlockLevel PIISeverity `json:"block_level"`
ShouldBlock bool `json:"should_block"`
}
// =============================================================================
// LIST/FILTER MODELS
// =============================================================================
// PolicyListFilter represents filters for listing policies.
type PolicyListFilter struct {
Bundesland *Bundesland `form:"bundesland"`
IsActive *bool `form:"is_active"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
// SourceListFilter represents filters for listing sources.
type SourceListFilter struct {
PolicyID *uuid.UUID `form:"policy_id"`
Domain *string `form:"domain"`
License *License `form:"license"`
IsActive *bool `form:"is_active"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
// AuditLogFilter represents filters for querying audit logs.
type AuditLogFilter struct {
EntityType *AuditEntityType `form:"entity_type"`
EntityID *uuid.UUID `form:"entity_id"`
Action *AuditAction `form:"action"`
UserEmail *string `form:"user_email"`
FromDate *time.Time `form:"from"`
ToDate *time.Time `form:"to"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
// BlockedContentFilter represents filters for querying blocked content logs.
type BlockedContentFilter struct {
Domain *string `form:"domain"`
BlockReason *BlockReason `form:"block_reason"`
FromDate *time.Time `form:"from"`
ToDate *time.Time `form:"to"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
// =============================================================================
// STATISTICS MODELS
// =============================================================================
// PolicyStats represents aggregated statistics for the policy system.
type PolicyStats struct {
ActivePolicies int `json:"active_policies"`
TotalSources int `json:"total_sources"`
ActiveSources int `json:"active_sources"`
BlockedToday int `json:"blocked_today"`
BlockedTotal int `json:"blocked_total"`
PIIRulesActive int `json:"pii_rules_active"`
SourcesByLicense map[string]int `json:"sources_by_license"`
BlocksByReason map[string]int `json:"blocks_by_reason"`
ComplianceScore float64 `json:"compliance_score"`
}
// =============================================================================
// YAML CONFIGURATION MODELS
// =============================================================================
// BundeslaenderConfig represents the YAML configuration for initial data loading.
type BundeslaenderConfig struct {
Federal PolicyConfig `yaml:"federal"`
Bundeslaender map[string]PolicyConfig `yaml:",inline"`
DefaultOperations OperationsConfig `yaml:"default_operations"`
PIIRules []PIIRuleConfig `yaml:"pii_rules"`
}
// PolicyConfig represents a policy configuration in YAML.
type PolicyConfig struct {
Name string `yaml:"name"`
Sources []SourceConfig `yaml:"sources"`
}
// SourceConfig represents a source configuration in YAML.
type SourceConfig struct {
Domain string `yaml:"domain"`
Name string `yaml:"name"`
License string `yaml:"license"`
LegalBasis string `yaml:"legal_basis,omitempty"`
CitationTemplate string `yaml:"citation_template,omitempty"`
TrustBoost float64 `yaml:"trust_boost,omitempty"`
}
// OperationsConfig represents default operation permissions in YAML.
type OperationsConfig struct {
Lookup OperationConfig `yaml:"lookup"`
RAG OperationConfig `yaml:"rag"`
Training OperationConfig `yaml:"training"`
Export OperationConfig `yaml:"export"`
}
// OperationConfig represents a single operation permission in YAML.
type OperationConfig struct {
Allowed bool `yaml:"allowed"`
RequiresCitation bool `yaml:"requires_citation"`
}
// PIIRuleConfig represents a PII rule configuration in YAML.
type PIIRuleConfig struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
Pattern string `yaml:"pattern"`
Severity string `yaml:"severity"`
}

View File

@@ -0,0 +1,350 @@
package policy
import (
"context"
"regexp"
"strings"
"sync"
)
// PIIDetector detects personally identifiable information in text.
type PIIDetector struct {
store *Store
compiledRules map[string]*regexp.Regexp
rulesMu sync.RWMutex
}
// NewPIIDetector creates a new PIIDetector instance.
func NewPIIDetector(store *Store) *PIIDetector {
return &PIIDetector{
store: store,
compiledRules: make(map[string]*regexp.Regexp),
}
}
// Detect scans text for PII patterns and returns all matches.
func (d *PIIDetector) Detect(ctx context.Context, text string) (*PIITestResponse, error) {
rules, err := d.store.ListPIIRules(ctx, true)
if err != nil {
return nil, err
}
response := &PIITestResponse{
HasPII: false,
Matches: []PIIMatch{},
ShouldBlock: false,
}
highestSeverity := PIISeverity("")
for _, rule := range rules {
matches := d.findMatches(text, &rule)
if len(matches) > 0 {
response.HasPII = true
response.Matches = append(response.Matches, matches...)
// Track highest severity
if compareSeverity(rule.Severity, highestSeverity) > 0 {
highestSeverity = rule.Severity
}
}
}
response.BlockLevel = highestSeverity
response.ShouldBlock = highestSeverity == PIISeverityBlock
return response, nil
}
// findMatches finds all matches for a single rule in the text.
func (d *PIIDetector) findMatches(text string, rule *PIIRule) []PIIMatch {
var matches []PIIMatch
switch rule.RuleType {
case PIIRuleTypeRegex:
matches = d.findRegexMatches(text, rule)
case PIIRuleTypeKeyword:
matches = d.findKeywordMatches(text, rule)
}
return matches
}
// findRegexMatches finds all regex pattern matches in text.
func (d *PIIDetector) findRegexMatches(text string, rule *PIIRule) []PIIMatch {
re := d.getCompiledRegex(rule.ID.String(), rule.Pattern)
if re == nil {
return nil
}
var matches []PIIMatch
allMatches := re.FindAllStringIndex(text, -1)
for _, loc := range allMatches {
matches = append(matches, PIIMatch{
RuleID: rule.ID,
RuleName: rule.Name,
RuleType: rule.RuleType,
Severity: rule.Severity,
Match: text[loc[0]:loc[1]],
StartIndex: loc[0],
EndIndex: loc[1],
})
}
return matches
}
// findKeywordMatches finds all keyword matches in text (case-insensitive).
func (d *PIIDetector) findKeywordMatches(text string, rule *PIIRule) []PIIMatch {
var matches []PIIMatch
lowerText := strings.ToLower(text)
// Split pattern by commas or pipes for multiple keywords
keywords := strings.FieldsFunc(rule.Pattern, func(r rune) bool {
return r == ',' || r == '|'
})
for _, keyword := range keywords {
keyword = strings.TrimSpace(keyword)
if keyword == "" {
continue
}
lowerKeyword := strings.ToLower(keyword)
startIdx := 0
for {
idx := strings.Index(lowerText[startIdx:], lowerKeyword)
if idx == -1 {
break
}
actualIdx := startIdx + idx
matches = append(matches, PIIMatch{
RuleID: rule.ID,
RuleName: rule.Name,
RuleType: rule.RuleType,
Severity: rule.Severity,
Match: text[actualIdx : actualIdx+len(keyword)],
StartIndex: actualIdx,
EndIndex: actualIdx + len(keyword),
})
startIdx = actualIdx + len(keyword)
}
}
return matches
}
// getCompiledRegex returns a compiled regex, caching for performance.
func (d *PIIDetector) getCompiledRegex(ruleID, pattern string) *regexp.Regexp {
d.rulesMu.RLock()
re, ok := d.compiledRules[ruleID]
d.rulesMu.RUnlock()
if ok {
return re
}
// Compile and cache
d.rulesMu.Lock()
defer d.rulesMu.Unlock()
// Double-check after acquiring write lock
if re, ok = d.compiledRules[ruleID]; ok {
return re
}
compiled, err := regexp.Compile(pattern)
if err != nil {
// Invalid regex - don't cache
return nil
}
d.compiledRules[ruleID] = compiled
return compiled
}
// ClearCache clears the compiled regex cache (call after rule updates).
func (d *PIIDetector) ClearCache() {
d.rulesMu.Lock()
defer d.rulesMu.Unlock()
d.compiledRules = make(map[string]*regexp.Regexp)
}
// RefreshRules reloads rules and clears the cache.
func (d *PIIDetector) RefreshRules() {
d.ClearCache()
}
// compareSeverity compares two severity levels.
// Returns: 1 if a > b, -1 if a < b, 0 if equal.
func compareSeverity(a, b PIISeverity) int {
severityOrder := map[PIISeverity]int{
"": 0,
PIISeverityWarn: 1,
PIISeverityRedact: 2,
PIISeverityBlock: 3,
}
aOrder := severityOrder[a]
bOrder := severityOrder[b]
if aOrder > bOrder {
return 1
} else if aOrder < bOrder {
return -1
}
return 0
}
// =============================================================================
// PREDEFINED PII PATTERNS (German Context)
// =============================================================================
// DefaultPIIRules returns a set of default PII detection rules for German context.
func DefaultPIIRules() []PIIRuleConfig {
return []PIIRuleConfig{
// Email Addresses
{
Name: "Email Addresses",
Type: "regex",
Pattern: `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`,
Severity: "block",
},
// German Phone Numbers
{
Name: "German Phone Numbers",
Type: "regex",
Pattern: `(?:\+49|0)[\s.-]?\d{2,4}[\s.-]?\d{3,}[\s.-]?\d{2,}`,
Severity: "block",
},
// German Mobile Numbers
{
Name: "German Mobile Numbers",
Type: "regex",
Pattern: `(?:\+49|0)1[567]\d[\s.-]?\d{3,}[\s.-]?\d{2,}`,
Severity: "block",
},
// IBAN (German)
{
Name: "German IBAN",
Type: "regex",
Pattern: `DE\d{2}\s?\d{4}\s?\d{4}\s?\d{4}\s?\d{4}\s?\d{2}`,
Severity: "block",
},
// German Social Security Number (Sozialversicherungsnummer)
{
Name: "German Social Security Number",
Type: "regex",
Pattern: `\d{2}[0-3]\d[01]\d{2}[A-Z]\d{3}`,
Severity: "block",
},
// German Tax ID (Steuer-ID)
{
Name: "German Tax ID",
Type: "regex",
Pattern: `\d{2}\s?\d{3}\s?\d{3}\s?\d{3}`,
Severity: "block",
},
// Credit Card Numbers (Luhn-compatible patterns)
{
Name: "Credit Card Numbers",
Type: "regex",
Pattern: `(?:\d{4}[\s.-]?){3}\d{4}`,
Severity: "block",
},
// German Postal Code + City Pattern (potential address)
{
Name: "German Address Pattern",
Type: "regex",
Pattern: `\d{5}\s+[A-ZÄÖÜ][a-zäöüß]+`,
Severity: "warn",
},
// Date of Birth Patterns (DD.MM.YYYY)
{
Name: "Date of Birth",
Type: "regex",
Pattern: `(?:geboren|geb\.|Geburtsdatum|DoB)[\s:]*\d{1,2}[\./]\d{1,2}[\./]\d{2,4}`,
Severity: "warn",
},
// Personal Names with Titles
{
Name: "Personal Names with Titles",
Type: "regex",
Pattern: `(?:Herr|Frau|Dr\.|Prof\.)\s+[A-ZÄÖÜ][a-zäöüß]+\s+[A-ZÄÖÜ][a-zäöüß]+`,
Severity: "warn",
},
// German Health Insurance Number
{
Name: "Health Insurance Number",
Type: "regex",
Pattern: `[A-Z]\d{9}`,
Severity: "block",
},
// Vehicle Registration (German)
{
Name: "German Vehicle Registration",
Type: "regex",
Pattern: `[A-ZÄÖÜ]{1,3}[\s-]?[A-Z]{1,2}[\s-]?\d{1,4}[HE]?`,
Severity: "warn",
},
}
}
// =============================================================================
// REDACTION
// =============================================================================
// RedactText redacts PII from text based on the matches.
func (d *PIIDetector) RedactText(text string, matches []PIIMatch) string {
if len(matches) == 0 {
return text
}
// Sort matches by start index (descending) to replace from end
sortedMatches := make([]PIIMatch, len(matches))
copy(sortedMatches, matches)
// Simple bubble sort for small number of matches
for i := 0; i < len(sortedMatches)-1; i++ {
for j := 0; j < len(sortedMatches)-i-1; j++ {
if sortedMatches[j].StartIndex < sortedMatches[j+1].StartIndex {
sortedMatches[j], sortedMatches[j+1] = sortedMatches[j+1], sortedMatches[j]
}
}
}
result := text
for _, match := range sortedMatches {
if match.Severity == PIISeverityRedact || match.Severity == PIISeverityBlock {
replacement := strings.Repeat("*", match.EndIndex-match.StartIndex)
result = result[:match.StartIndex] + replacement + result[match.EndIndex:]
}
}
return result
}
// FilterContent filters content based on PII detection.
// Returns the filtered content and whether it should be blocked.
func (d *PIIDetector) FilterContent(ctx context.Context, content string) (string, bool, error) {
response, err := d.Detect(ctx, content)
if err != nil {
return content, false, err
}
if !response.HasPII {
return content, false, nil
}
if response.ShouldBlock {
return "", true, nil
}
// Redact content for warn/redact severity
redacted := d.RedactText(content, response.Matches)
return redacted, false, nil
}

View File

@@ -0,0 +1,489 @@
package policy
import (
"regexp"
"testing"
)
// =============================================================================
// MODEL TESTS
// =============================================================================
func TestBundeslandValidation(t *testing.T) {
tests := []struct {
name string
bl Bundesland
expected bool
}{
{"valid NI", BundeslandNI, true},
{"valid BY", BundeslandBY, true},
{"valid BW", BundeslandBW, true},
{"valid NW", BundeslandNW, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
found := false
for _, valid := range ValidBundeslaender {
if valid == tt.bl {
found = true
break
}
}
if found != tt.expected {
t.Errorf("Expected %v to be valid=%v, got valid=%v", tt.bl, tt.expected, found)
}
})
}
}
func TestLicenseValues(t *testing.T) {
licenses := []License{
LicenseDLDEBY20,
LicenseCCBY,
LicenseCCBYSA,
LicenseCC0,
LicenseParagraph5,
}
for _, l := range licenses {
if l == "" {
t.Errorf("License should not be empty")
}
}
}
func TestOperationValues(t *testing.T) {
if len(ValidOperations) != 4 {
t.Errorf("Expected 4 operations, got %d", len(ValidOperations))
}
expectedOps := []Operation{OperationLookup, OperationRAG, OperationTraining, OperationExport}
for _, expected := range expectedOps {
found := false
for _, op := range ValidOperations {
if op == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected operation %s not found in ValidOperations", expected)
}
}
}
// =============================================================================
// PII DETECTOR TESTS
// =============================================================================
func TestPIIDetector_EmailDetection(t *testing.T) {
tests := []struct {
name string
text string
hasEmail bool
}{
{"simple email", "Contact: test@example.com", true},
{"email with plus", "Email: user+tag@domain.org", true},
{"no email", "This is plain text", false},
{"partial email", "user@ is not an email", false},
{"multiple emails", "Send to a@b.com and x@y.de", true},
}
// Test using regex pattern directly since we don't have a store
emailPattern := `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simple test without database
rule := &PIIRule{
Name: "Email",
RuleType: PIIRuleTypeRegex,
Pattern: emailPattern,
Severity: PIISeverityBlock,
}
detector := &PIIDetector{
compiledRules: make(map[string]*regexp.Regexp),
}
matches := detector.findMatches(tt.text, rule)
hasMatch := len(matches) > 0
if hasMatch != tt.hasEmail {
t.Errorf("Expected hasEmail=%v, got %v for text: %s", tt.hasEmail, hasMatch, tt.text)
}
})
}
}
func TestPIIDetector_PhoneDetection(t *testing.T) {
tests := []struct {
name string
text string
hasPhone bool
}{
{"german mobile", "Call +49 170 1234567", true},
{"german landline", "Tel: 030-12345678", true},
{"with spaces", "Phone: 0170 123 4567", true},
{"no phone", "This is just text", false},
{"US format", "Call 555-123-4567", false}, // Should not match German pattern
}
phonePattern := `(?:\+49|0)[\s.-]?\d{2,4}[\s.-]?\d{3,}[\s.-]?\d{2,}`
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &PIIRule{
Name: "Phone",
RuleType: PIIRuleTypeRegex,
Pattern: phonePattern,
Severity: PIISeverityBlock,
}
detector := &PIIDetector{
compiledRules: make(map[string]*regexp.Regexp),
}
matches := detector.findMatches(tt.text, rule)
hasMatch := len(matches) > 0
if hasMatch != tt.hasPhone {
t.Errorf("Expected hasPhone=%v, got %v for text: %s", tt.hasPhone, hasMatch, tt.text)
}
})
}
}
func TestPIIDetector_IBANDetection(t *testing.T) {
tests := []struct {
name string
text string
hasIBAN bool
}{
{"valid IBAN", "IBAN: DE89 3704 0044 0532 0130 00", true},
{"compact IBAN", "DE89370400440532013000", true},
{"no IBAN", "Just a number: 12345678", false},
{"partial", "DE12 is not complete", false},
}
ibanPattern := `DE\d{2}\s?\d{4}\s?\d{4}\s?\d{4}\s?\d{4}\s?\d{2}`
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &PIIRule{
Name: "IBAN",
RuleType: PIIRuleTypeRegex,
Pattern: ibanPattern,
Severity: PIISeverityBlock,
}
detector := &PIIDetector{
compiledRules: make(map[string]*regexp.Regexp),
}
matches := detector.findMatches(tt.text, rule)
hasMatch := len(matches) > 0
if hasMatch != tt.hasIBAN {
t.Errorf("Expected hasIBAN=%v, got %v for text: %s", tt.hasIBAN, hasMatch, tt.text)
}
})
}
}
func TestPIIDetector_KeywordMatching(t *testing.T) {
tests := []struct {
name string
text string
keywords string
expected int
}{
{"single keyword", "The password is secret", "password", 1},
{"multiple keywords", "Password and secret", "password,secret", 2},
{"case insensitive", "PASSWORD and Secret", "password,secret", 2},
{"no match", "This is safe text", "password,secret", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &PIIRule{
Name: "Keywords",
RuleType: PIIRuleTypeKeyword,
Pattern: tt.keywords,
Severity: PIISeverityWarn,
}
detector := &PIIDetector{
compiledRules: make(map[string]*regexp.Regexp),
}
matches := detector.findKeywordMatches(tt.text, rule)
if len(matches) != tt.expected {
t.Errorf("Expected %d matches, got %d for text: %s", tt.expected, len(matches), tt.text)
}
})
}
}
func TestPIIDetector_Redaction(t *testing.T) {
detector := &PIIDetector{
compiledRules: make(map[string]*regexp.Regexp),
}
tests := []struct {
name string
text string
matches []PIIMatch
expected string
}{
{
"single redaction",
"Email: test@example.com",
[]PIIMatch{{StartIndex: 7, EndIndex: 23, Severity: PIISeverityBlock}},
"Email: ****************",
},
{
"no matches",
"Plain text",
[]PIIMatch{},
"Plain text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := detector.RedactText(tt.text, tt.matches)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestCompareSeverity(t *testing.T) {
tests := []struct {
a, b PIISeverity
expected int
}{
{PIISeverityBlock, PIISeverityWarn, 1},
{PIISeverityWarn, PIISeverityBlock, -1},
{PIISeverityBlock, PIISeverityBlock, 0},
{PIISeverityRedact, PIISeverityWarn, 1},
{PIISeverityRedact, PIISeverityBlock, -1},
}
for _, tt := range tests {
t.Run(string(tt.a)+"_vs_"+string(tt.b), func(t *testing.T) {
result := compareSeverity(tt.a, tt.b)
if result != tt.expected {
t.Errorf("Expected %d, got %d for %s vs %s", tt.expected, result, tt.a, tt.b)
}
})
}
}
// =============================================================================
// ENFORCER TESTS
// =============================================================================
func TestExtractDomain(t *testing.T) {
tests := []struct {
name string
url string
expected string
hasError bool
}{
{"full URL", "https://www.example.com/path", "example.com", false},
{"with port", "http://example.com:8080/path", "example.com", false},
{"subdomain", "https://sub.domain.example.com", "sub.domain.example.com", false},
{"no scheme", "example.com/path", "example.com", false},
{"www prefix", "https://www.test.de", "test.de", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractDomain(tt.url)
if tt.hasError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.hasError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
// =============================================================================
// YAML LOADER TESTS
// =============================================================================
func TestParseYAML(t *testing.T) {
yamlData := `
federal:
name: "Test Federal"
sources:
- domain: "test.gov"
name: "Test Source"
license: "§5 UrhG"
trust_boost: 0.9
NI:
name: "Niedersachsen"
sources:
- domain: "ni.gov"
name: "NI Source"
license: "DL-DE-BY-2.0"
default_operations:
lookup:
allowed: true
requires_citation: true
training:
allowed: false
requires_citation: false
pii_rules:
- name: "Test Rule"
type: "regex"
pattern: "test.*pattern"
severity: "block"
`
config, err := ParseYAML([]byte(yamlData))
if err != nil {
t.Fatalf("Failed to parse YAML: %v", err)
}
// Test federal
if config.Federal.Name != "Test Federal" {
t.Errorf("Expected federal name 'Test Federal', got '%s'", config.Federal.Name)
}
if len(config.Federal.Sources) != 1 {
t.Errorf("Expected 1 federal source, got %d", len(config.Federal.Sources))
}
if config.Federal.Sources[0].Domain != "test.gov" {
t.Errorf("Expected domain 'test.gov', got '%s'", config.Federal.Sources[0].Domain)
}
if config.Federal.Sources[0].TrustBoost != 0.9 {
t.Errorf("Expected trust_boost 0.9, got %f", config.Federal.Sources[0].TrustBoost)
}
// Test Bundesland
if len(config.Bundeslaender) != 1 {
t.Errorf("Expected 1 Bundesland, got %d", len(config.Bundeslaender))
}
ni, ok := config.Bundeslaender["NI"]
if !ok {
t.Error("Expected NI in Bundeslaender")
}
if ni.Name != "Niedersachsen" {
t.Errorf("Expected name 'Niedersachsen', got '%s'", ni.Name)
}
// Test operations
if !config.DefaultOperations.Lookup.Allowed {
t.Error("Expected lookup to be allowed")
}
if config.DefaultOperations.Training.Allowed {
t.Error("Expected training to be NOT allowed")
}
// Test PII rules
if len(config.PIIRules) != 1 {
t.Errorf("Expected 1 PII rule, got %d", len(config.PIIRules))
}
if config.PIIRules[0].Name != "Test Rule" {
t.Errorf("Expected rule name 'Test Rule', got '%s'", config.PIIRules[0].Name)
}
}
// =============================================================================
// AUDIT TESTS
// =============================================================================
func TestMaskPII(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"short", "ab", "****"},
{"medium", "test@email.com", "te****om"},
{"long", "very-long-email@example.com", "ve****om"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := maskPII(tt.input)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
// =============================================================================
// DEFAULT PII RULES TEST
// =============================================================================
func TestDefaultPIIRules(t *testing.T) {
rules := DefaultPIIRules()
if len(rules) == 0 {
t.Error("Expected default PII rules, got none")
}
// Check that each rule has required fields
for _, rule := range rules {
if rule.Name == "" {
t.Error("Rule name should not be empty")
}
if rule.Type == "" {
t.Error("Rule type should not be empty")
}
if rule.Pattern == "" {
t.Error("Rule pattern should not be empty")
}
}
// Check for email rule
hasEmailRule := false
for _, rule := range rules {
if rule.Name == "Email Addresses" {
hasEmailRule = true
break
}
}
if !hasEmailRule {
t.Error("Expected email addresses rule in defaults")
}
}
// =============================================================================
// INTEGRATION TEST HELPERS
// =============================================================================
// TestFilteredURL tests the FilteredURL struct.
func TestFilteredURL(t *testing.T) {
fu := FilteredURL{
URL: "https://example.com",
IsAllowed: true,
RequiresCitation: true,
}
if fu.URL != "https://example.com" {
t.Error("URL not set correctly")
}
if !fu.IsAllowed {
t.Error("IsAllowed should be true")
}
if !fu.RequiresCitation {
t.Error("RequiresCitation should be true")
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,369 @@
package publications
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/breakpilot/edu-search-service/internal/database"
"github.com/google/uuid"
)
// CrossRefClient is a client for the CrossRef API
type CrossRefClient struct {
client *http.Client
baseURL string
userAgent string
email string // For polite pool access
}
// CrossRefResponse represents the top-level API response
type CrossRefResponse struct {
Status string `json:"status"`
MessageType string `json:"message-type"`
MessageVersion string `json:"message-version"`
Message CrossRefResult `json:"message"`
}
// CrossRefResult contains the actual results
type CrossRefResult struct {
TotalResults int `json:"total-results"`
Items []CrossRefWork `json:"items"`
Query *CrossRefQuery `json:"query,omitempty"`
}
// CrossRefQuery contains query info
type CrossRefQuery struct {
StartIndex int `json:"start-index"`
SearchTerms string `json:"search-terms"`
}
// CrossRefWork represents a single work/publication
type CrossRefWork struct {
DOI string `json:"DOI"`
Title []string `json:"title"`
ContainerTitle []string `json:"container-title"`
Publisher string `json:"publisher"`
Type string `json:"type"`
Author []CrossRefAuthor `json:"author"`
Issued CrossRefDate `json:"issued"`
PublishedPrint CrossRefDate `json:"published-print"`
Abstract string `json:"abstract"`
URL string `json:"URL"`
Link []CrossRefLink `json:"link"`
Subject []string `json:"subject"`
ISSN []string `json:"ISSN"`
ISBN []string `json:"ISBN"`
IsCitedByCount int `json:"is-referenced-by-count"`
}
// CrossRefAuthor represents an author
type CrossRefAuthor struct {
Given string `json:"given"`
Family string `json:"family"`
ORCID string `json:"ORCID"`
Affiliation []struct {
Name string `json:"name"`
} `json:"affiliation"`
Sequence string `json:"sequence"` // "first" or "additional"
}
// CrossRefDate represents a date
type CrossRefDate struct {
DateParts [][]int `json:"date-parts"`
}
// CrossRefLink represents a link to the work
type CrossRefLink struct {
URL string `json:"URL"`
ContentType string `json:"content-type"`
}
// NewCrossRefClient creates a new CrossRef API client
func NewCrossRefClient(email string) *CrossRefClient {
return &CrossRefClient{
client: &http.Client{
Timeout: 30 * time.Second,
},
baseURL: "https://api.crossref.org",
userAgent: "BreakPilot-EduBot/1.0 (https://breakpilot.de; mailto:" + email + ")",
email: email,
}
}
// GetWorkByDOI retrieves a work by its DOI
func (c *CrossRefClient) GetWorkByDOI(ctx context.Context, doi string) (*database.Publication, error) {
// Clean DOI
doi = strings.TrimSpace(doi)
doi = strings.TrimPrefix(doi, "https://doi.org/")
doi = strings.TrimPrefix(doi, "http://doi.org/")
endpoint := fmt.Sprintf("%s/works/%s", c.baseURL, url.PathEscape(doi))
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", c.userAgent)
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("DOI not found: %s", doi)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("CrossRef API error: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result struct {
Status string `json:"status"`
Message CrossRefWork `json:"message"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return c.convertToPub(&result.Message), nil
}
// SearchByAuthor searches for publications by author name
func (c *CrossRefClient) SearchByAuthor(ctx context.Context, authorName string, limit int) ([]*database.Publication, error) {
if limit <= 0 {
limit = 20
}
endpoint := fmt.Sprintf("%s/works?query.author=%s&rows=%d&sort=published&order=desc",
c.baseURL, url.QueryEscape(authorName), limit)
return c.searchWorks(ctx, endpoint)
}
// SearchByAffiliation searches for publications by affiliation (university)
func (c *CrossRefClient) SearchByAffiliation(ctx context.Context, affiliation string, limit int) ([]*database.Publication, error) {
if limit <= 0 {
limit = 20
}
endpoint := fmt.Sprintf("%s/works?query.affiliation=%s&rows=%d&sort=published&order=desc",
c.baseURL, url.QueryEscape(affiliation), limit)
return c.searchWorks(ctx, endpoint)
}
// SearchByORCID searches for publications by ORCID
func (c *CrossRefClient) SearchByORCID(ctx context.Context, orcid string, limit int) ([]*database.Publication, error) {
if limit <= 0 {
limit = 100
}
// ORCID format: 0000-0000-0000-0000
orcid = strings.TrimPrefix(orcid, "https://orcid.org/")
endpoint := fmt.Sprintf("%s/works?filter=orcid:%s&rows=%d&sort=published&order=desc",
c.baseURL, url.QueryEscape(orcid), limit)
return c.searchWorks(ctx, endpoint)
}
// SearchByTitle searches for publications by title
func (c *CrossRefClient) SearchByTitle(ctx context.Context, title string, limit int) ([]*database.Publication, error) {
if limit <= 0 {
limit = 10
}
endpoint := fmt.Sprintf("%s/works?query.title=%s&rows=%d",
c.baseURL, url.QueryEscape(title), limit)
return c.searchWorks(ctx, endpoint)
}
// searchWorks performs a generic search
func (c *CrossRefClient) searchWorks(ctx context.Context, endpoint string) ([]*database.Publication, error) {
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", c.userAgent)
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("CrossRef API error: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var result CrossRefResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
var pubs []*database.Publication
for _, work := range result.Message.Items {
pubs = append(pubs, c.convertToPub(&work))
}
return pubs, nil
}
// convertToPub converts a CrossRef work to our Publication model
func (c *CrossRefClient) convertToPub(work *CrossRefWork) *database.Publication {
pub := &database.Publication{
ID: uuid.New(),
CitationCount: work.IsCitedByCount,
CrawledAt: time.Now(),
}
// Title
if len(work.Title) > 0 {
pub.Title = work.Title[0]
}
// DOI
if work.DOI != "" {
pub.DOI = &work.DOI
}
// URL
if work.URL != "" {
pub.URL = &work.URL
}
// Abstract (clean HTML)
if work.Abstract != "" {
abstract := cleanHTML(work.Abstract)
pub.Abstract = &abstract
}
// Year
if len(work.Issued.DateParts) > 0 && len(work.Issued.DateParts[0]) > 0 {
year := work.Issued.DateParts[0][0]
pub.Year = &year
if len(work.Issued.DateParts[0]) > 1 {
month := work.Issued.DateParts[0][1]
pub.Month = &month
}
}
// Type
pubType := mapCrossRefType(work.Type)
pub.PubType = &pubType
// Venue
if len(work.ContainerTitle) > 0 {
venue := work.ContainerTitle[0]
pub.Venue = &venue
}
// Publisher
if work.Publisher != "" {
pub.Publisher = &work.Publisher
}
// ISBN
if len(work.ISBN) > 0 {
pub.ISBN = &work.ISBN[0]
}
// ISSN
if len(work.ISSN) > 0 {
pub.ISSN = &work.ISSN[0]
}
// Keywords/Subjects
if len(work.Subject) > 0 {
pub.Keywords = work.Subject
}
// PDF URL
for _, link := range work.Link {
if strings.Contains(link.ContentType, "pdf") {
pub.PDFURL = &link.URL
break
}
}
// Authors
var authors []string
for _, author := range work.Author {
name := strings.TrimSpace(author.Given + " " + author.Family)
if name != "" {
authors = append(authors, name)
}
}
pub.Authors = authors
// Source
source := "crossref"
pub.Source = &source
// Store raw data
rawData, _ := json.Marshal(work)
pub.RawData = rawData
return pub
}
// mapCrossRefType maps CrossRef types to our types
func mapCrossRefType(crType string) string {
switch crType {
case "journal-article":
return "journal"
case "proceedings-article", "conference-paper":
return "conference"
case "book":
return "book"
case "book-chapter":
return "book_chapter"
case "dissertation":
return "thesis"
case "posted-content":
return "preprint"
default:
return "other"
}
}
// cleanHTML removes HTML tags from text
func cleanHTML(html string) string {
// Simple HTML tag removal
result := html
result = strings.ReplaceAll(result, "<jats:p>", "")
result = strings.ReplaceAll(result, "</jats:p>", " ")
result = strings.ReplaceAll(result, "<jats:italic>", "")
result = strings.ReplaceAll(result, "</jats:italic>", "")
result = strings.ReplaceAll(result, "<jats:bold>", "")
result = strings.ReplaceAll(result, "</jats:bold>", "")
result = strings.ReplaceAll(result, "<p>", "")
result = strings.ReplaceAll(result, "</p>", " ")
// Collapse whitespace
result = strings.Join(strings.Fields(result), " ")
return strings.TrimSpace(result)
}

View File

@@ -0,0 +1,268 @@
package publications
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/breakpilot/edu-search-service/internal/database"
"github.com/google/uuid"
)
// PublicationCrawler crawls publications for university staff
type PublicationCrawler struct {
repo *database.Repository
crossref *CrossRefClient
rateLimit time.Duration
mu sync.Mutex
lastRequest time.Time
}
// CrawlResult contains the result of a publication crawl
type CrawlResult struct {
StaffID uuid.UUID
PubsFound int
PubsNew int
PubsUpdated int
Errors []string
Duration time.Duration
}
// NewPublicationCrawler creates a new publication crawler
func NewPublicationCrawler(repo *database.Repository, email string) *PublicationCrawler {
return &PublicationCrawler{
repo: repo,
crossref: NewCrossRefClient(email),
rateLimit: time.Second, // CrossRef polite pool: 50 req/sec max
}
}
// CrawlForStaff crawls publications for a single staff member
func (c *PublicationCrawler) CrawlForStaff(ctx context.Context, staff *database.UniversityStaff) (*CrawlResult, error) {
start := time.Now()
result := &CrawlResult{
StaffID: staff.ID,
}
log.Printf("Starting publication crawl for %s", *staff.FullName)
var pubs []*database.Publication
// Strategy 1: Search by ORCID (most reliable)
if staff.ORCID != nil && *staff.ORCID != "" {
c.waitForRateLimit()
orcidPubs, err := c.crossref.SearchByORCID(ctx, *staff.ORCID, 100)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("ORCID search error: %v", err))
} else {
pubs = append(pubs, orcidPubs...)
log.Printf("Found %d publications via ORCID for %s", len(orcidPubs), *staff.FullName)
}
}
// Strategy 2: Search by author name
if staff.FullName != nil && *staff.FullName != "" {
c.waitForRateLimit()
namePubs, err := c.crossref.SearchByAuthor(ctx, *staff.FullName, 50)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Name search error: %v", err))
} else {
// Deduplicate
for _, pub := range namePubs {
if !containsPub(pubs, pub) {
pubs = append(pubs, pub)
}
}
log.Printf("Found %d additional publications via name search for %s", len(namePubs), *staff.FullName)
}
}
// Save publications and create links
for _, pub := range pubs {
// Save publication
err := c.repo.CreatePublication(ctx, pub)
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Save error for %s: %v", pub.Title, err))
continue
}
result.PubsFound++
// Link to staff
link := &database.StaffPublication{
StaffID: staff.ID,
PublicationID: pub.ID,
}
// Determine author position
pos := findAuthorPosition(pub, staff)
if pos > 0 {
link.AuthorPosition = &pos
}
if err := c.repo.LinkStaffPublication(ctx, link); err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("Link error: %v", err))
}
}
result.Duration = time.Since(start)
log.Printf("Completed publication crawl for %s: found=%d, duration=%v",
*staff.FullName, result.PubsFound, result.Duration)
return result, nil
}
// CrawlForUniversity crawls publications for all staff at a university
func (c *PublicationCrawler) CrawlForUniversity(ctx context.Context, uniID uuid.UUID, limit int) (*database.UniversityCrawlStatus, error) {
log.Printf("Starting publication crawl for university %s", uniID)
// Get staff with ORCID first (more reliable)
params := database.StaffSearchParams{
UniversityID: &uniID,
Limit: limit,
}
result, err := c.repo.SearchStaff(ctx, params)
if err != nil {
return nil, err
}
status := &database.UniversityCrawlStatus{
UniversityID: uniID,
PubCrawlStatus: "running",
}
var totalPubs int
var errors []string
for _, staff := range result.Staff {
select {
case <-ctx.Done():
status.PubCrawlStatus = "cancelled"
status.PubErrors = append(errors, "Crawl cancelled")
return status, ctx.Err()
default:
}
crawlResult, err := c.CrawlForStaff(ctx, &staff)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: %v", staff.LastName, err))
continue
}
totalPubs += crawlResult.PubsFound
errors = append(errors, crawlResult.Errors...)
}
now := time.Now()
status.LastPubCrawl = &now
status.PubCrawlStatus = "completed"
status.PubCount = totalPubs
status.PubErrors = errors
// Update status in database
if err := c.repo.UpdateCrawlStatus(ctx, status); err != nil {
log.Printf("Warning: Failed to update crawl status: %v", err)
}
log.Printf("Completed publication crawl for university %s: %d publications found", uniID, totalPubs)
return status, nil
}
// ResolveDOI resolves a DOI and saves the publication
func (c *PublicationCrawler) ResolveDOI(ctx context.Context, doi string) (*database.Publication, error) {
c.waitForRateLimit()
pub, err := c.crossref.GetWorkByDOI(ctx, doi)
if err != nil {
return nil, err
}
if err := c.repo.CreatePublication(ctx, pub); err != nil {
return nil, err
}
return pub, nil
}
// waitForRateLimit enforces rate limiting
func (c *PublicationCrawler) waitForRateLimit() {
c.mu.Lock()
defer c.mu.Unlock()
elapsed := time.Since(c.lastRequest)
if elapsed < c.rateLimit {
time.Sleep(c.rateLimit - elapsed)
}
c.lastRequest = time.Now()
}
// containsPub checks if a publication is already in the list (by DOI or title)
func containsPub(pubs []*database.Publication, pub *database.Publication) bool {
for _, existing := range pubs {
// Check DOI
if pub.DOI != nil && existing.DOI != nil && *pub.DOI == *existing.DOI {
return true
}
// Check title (rough match)
if pub.Title == existing.Title {
return true
}
}
return false
}
// findAuthorPosition finds the position of a staff member in the author list
func findAuthorPosition(pub *database.Publication, staff *database.UniversityStaff) int {
for i, author := range pub.Authors {
// Check if author name matches staff
if staff.LastName != "" && containsIgnoreCase(author, staff.LastName) {
return i + 1
}
}
return 0
}
// containsIgnoreCase checks if s contains substr (case insensitive)
func containsIgnoreCase(s, substr string) bool {
return len(s) >= len(substr) &&
(s == substr ||
len(substr) == 0 ||
(len(s) > 0 && containsIgnoreCaseHelper(s, substr)))
}
func containsIgnoreCaseHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if equalFold(s[i:i+len(substr)], substr) {
return true
}
}
return false
}
func equalFold(s1, s2 string) bool {
if len(s1) != len(s2) {
return false
}
for i := 0; i < len(s1); i++ {
c1, c2 := s1[i], s2[i]
if c1 != c2 {
// Simple ASCII case folding
if c1 >= 'A' && c1 <= 'Z' {
c1 += 'a' - 'A'
}
if c2 >= 'A' && c2 <= 'Z' {
c2 += 'a' - 'A'
}
if c1 != c2 {
return false
}
}
}
return true
}

View File

@@ -0,0 +1,188 @@
package publications
import (
"testing"
"github.com/breakpilot/edu-search-service/internal/database"
)
func TestContainsPub_ByDOI(t *testing.T) {
doi1 := "10.1000/test1"
doi2 := "10.1000/test2"
doi3 := "10.1000/test3"
pubs := []*database.Publication{
{Title: "Paper 1", DOI: &doi1},
{Title: "Paper 2", DOI: &doi2},
}
tests := []struct {
name string
pub *database.Publication
expected bool
}{
{
name: "DOI exists in list",
pub: &database.Publication{Title: "Different Title", DOI: &doi1},
expected: true,
},
{
name: "DOI does not exist",
pub: &database.Publication{Title: "New Paper", DOI: &doi3},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsPub(pubs, tt.pub)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestContainsPub_ByTitle(t *testing.T) {
pubs := []*database.Publication{
{Title: "Machine Learning Applications"},
{Title: "Deep Neural Networks"},
}
tests := []struct {
name string
pub *database.Publication
expected bool
}{
{
name: "Title exists in list",
pub: &database.Publication{Title: "Machine Learning Applications"},
expected: true,
},
{
name: "Title does not exist",
pub: &database.Publication{Title: "New Research Paper"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsPub(pubs, tt.pub)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestContainsIgnoreCase(t *testing.T) {
tests := []struct {
name string
s string
substr string
expected bool
}{
{"Exact match", "Hello World", "Hello", true},
{"Case insensitive", "Hello World", "hello", true},
{"Case insensitive uppercase", "HELLO WORLD", "world", true},
{"Substring in middle", "The quick brown fox", "brown", true},
{"No match", "Hello World", "xyz", false},
{"Empty substring", "Hello", "", true},
{"Empty string", "", "test", false},
{"Both empty", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsIgnoreCase(tt.s, tt.substr)
if result != tt.expected {
t.Errorf("containsIgnoreCase(%q, %q) = %v, expected %v",
tt.s, tt.substr, result, tt.expected)
}
})
}
}
func TestEqualFold(t *testing.T) {
tests := []struct {
name string
s1 string
s2 string
expected bool
}{
{"Same string", "hello", "hello", true},
{"Different case", "Hello", "hello", true},
{"All uppercase", "HELLO", "hello", true},
{"Mixed case", "HeLLo", "hEllO", true},
{"Different strings", "hello", "world", false},
{"Different length", "hello", "hi", false},
{"Empty strings", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := equalFold(tt.s1, tt.s2)
if result != tt.expected {
t.Errorf("equalFold(%q, %q) = %v, expected %v",
tt.s1, tt.s2, result, tt.expected)
}
})
}
}
func TestFindAuthorPosition(t *testing.T) {
pub := &database.Publication{
Title: "Test Paper",
Authors: []string{
"John Smith",
"Maria Müller",
"Hans Weber",
},
}
tests := []struct {
name string
staff *database.UniversityStaff
expected int
}{
{
name: "First author",
staff: &database.UniversityStaff{
LastName: "Smith",
},
expected: 1,
},
{
name: "Second author",
staff: &database.UniversityStaff{
LastName: "Müller",
},
expected: 2,
},
{
name: "Third author",
staff: &database.UniversityStaff{
LastName: "Weber",
},
expected: 3,
},
{
name: "Author not found",
staff: &database.UniversityStaff{
LastName: "Unknown",
},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findAuthorPosition(pub, tt.staff)
if result != tt.expected {
t.Errorf("Expected position %d, got %d for author %s",
tt.expected, result, tt.staff.LastName)
}
})
}
}

View File

@@ -0,0 +1,326 @@
package quality
import (
"regexp"
"strings"
)
// Scorer calculates quality scores for documents
type Scorer struct {
weights Weights
}
// Weights defines the contribution of each factor to the quality score
type Weights struct {
ContentLength float64 // 0.20 - longer content often more valuable
HeadingStructure float64 // 0.15 - well-structured documents
LinkQuality float64 // 0.15 - low ad/external link density
TextToHTMLRatio float64 // 0.15 - content-rich pages
MetadataPresence float64 // 0.10 - proper title, description
LanguageClarity float64 // 0.10 - German content, no mixed languages
ContentFreshness float64 // 0.10 - indication of update/recency
PDFSpecific float64 // 0.05 - PDF-specific quality signals
}
// DefaultWeights returns the default quality score weights
func DefaultWeights() Weights {
return Weights{
ContentLength: 0.20,
HeadingStructure: 0.15,
LinkQuality: 0.15,
TextToHTMLRatio: 0.15,
MetadataPresence: 0.10,
LanguageClarity: 0.10,
ContentFreshness: 0.10,
PDFSpecific: 0.05,
}
}
// ContentFeatures holds extracted features for quality scoring
type ContentFeatures struct {
ContentLength int
HeadingCount int
HeadingDepth int // max heading level depth (h1-h6)
LinkDensity float64
AdDensity float64
TextToHTMLRatio float64
HasTitle bool
HasDescription bool
HasCanonical bool
Language string
IsPDF bool
PageCount int // for PDFs
HasTOC bool // table of contents
DateIndicators []string // found date patterns
}
// Score represents the quality score breakdown
type Score struct {
Total float64 `json:"total"`
ContentLength float64 `json:"content_length"`
HeadingStructure float64 `json:"heading_structure"`
LinkQuality float64 `json:"link_quality"`
TextToHTMLRatio float64 `json:"text_html_ratio"`
MetadataPresence float64 `json:"metadata_presence"`
LanguageClarity float64 `json:"language_clarity"`
ContentFreshness float64 `json:"content_freshness"`
PDFSpecific float64 `json:"pdf_specific"`
}
// NewScorer creates a quality scorer with default weights
func NewScorer() *Scorer {
return &Scorer{weights: DefaultWeights()}
}
// NewScorerWithWeights creates a scorer with custom weights
func NewScorerWithWeights(w Weights) *Scorer {
return &Scorer{weights: w}
}
// Calculate computes the quality score for given features
func (s *Scorer) Calculate(features ContentFeatures) Score {
score := Score{}
// 1. Content Length Score (0-1)
score.ContentLength = s.calculateContentLengthScore(features.ContentLength)
// 2. Heading Structure Score (0-1)
score.HeadingStructure = s.calculateHeadingScore(features.HeadingCount, features.HeadingDepth, features.HasTOC)
// 3. Link Quality Score (0-1)
score.LinkQuality = s.calculateLinkQualityScore(features.LinkDensity, features.AdDensity)
// 4. Text to HTML Ratio Score (0-1)
score.TextToHTMLRatio = s.calculateTextRatioScore(features.TextToHTMLRatio)
// 5. Metadata Presence Score (0-1)
score.MetadataPresence = s.calculateMetadataScore(features.HasTitle, features.HasDescription, features.HasCanonical)
// 6. Language Clarity Score (0-1)
score.LanguageClarity = s.calculateLanguageScore(features.Language)
// 7. Content Freshness Score (0-1)
score.ContentFreshness = s.calculateFreshnessScore(features.DateIndicators)
// 8. PDF-Specific Score (0-1)
if features.IsPDF {
score.PDFSpecific = s.calculatePDFScore(features.PageCount, features.ContentLength)
} else {
score.PDFSpecific = 1.0 // full score for non-PDFs (no penalty)
}
// Calculate weighted total
score.Total = score.ContentLength*s.weights.ContentLength +
score.HeadingStructure*s.weights.HeadingStructure +
score.LinkQuality*s.weights.LinkQuality +
score.TextToHTMLRatio*s.weights.TextToHTMLRatio +
score.MetadataPresence*s.weights.MetadataPresence +
score.LanguageClarity*s.weights.LanguageClarity +
score.ContentFreshness*s.weights.ContentFreshness +
score.PDFSpecific*s.weights.PDFSpecific
// Clamp to 0-1
if score.Total > 1.0 {
score.Total = 1.0
}
if score.Total < 0 {
score.Total = 0
}
return score
}
// calculateContentLengthScore scores based on content length
func (s *Scorer) calculateContentLengthScore(length int) float64 {
// Optimal range: 1000-10000 characters
// Too short (<500): low quality
// Too long (>20000): might be noise/boilerplate
switch {
case length < 200:
return 0.1
case length < 500:
return 0.3
case length < 1000:
return 0.6
case length < 3000:
return 0.8
case length < 10000:
return 1.0
case length < 20000:
return 0.9
default:
return 0.7 // very long documents might have quality issues
}
}
// calculateHeadingScore scores heading structure
func (s *Scorer) calculateHeadingScore(count, depth int, hasTOC bool) float64 {
score := 0.0
// Headings present
if count > 0 {
score += 0.4
}
if count >= 3 {
score += 0.2
}
// Depth variety (proper hierarchy)
if depth >= 2 {
score += 0.2
}
// Table of contents indicates well-structured document
if hasTOC {
score += 0.2
}
if score > 1.0 {
score = 1.0
}
return score
}
// calculateLinkQualityScore scores based on link/ad density
func (s *Scorer) calculateLinkQualityScore(linkDensity, adDensity float64) float64 {
score := 1.0
// High link density is bad
if linkDensity > 0.3 {
score -= 0.3
} else if linkDensity > 0.2 {
score -= 0.1
}
// Any ad density is bad
if adDensity > 0.1 {
score -= 0.4
} else if adDensity > 0.05 {
score -= 0.2
} else if adDensity > 0 {
score -= 0.1
}
if score < 0 {
score = 0
}
return score
}
// calculateTextRatioScore scores text to HTML ratio
func (s *Scorer) calculateTextRatioScore(ratio float64) float64 {
// Good ratio: 0.2-0.6
// Too low: too much markup
// Too high: might be plain text dump
switch {
case ratio < 0.1:
return 0.3
case ratio < 0.2:
return 0.6
case ratio < 0.6:
return 1.0
case ratio < 0.8:
return 0.8
default:
return 0.6
}
}
// calculateMetadataScore scores presence of metadata
func (s *Scorer) calculateMetadataScore(hasTitle, hasDescription, hasCanonical bool) float64 {
score := 0.0
if hasTitle {
score += 0.5
}
if hasDescription {
score += 0.3
}
if hasCanonical {
score += 0.2
}
return score
}
// calculateLanguageScore scores language clarity
func (s *Scorer) calculateLanguageScore(language string) float64 {
switch strings.ToLower(language) {
case "de", "german", "deutsch":
return 1.0
case "en", "english", "englisch":
return 0.8 // English is acceptable
case "":
return 0.5 // unknown
default:
return 0.3 // other languages
}
}
// calculateFreshnessScore scores content freshness indicators
func (s *Scorer) calculateFreshnessScore(dateIndicators []string) float64 {
if len(dateIndicators) == 0 {
return 0.5 // neutral
}
// Check for recent years (2020+)
recentYearPattern := regexp.MustCompile(`202[0-5]`)
for _, indicator := range dateIndicators {
if recentYearPattern.MatchString(indicator) {
return 1.0
}
}
// Check for 2015-2019
modernPattern := regexp.MustCompile(`201[5-9]`)
for _, indicator := range dateIndicators {
if modernPattern.MatchString(indicator) {
return 0.7
}
}
// Older content
return 0.4
}
// calculatePDFScore scores PDF-specific quality
func (s *Scorer) calculatePDFScore(pageCount, contentLength int) float64 {
score := 0.5 // base
// Page count bonus
if pageCount > 1 {
score += 0.2
}
if pageCount > 5 {
score += 0.1
}
// Text extraction success
if contentLength > 100 {
score += 0.2
}
if score > 1.0 {
score = 1.0
}
return score
}
// ExtractDateIndicators finds date patterns in text
func ExtractDateIndicators(text string) []string {
var indicators []string
// Pattern: DD.MM.YYYY or YYYY-MM-DD
datePatterns := []*regexp.Regexp{
regexp.MustCompile(`\d{2}\.\d{2}\.\d{4}`),
regexp.MustCompile(`\d{4}-\d{2}-\d{2}`),
regexp.MustCompile(`\b20[012][0-9]\b`), // years 2000-2029
}
for _, pattern := range datePatterns {
matches := pattern.FindAllString(text, 5) // limit matches
indicators = append(indicators, matches...)
}
return indicators
}

View File

@@ -0,0 +1,333 @@
package quality
import (
"testing"
)
func TestNewScorer(t *testing.T) {
scorer := NewScorer()
if scorer == nil {
t.Fatal("Expected non-nil scorer")
}
}
func TestNewScorerWithWeights(t *testing.T) {
weights := Weights{
ContentLength: 0.5,
HeadingStructure: 0.5,
}
scorer := NewScorerWithWeights(weights)
if scorer.weights.ContentLength != 0.5 {
t.Errorf("Expected weight 0.5, got %f", scorer.weights.ContentLength)
}
}
func TestCalculate_HighQualityDocument(t *testing.T) {
scorer := NewScorer()
features := ContentFeatures{
ContentLength: 5000,
HeadingCount: 5,
HeadingDepth: 3,
LinkDensity: 0.1,
AdDensity: 0,
TextToHTMLRatio: 0.4,
HasTitle: true,
HasDescription: true,
HasCanonical: true,
Language: "de",
DateIndicators: []string{"2024-01-15"},
}
score := scorer.Calculate(features)
if score.Total < 0.8 {
t.Errorf("Expected high quality score (>0.8), got %f", score.Total)
}
}
func TestCalculate_LowQualityDocument(t *testing.T) {
scorer := NewScorer()
features := ContentFeatures{
ContentLength: 100,
HeadingCount: 0,
LinkDensity: 0.5,
AdDensity: 0.2,
TextToHTMLRatio: 0.05,
HasTitle: false,
HasDescription: false,
Language: "",
}
score := scorer.Calculate(features)
if score.Total > 0.5 {
t.Errorf("Expected low quality score (<0.5), got %f", score.Total)
}
}
func TestCalculateContentLengthScore(t *testing.T) {
scorer := NewScorer()
tests := []struct {
length int
minScore float64
maxScore float64
}{
{100, 0.0, 0.2}, // very short
{500, 0.5, 0.7}, // short-medium
{2000, 0.7, 0.9}, // good
{5000, 0.9, 1.0}, // optimal
{30000, 0.6, 0.8}, // very long
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
score := scorer.calculateContentLengthScore(tt.length)
if score < tt.minScore || score > tt.maxScore {
t.Errorf("Length %d: expected score in [%f, %f], got %f",
tt.length, tt.minScore, tt.maxScore, score)
}
})
}
}
func TestCalculateHeadingScore(t *testing.T) {
scorer := NewScorer()
// No headings
score := scorer.calculateHeadingScore(0, 0, false)
if score > 0.1 {
t.Errorf("Expected low score for no headings, got %f", score)
}
// Good heading structure
score = scorer.calculateHeadingScore(5, 3, true)
if score < 0.9 {
t.Errorf("Expected high score for good headings, got %f", score)
}
}
func TestCalculateLinkQualityScore(t *testing.T) {
scorer := NewScorer()
// Good: low link and ad density
score := scorer.calculateLinkQualityScore(0.1, 0)
if score < 0.9 {
t.Errorf("Expected high score for good link quality, got %f", score)
}
// Bad: high ad density
score = scorer.calculateLinkQualityScore(0.1, 0.2)
if score > 0.6 {
t.Errorf("Expected low score for high ad density, got %f", score)
}
}
func TestCalculateTextRatioScore(t *testing.T) {
scorer := NewScorer()
tests := []struct {
ratio float64
minScore float64
}{
{0.05, 0.0}, // too low
{0.3, 0.9}, // optimal
{0.9, 0.5}, // too high (plain text dump)
}
for _, tt := range tests {
score := scorer.calculateTextRatioScore(tt.ratio)
if score < tt.minScore {
t.Errorf("Ratio %f: expected score >= %f, got %f", tt.ratio, tt.minScore, score)
}
}
}
func TestCalculateMetadataScore(t *testing.T) {
scorer := NewScorer()
// All metadata present
score := scorer.calculateMetadataScore(true, true, true)
if score != 1.0 {
t.Errorf("Expected 1.0 for all metadata, got %f", score)
}
// No metadata
score = scorer.calculateMetadataScore(false, false, false)
if score != 0.0 {
t.Errorf("Expected 0.0 for no metadata, got %f", score)
}
// Only title
score = scorer.calculateMetadataScore(true, false, false)
if score != 0.5 {
t.Errorf("Expected 0.5 for only title, got %f", score)
}
}
func TestCalculateLanguageScore(t *testing.T) {
scorer := NewScorer()
tests := []struct {
language string
expected float64
}{
{"de", 1.0},
{"german", 1.0},
{"en", 0.8},
{"", 0.5},
{"fr", 0.3},
}
for _, tt := range tests {
score := scorer.calculateLanguageScore(tt.language)
if score != tt.expected {
t.Errorf("Language '%s': expected %f, got %f", tt.language, tt.expected, score)
}
}
}
func TestCalculateFreshnessScore(t *testing.T) {
scorer := NewScorer()
// Recent date
score := scorer.calculateFreshnessScore([]string{"2024-06-15"})
if score < 0.9 {
t.Errorf("Expected high score for recent date, got %f", score)
}
// Older date
score = scorer.calculateFreshnessScore([]string{"2016-01-01"})
if score > 0.8 {
t.Errorf("Expected moderate score for 2016, got %f", score)
}
// No date indicators
score = scorer.calculateFreshnessScore(nil)
if score != 0.5 {
t.Errorf("Expected neutral score for no dates, got %f", score)
}
}
func TestCalculatePDFScore(t *testing.T) {
scorer := NewScorer()
// Multi-page PDF with good content
score := scorer.calculatePDFScore(10, 5000)
if score < 0.8 {
t.Errorf("Expected high score for good PDF, got %f", score)
}
// Single page, little content
score = scorer.calculatePDFScore(1, 50)
if score > 0.6 {
t.Errorf("Expected lower score for poor PDF, got %f", score)
}
}
func TestExtractDateIndicators(t *testing.T) {
text := "Lehrplan gültig ab 01.08.2023 - Stand: 2024-01-15. Aktualisiert 2024."
indicators := ExtractDateIndicators(text)
if len(indicators) == 0 {
t.Error("Expected to find date indicators")
}
// Should find at least the year patterns
found2024 := false
for _, ind := range indicators {
if ind == "2024" || ind == "2023" || ind == "2024-01-15" || ind == "01.08.2023" {
found2024 = true
}
}
if !found2024 {
t.Errorf("Expected to find 2024 or 2023, got: %v", indicators)
}
}
func TestExtractDateIndicators_Empty(t *testing.T) {
text := "This text has no dates whatsoever."
indicators := ExtractDateIndicators(text)
if len(indicators) != 0 {
t.Errorf("Expected no indicators, got: %v", indicators)
}
}
func TestCalculate_PDFDocument(t *testing.T) {
scorer := NewScorer()
features := ContentFeatures{
ContentLength: 3000,
HeadingCount: 3,
HeadingDepth: 2,
Language: "de",
IsPDF: true,
PageCount: 8,
DateIndicators: []string{"2023"},
}
score := scorer.Calculate(features)
// PDF with 8 pages and good content should score well
if score.PDFSpecific < 0.8 {
t.Errorf("Expected good PDF-specific score, got %f", score.PDFSpecific)
}
if score.Total < 0.5 {
t.Errorf("Expected reasonable score for PDF, got %f", score.Total)
}
}
func TestCalculate_ScoreClamping(t *testing.T) {
scorer := NewScorer()
// Even with all perfect scores, total should not exceed 1.0
features := ContentFeatures{
ContentLength: 5000,
HeadingCount: 10,
HeadingDepth: 4,
HasTOC: true,
LinkDensity: 0,
AdDensity: 0,
TextToHTMLRatio: 0.4,
HasTitle: true,
HasDescription: true,
HasCanonical: true,
Language: "de",
DateIndicators: []string{"2024"},
}
score := scorer.Calculate(features)
if score.Total > 1.0 {
t.Errorf("Score should be clamped to 1.0, got %f", score.Total)
}
if score.Total < 0 {
t.Errorf("Score should not be negative, got %f", score.Total)
}
}
func TestDefaultWeights(t *testing.T) {
weights := DefaultWeights()
// Sum should be approximately 1.0
sum := weights.ContentLength +
weights.HeadingStructure +
weights.LinkQuality +
weights.TextToHTMLRatio +
weights.MetadataPresence +
weights.LanguageClarity +
weights.ContentFreshness +
weights.PDFSpecific
if sum < 0.99 || sum > 1.01 {
t.Errorf("Default weights should sum to 1.0, got %f", sum)
}
}

View File

@@ -0,0 +1,282 @@
package robots
import (
"bufio"
"context"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
)
// Checker handles robots.txt parsing and checking
type Checker struct {
mu sync.RWMutex
cache map[string]*RobotsData
userAgent string
client *http.Client
cacheTTL time.Duration
}
// RobotsData holds parsed robots.txt data for a host
type RobotsData struct {
DisallowPatterns []string
AllowPatterns []string
CrawlDelay int // seconds
FetchedAt time.Time
Error error
}
// NewChecker creates a new robots.txt checker
func NewChecker(userAgent string) *Checker {
return &Checker{
cache: make(map[string]*RobotsData),
userAgent: userAgent,
client: &http.Client{
Timeout: 10 * time.Second,
},
cacheTTL: 24 * time.Hour, // Cache robots.txt for 24 hours
}
}
// IsAllowed checks if a URL is allowed to be crawled
func (c *Checker) IsAllowed(ctx context.Context, urlStr string) (bool, error) {
u, err := url.Parse(urlStr)
if err != nil {
return false, fmt.Errorf("invalid URL: %w", err)
}
host := u.Host
path := u.Path
if path == "" {
path = "/"
}
// Get or fetch robots.txt
robotsData, err := c.getRobotsData(ctx, u.Scheme, host)
if err != nil {
// If we can't fetch robots.txt, assume allowed (be lenient)
return true, nil
}
// If there was an error fetching robots.txt, allow crawling
if robotsData.Error != nil {
return true, nil
}
// Check allow rules first (they take precedence)
for _, pattern := range robotsData.AllowPatterns {
if matchPattern(pattern, path) {
return true, nil
}
}
// Check disallow rules
for _, pattern := range robotsData.DisallowPatterns {
if matchPattern(pattern, path) {
return false, nil
}
}
// If no rules match, allow
return true, nil
}
// GetCrawlDelay returns the crawl delay for a host
func (c *Checker) GetCrawlDelay(ctx context.Context, urlStr string) (int, error) {
u, err := url.Parse(urlStr)
if err != nil {
return 0, err
}
robotsData, err := c.getRobotsData(ctx, u.Scheme, u.Host)
if err != nil || robotsData.Error != nil {
return 0, nil
}
return robotsData.CrawlDelay, nil
}
// getRobotsData fetches and caches robots.txt for a host
func (c *Checker) getRobotsData(ctx context.Context, scheme, host string) (*RobotsData, error) {
c.mu.RLock()
data, exists := c.cache[host]
c.mu.RUnlock()
// Return cached data if not expired
if exists && time.Since(data.FetchedAt) < c.cacheTTL {
return data, nil
}
// Fetch robots.txt
robotsURL := fmt.Sprintf("%s://%s/robots.txt", scheme, host)
data = c.fetchRobots(ctx, robotsURL)
// Cache the result
c.mu.Lock()
c.cache[host] = data
c.mu.Unlock()
return data, nil
}
// fetchRobots fetches and parses robots.txt
func (c *Checker) fetchRobots(ctx context.Context, robotsURL string) *RobotsData {
data := &RobotsData{
FetchedAt: time.Now(),
}
req, err := http.NewRequestWithContext(ctx, "GET", robotsURL, nil)
if err != nil {
data.Error = err
return data
}
req.Header.Set("User-Agent", c.userAgent)
resp, err := c.client.Do(req)
if err != nil {
data.Error = err
return data
}
defer resp.Body.Close()
// If robots.txt doesn't exist, allow everything
if resp.StatusCode == http.StatusNotFound {
return data
}
if resp.StatusCode != http.StatusOK {
data.Error = fmt.Errorf("HTTP %d", resp.StatusCode)
return data
}
// Parse the robots.txt
c.parseRobotsTxt(data, resp.Body)
return data
}
// parseRobotsTxt parses robots.txt content
func (c *Checker) parseRobotsTxt(data *RobotsData, reader io.Reader) {
scanner := bufio.NewScanner(reader)
// Track which user-agent section we're in
inRelevantSection := false
inWildcardSection := false
// Normalize our user agent for matching
ourAgent := strings.ToLower(c.userAgent)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Split on first colon
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
directive := strings.ToLower(strings.TrimSpace(parts[0]))
value := strings.TrimSpace(parts[1])
// Remove inline comments
if idx := strings.Index(value, "#"); idx >= 0 {
value = strings.TrimSpace(value[:idx])
}
switch directive {
case "user-agent":
agent := strings.ToLower(value)
if agent == "*" {
inWildcardSection = true
inRelevantSection = false
} else if strings.Contains(ourAgent, agent) || strings.Contains(agent, "breakpilot") || strings.Contains(agent, "edubot") {
inRelevantSection = true
} else {
inRelevantSection = false
inWildcardSection = false
}
case "disallow":
if value != "" && (inRelevantSection || inWildcardSection) {
data.DisallowPatterns = append(data.DisallowPatterns, value)
}
case "allow":
if value != "" && (inRelevantSection || inWildcardSection) {
data.AllowPatterns = append(data.AllowPatterns, value)
}
case "crawl-delay":
if inRelevantSection || inWildcardSection {
var delay int
fmt.Sscanf(value, "%d", &delay)
if delay > 0 {
data.CrawlDelay = delay
}
}
}
}
}
// matchPattern matches a URL path against a robots.txt pattern
func matchPattern(pattern, path string) bool {
// Empty pattern matches nothing
if pattern == "" {
return false
}
// Handle wildcards
if strings.Contains(pattern, "*") {
// Convert to regex
regexPattern := regexp.QuoteMeta(pattern)
regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*")
// Handle $ at end (exact match)
if strings.HasSuffix(regexPattern, `\$`) {
regexPattern = strings.TrimSuffix(regexPattern, `\$`) + "$"
}
re, err := regexp.Compile("^" + regexPattern)
if err != nil {
return false
}
return re.MatchString(path)
}
// Handle $ (exact end match)
if strings.HasSuffix(pattern, "$") {
return path == strings.TrimSuffix(pattern, "$")
}
// Simple prefix match
return strings.HasPrefix(path, pattern)
}
// ClearCache clears the robots.txt cache
func (c *Checker) ClearCache() {
c.mu.Lock()
c.cache = make(map[string]*RobotsData)
c.mu.Unlock()
}
// CacheStats returns cache statistics
func (c *Checker) CacheStats() (count int, hosts []string) {
c.mu.RLock()
defer c.mu.RUnlock()
for host := range c.cache {
hosts = append(hosts, host)
}
return len(c.cache), hosts
}

View File

@@ -0,0 +1,324 @@
package robots
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewChecker(t *testing.T) {
checker := NewChecker("TestBot/1.0")
if checker == nil {
t.Fatal("Expected non-nil checker")
}
}
func TestIsAllowed_NoRobots(t *testing.T) {
// Server that returns 404 for robots.txt
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
allowed, err := checker.IsAllowed(context.Background(), server.URL+"/some/page")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !allowed {
t.Error("Should be allowed when robots.txt doesn't exist")
}
}
func TestIsAllowed_AllowAll(t *testing.T) {
robotsTxt := `User-agent: *
Allow: /
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/robots.txt" {
w.Write([]byte(robotsTxt))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
allowed, _ := checker.IsAllowed(context.Background(), server.URL+"/any/path")
if !allowed {
t.Error("Should be allowed with Allow: /")
}
}
func TestIsAllowed_DisallowPath(t *testing.T) {
robotsTxt := `User-agent: *
Disallow: /private/
Disallow: /admin/
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/robots.txt" {
w.Write([]byte(robotsTxt))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
// Should be disallowed
allowed, _ := checker.IsAllowed(context.Background(), server.URL+"/private/secret")
if allowed {
t.Error("/private/secret should be disallowed")
}
allowed, _ = checker.IsAllowed(context.Background(), server.URL+"/admin/users")
if allowed {
t.Error("/admin/users should be disallowed")
}
// Should be allowed
allowed, _ = checker.IsAllowed(context.Background(), server.URL+"/public/page")
if !allowed {
t.Error("/public/page should be allowed")
}
}
func TestIsAllowed_AllowTakesPrecedence(t *testing.T) {
robotsTxt := `User-agent: *
Disallow: /api/
Allow: /api/public/
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/robots.txt" {
w.Write([]byte(robotsTxt))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
// Allow takes precedence
allowed, _ := checker.IsAllowed(context.Background(), server.URL+"/api/public/docs")
if !allowed {
t.Error("/api/public/docs should be allowed (Allow takes precedence)")
}
// Still disallowed
allowed, _ = checker.IsAllowed(context.Background(), server.URL+"/api/internal")
if allowed {
t.Error("/api/internal should be disallowed")
}
}
func TestIsAllowed_SpecificUserAgent(t *testing.T) {
robotsTxt := `User-agent: BadBot
Disallow: /
User-agent: *
Allow: /
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/robots.txt" {
w.Write([]byte(robotsTxt))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
checker := NewChecker("GoodBot/1.0")
allowed, _ := checker.IsAllowed(context.Background(), server.URL+"/page")
if !allowed {
t.Error("GoodBot should be allowed")
}
}
func TestGetCrawlDelay(t *testing.T) {
robotsTxt := `User-agent: *
Crawl-delay: 5
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/robots.txt" {
w.Write([]byte(robotsTxt))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
delay, err := checker.GetCrawlDelay(context.Background(), server.URL+"/page")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if delay != 5 {
t.Errorf("Expected delay 5, got %d", delay)
}
}
func TestMatchPattern_Simple(t *testing.T) {
tests := []struct {
pattern string
path string
match bool
}{
{"/private/", "/private/secret", true},
{"/private/", "/public/", false},
{"/", "/anything", true},
{"", "/anything", false},
}
for _, tt := range tests {
result := matchPattern(tt.pattern, tt.path)
if result != tt.match {
t.Errorf("Pattern '%s' vs Path '%s': expected %v, got %v",
tt.pattern, tt.path, tt.match, result)
}
}
}
func TestMatchPattern_Wildcard(t *testing.T) {
tests := []struct {
pattern string
path string
match bool
}{
{"/*.pdf", "/document.pdf", true},
{"/*.pdf", "/folder/doc.pdf", true},
{"/*.pdf", "/document.html", false},
{"/dir/*/page", "/dir/sub/page", true},
{"/dir/*/page", "/dir/other/page", true},
}
for _, tt := range tests {
result := matchPattern(tt.pattern, tt.path)
if result != tt.match {
t.Errorf("Pattern '%s' vs Path '%s': expected %v, got %v",
tt.pattern, tt.path, tt.match, result)
}
}
}
func TestMatchPattern_EndAnchor(t *testing.T) {
tests := []struct {
pattern string
path string
match bool
}{
{"/exact$", "/exact", true},
{"/exact$", "/exactmore", false},
{"/exact$", "/exact/more", false},
}
for _, tt := range tests {
result := matchPattern(tt.pattern, tt.path)
if result != tt.match {
t.Errorf("Pattern '%s' vs Path '%s': expected %v, got %v",
tt.pattern, tt.path, tt.match, result)
}
}
}
func TestCacheStats(t *testing.T) {
robotsTxt := `User-agent: *
Allow: /
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(robotsTxt))
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
// Initially empty
count, _ := checker.CacheStats()
if count != 0 {
t.Errorf("Expected 0 cached entries, got %d", count)
}
// Fetch robots.txt
checker.IsAllowed(context.Background(), server.URL+"/page")
// Should have 1 entry
count, hosts := checker.CacheStats()
if count != 1 {
t.Errorf("Expected 1 cached entry, got %d", count)
}
if len(hosts) != 1 {
t.Errorf("Expected 1 host, got %v", hosts)
}
}
func TestClearCache(t *testing.T) {
robotsTxt := `User-agent: *
Allow: /
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(robotsTxt))
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
// Populate cache
checker.IsAllowed(context.Background(), server.URL+"/page")
count, _ := checker.CacheStats()
if count != 1 {
t.Errorf("Expected 1 cached entry, got %d", count)
}
// Clear cache
checker.ClearCache()
count, _ = checker.CacheStats()
if count != 0 {
t.Errorf("Expected 0 cached entries after clear, got %d", count)
}
}
func TestParseRobotsTxt_Comments(t *testing.T) {
robotsTxt := `# This is a comment
User-agent: *
# Another comment
Disallow: /private/ # inline comment
Allow: /public/
`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/robots.txt" {
w.Write([]byte(robotsTxt))
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
checker := NewChecker("TestBot/1.0")
allowed, _ := checker.IsAllowed(context.Background(), server.URL+"/public/page")
if !allowed {
t.Error("/public/page should be allowed")
}
allowed, _ = checker.IsAllowed(context.Background(), server.URL+"/private/page")
if allowed {
t.Error("/private/page should be disallowed")
}
}
func TestIsAllowed_InvalidURL(t *testing.T) {
checker := NewChecker("TestBot/1.0")
_, err := checker.IsAllowed(context.Background(), "not a valid url ://")
if err == nil {
t.Error("Expected error for invalid URL")
}
}

View File

@@ -0,0 +1,222 @@
package scheduler
import (
"context"
"log"
"sync"
"time"
)
// CrawlFunc is the function signature for executing a crawl
type CrawlFunc func(ctx context.Context) error
// Status represents the current scheduler status
type Status struct {
Enabled bool `json:"enabled"`
Running bool `json:"running"`
LastRun time.Time `json:"last_run,omitempty"`
LastRunStatus string `json:"last_run_status,omitempty"`
NextRun time.Time `json:"next_run,omitempty"`
Interval string `json:"interval"`
}
// Scheduler handles automatic crawl scheduling
type Scheduler struct {
mu sync.RWMutex
enabled bool
interval time.Duration
crawlFunc CrawlFunc
running bool
lastRun time.Time
lastRunStatus string
stopChan chan struct{}
doneChan chan struct{}
}
// Config holds scheduler configuration
type Config struct {
Enabled bool
Interval time.Duration
}
// NewScheduler creates a new crawler scheduler
func NewScheduler(cfg Config, crawlFunc CrawlFunc) *Scheduler {
return &Scheduler{
enabled: cfg.Enabled,
interval: cfg.Interval,
crawlFunc: crawlFunc,
stopChan: make(chan struct{}),
doneChan: make(chan struct{}),
}
}
// Start begins the scheduler loop
func (s *Scheduler) Start() {
if !s.enabled {
log.Println("Scheduler is disabled")
return
}
log.Printf("Scheduler starting with interval: %v", s.interval)
go s.run()
}
// Stop gracefully stops the scheduler
func (s *Scheduler) Stop() {
s.mu.Lock()
if !s.enabled {
s.mu.Unlock()
return
}
s.mu.Unlock()
close(s.stopChan)
<-s.doneChan
log.Println("Scheduler stopped")
}
// run is the main scheduler loop
func (s *Scheduler) run() {
defer close(s.doneChan)
// Calculate time until first run
// Default: run at 2:00 AM to minimize impact
now := time.Now()
nextRun := s.calculateNextRun(now)
log.Printf("Scheduler: first crawl scheduled for %v", nextRun)
timer := time.NewTimer(time.Until(nextRun))
defer timer.Stop()
for {
select {
case <-s.stopChan:
return
case <-timer.C:
s.executeCrawl()
// Schedule next run
nextRun = time.Now().Add(s.interval)
timer.Reset(s.interval)
}
}
}
// calculateNextRun determines when the next crawl should occur
func (s *Scheduler) calculateNextRun(from time.Time) time.Time {
// If interval is 24h or more, schedule for 2:00 AM
if s.interval >= 24*time.Hour {
next := time.Date(from.Year(), from.Month(), from.Day(), 2, 0, 0, 0, from.Location())
if next.Before(from) || next.Equal(from) {
next = next.Add(24 * time.Hour)
}
return next
}
// For shorter intervals, start immediately
return from.Add(1 * time.Minute)
}
// executeCrawl runs the crawl function
func (s *Scheduler) executeCrawl() {
s.mu.Lock()
if s.running {
s.mu.Unlock()
log.Println("Scheduler: crawl already running, skipping")
return
}
s.running = true
s.mu.Unlock()
log.Println("Scheduler: starting scheduled crawl")
startTime := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Hour)
defer cancel()
err := s.crawlFunc(ctx)
s.mu.Lock()
s.running = false
s.lastRun = startTime
if err != nil {
s.lastRunStatus = "failed: " + err.Error()
log.Printf("Scheduler: crawl failed after %v: %v", time.Since(startTime), err)
} else {
s.lastRunStatus = "success"
log.Printf("Scheduler: crawl completed successfully in %v", time.Since(startTime))
}
s.mu.Unlock()
}
// TriggerCrawl manually triggers a crawl
func (s *Scheduler) TriggerCrawl() error {
s.mu.Lock()
if s.running {
s.mu.Unlock()
return ErrCrawlAlreadyRunning
}
s.running = true
s.mu.Unlock()
log.Println("Scheduler: manual crawl triggered")
go func() {
startTime := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Hour)
defer cancel()
err := s.crawlFunc(ctx)
s.mu.Lock()
s.running = false
s.lastRun = startTime
if err != nil {
s.lastRunStatus = "failed: " + err.Error()
log.Printf("Scheduler: manual crawl failed after %v: %v", time.Since(startTime), err)
} else {
s.lastRunStatus = "success"
log.Printf("Scheduler: manual crawl completed successfully in %v", time.Since(startTime))
}
s.mu.Unlock()
}()
return nil
}
// Status returns the current scheduler status
func (s *Scheduler) Status() Status {
s.mu.RLock()
defer s.mu.RUnlock()
status := Status{
Enabled: s.enabled,
Running: s.running,
LastRun: s.lastRun,
LastRunStatus: s.lastRunStatus,
Interval: s.interval.String(),
}
if s.enabled && !s.lastRun.IsZero() {
status.NextRun = s.lastRun.Add(s.interval)
}
return status
}
// IsRunning returns true if a crawl is currently in progress
func (s *Scheduler) IsRunning() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.running
}
// Errors
type SchedulerError string
func (e SchedulerError) Error() string { return string(e) }
const (
ErrCrawlAlreadyRunning = SchedulerError("crawl already running")
)

View File

@@ -0,0 +1,294 @@
package scheduler
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
)
func TestNewScheduler(t *testing.T) {
callCount := int32(0)
crawlFunc := func(ctx context.Context) error {
atomic.AddInt32(&callCount, 1)
return nil
}
cfg := Config{
Enabled: true,
Interval: 1 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
if scheduler == nil {
t.Fatal("Expected non-nil scheduler")
}
if !scheduler.enabled {
t.Error("Expected scheduler to be enabled")
}
if scheduler.interval != 1*time.Hour {
t.Errorf("Expected interval 1h, got %v", scheduler.interval)
}
}
func TestScheduler_Disabled(t *testing.T) {
callCount := int32(0)
crawlFunc := func(ctx context.Context) error {
atomic.AddInt32(&callCount, 1)
return nil
}
cfg := Config{
Enabled: false,
Interval: 1 * time.Second,
}
scheduler := NewScheduler(cfg, crawlFunc)
scheduler.Start()
// Wait a bit - crawl should not run
time.Sleep(100 * time.Millisecond)
if atomic.LoadInt32(&callCount) != 0 {
t.Error("Crawl should not run when scheduler is disabled")
}
}
func TestScheduler_TriggerCrawl(t *testing.T) {
callCount := int32(0)
crawlFunc := func(ctx context.Context) error {
atomic.AddInt32(&callCount, 1)
time.Sleep(50 * time.Millisecond) // Simulate work
return nil
}
cfg := Config{
Enabled: false, // Disabled scheduler, but manual trigger should work
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
// Trigger manual crawl
err := scheduler.TriggerCrawl()
if err != nil {
t.Fatalf("TriggerCrawl failed: %v", err)
}
// Wait for crawl to complete
time.Sleep(100 * time.Millisecond)
if atomic.LoadInt32(&callCount) != 1 {
t.Errorf("Expected 1 crawl, got %d", atomic.LoadInt32(&callCount))
}
}
func TestScheduler_TriggerCrawl_AlreadyRunning(t *testing.T) {
crawlFunc := func(ctx context.Context) error {
time.Sleep(200 * time.Millisecond)
return nil
}
cfg := Config{
Enabled: false,
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
// First trigger
err := scheduler.TriggerCrawl()
if err != nil {
t.Fatalf("First TriggerCrawl failed: %v", err)
}
// Wait a bit for crawl to start
time.Sleep(10 * time.Millisecond)
// Second trigger should fail
err = scheduler.TriggerCrawl()
if err != ErrCrawlAlreadyRunning {
t.Errorf("Expected ErrCrawlAlreadyRunning, got %v", err)
}
// Wait for crawl to complete
time.Sleep(250 * time.Millisecond)
// Now trigger should work again
err = scheduler.TriggerCrawl()
if err != nil {
t.Errorf("Third TriggerCrawl should succeed: %v", err)
}
}
func TestScheduler_Status(t *testing.T) {
crawlFunc := func(ctx context.Context) error {
return nil
}
cfg := Config{
Enabled: true,
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
status := scheduler.Status()
if !status.Enabled {
t.Error("Expected enabled=true")
}
if status.Running {
t.Error("Expected running=false initially")
}
if status.Interval != "24h0m0s" {
t.Errorf("Expected interval '24h0m0s', got '%s'", status.Interval)
}
}
func TestScheduler_Status_AfterCrawl(t *testing.T) {
crawlFunc := func(ctx context.Context) error {
return nil
}
cfg := Config{
Enabled: false,
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
// Trigger and wait
scheduler.TriggerCrawl()
time.Sleep(50 * time.Millisecond)
status := scheduler.Status()
if status.LastRun.IsZero() {
t.Error("Expected LastRun to be set")
}
if status.LastRunStatus != "success" {
t.Errorf("Expected status 'success', got '%s'", status.LastRunStatus)
}
}
func TestScheduler_Status_FailedCrawl(t *testing.T) {
crawlFunc := func(ctx context.Context) error {
return errors.New("connection failed")
}
cfg := Config{
Enabled: false,
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
// Trigger and wait
scheduler.TriggerCrawl()
time.Sleep(50 * time.Millisecond)
status := scheduler.Status()
if status.LastRunStatus != "failed: connection failed" {
t.Errorf("Expected failed status, got '%s'", status.LastRunStatus)
}
}
func TestScheduler_IsRunning(t *testing.T) {
crawlFunc := func(ctx context.Context) error {
time.Sleep(100 * time.Millisecond)
return nil
}
cfg := Config{
Enabled: false,
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
if scheduler.IsRunning() {
t.Error("Should not be running initially")
}
scheduler.TriggerCrawl()
time.Sleep(10 * time.Millisecond)
if !scheduler.IsRunning() {
t.Error("Should be running after trigger")
}
time.Sleep(150 * time.Millisecond)
if scheduler.IsRunning() {
t.Error("Should not be running after completion")
}
}
func TestScheduler_CalculateNextRun_Daily(t *testing.T) {
crawlFunc := func(ctx context.Context) error { return nil }
cfg := Config{
Enabled: true,
Interval: 24 * time.Hour,
}
scheduler := NewScheduler(cfg, crawlFunc)
// Test at 1 AM - should schedule for 2 AM same day
from := time.Date(2024, 1, 15, 1, 0, 0, 0, time.UTC)
next := scheduler.calculateNextRun(from)
expectedHour := 2
if next.Hour() != expectedHour {
t.Errorf("Expected hour %d, got %d", expectedHour, next.Hour())
}
if next.Day() != 15 {
t.Errorf("Expected day 15, got %d", next.Day())
}
// Test at 3 AM - should schedule for 2 AM next day
from = time.Date(2024, 1, 15, 3, 0, 0, 0, time.UTC)
next = scheduler.calculateNextRun(from)
if next.Day() != 16 {
t.Errorf("Expected day 16, got %d", next.Day())
}
}
func TestScheduler_CalculateNextRun_Hourly(t *testing.T) {
crawlFunc := func(ctx context.Context) error { return nil }
cfg := Config{
Enabled: true,
Interval: 1 * time.Hour, // Less than 24h
}
scheduler := NewScheduler(cfg, crawlFunc)
from := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
next := scheduler.calculateNextRun(from)
// Should start in about 1 minute
diff := next.Sub(from)
if diff < 30*time.Second || diff > 90*time.Second {
t.Errorf("Expected ~1 minute delay for short intervals, got %v", diff)
}
}
func TestSchedulerError(t *testing.T) {
err := ErrCrawlAlreadyRunning
if err.Error() != "crawl already running" {
t.Errorf("Unexpected error message: %s", err.Error())
}
}

View File

@@ -0,0 +1,592 @@
package search
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/opensearch-project/opensearch-go/v2"
"github.com/opensearch-project/opensearch-go/v2/opensearchapi"
)
// SearchRequest represents an API search request
type SearchRequest struct {
Query string `json:"q"`
Mode string `json:"mode"` // keyword, semantic, hybrid
Limit int `json:"limit"`
Offset int `json:"offset"`
Filters SearchFilters `json:"filters"`
Rerank bool `json:"rerank"`
Include SearchInclude `json:"include"`
}
// SearchFilters for narrowing results
type SearchFilters struct {
Language []string `json:"language"`
CountryHint []string `json:"country_hint"`
SourceCategory []string `json:"source_category"`
DocType []string `json:"doc_type"`
SchoolLevel []string `json:"school_level"`
Subjects []string `json:"subjects"`
State []string `json:"state"`
MinTrustScore float64 `json:"min_trust_score"`
DateFrom string `json:"date_from"`
}
// SearchInclude specifies what to include in response
type SearchInclude struct {
Snippets bool `json:"snippets"`
Highlights bool `json:"highlights"`
ContentText bool `json:"content_text"`
}
// SearchResult represents a single search result
type SearchResult struct {
DocID string `json:"doc_id"`
Title string `json:"title"`
URL string `json:"url"`
Domain string `json:"domain"`
Language string `json:"language"`
DocType string `json:"doc_type"`
SchoolLevel string `json:"school_level"`
Subjects []string `json:"subjects"`
Scores Scores `json:"scores"`
Snippet string `json:"snippet,omitempty"`
Highlights []string `json:"highlights,omitempty"`
}
// Scores contains all scoring components
type Scores struct {
BM25 float64 `json:"bm25"`
Semantic float64 `json:"semantic"`
Rerank float64 `json:"rerank"`
Trust float64 `json:"trust"`
Quality float64 `json:"quality"`
Final float64 `json:"final"`
}
// SearchResponse is the API response
type SearchResponse struct {
QueryID string `json:"query_id"`
Results []SearchResult `json:"results"`
Pagination Pagination `json:"pagination"`
}
// Pagination info
type Pagination struct {
Limit int `json:"limit"`
Offset int `json:"offset"`
TotalEstimate int `json:"total_estimate"`
}
// EmbeddingProvider interface for generating embeddings
type EmbeddingProvider interface {
Embed(ctx context.Context, text string) ([]float32, error)
IsEnabled() bool
Dimension() int
}
// Service handles search operations
type Service struct {
client *opensearch.Client
indexName string
embeddingProvider EmbeddingProvider
semanticEnabled bool
}
// NewService creates a new search service
func NewService(url, username, password, indexName string) (*Service, error) {
cfg := opensearch.Config{
Addresses: []string{url},
Username: username,
Password: password,
}
client, err := opensearch.NewClient(cfg)
if err != nil {
return nil, err
}
return &Service{
client: client,
indexName: indexName,
semanticEnabled: false,
}, nil
}
// SetEmbeddingProvider configures the embedding provider for semantic search
func (s *Service) SetEmbeddingProvider(provider EmbeddingProvider) {
if provider != nil && provider.IsEnabled() {
s.embeddingProvider = provider
s.semanticEnabled = true
}
}
// IsSemanticEnabled returns true if semantic search is available
func (s *Service) IsSemanticEnabled() bool {
return s.semanticEnabled && s.embeddingProvider != nil
}
// Search performs a search query
func (s *Service) Search(ctx context.Context, req *SearchRequest) (*SearchResponse, error) {
// Determine search mode
mode := req.Mode
if mode == "" {
mode = "keyword" // Default to keyword search
}
// For semantic/hybrid modes, generate query embedding
var queryEmbedding []float32
var embErr error
if (mode == "semantic" || mode == "hybrid") && s.IsSemanticEnabled() {
queryEmbedding, embErr = s.embeddingProvider.Embed(ctx, req.Query)
if embErr != nil {
// Fall back to keyword search if embedding fails
mode = "keyword"
}
} else if mode == "semantic" || mode == "hybrid" {
// Semantic requested but not enabled, fall back
mode = "keyword"
}
// Build OpenSearch query based on mode
var query map[string]interface{}
switch mode {
case "semantic":
query = s.buildSemanticQuery(req, queryEmbedding)
case "hybrid":
query = s.buildHybridQuery(req, queryEmbedding)
default:
query = s.buildQuery(req)
}
queryJSON, err := json.Marshal(query)
if err != nil {
return nil, err
}
searchReq := opensearchapi.SearchRequest{
Index: []string{s.indexName},
Body: strings.NewReader(string(queryJSON)),
}
res, err := searchReq.Do(ctx, s.client)
if err != nil {
return nil, err
}
defer res.Body.Close()
// Parse response
var osResponse struct {
Hits struct {
Total struct {
Value int `json:"value"`
} `json:"total"`
Hits []struct {
ID string `json:"_id"`
Score float64 `json:"_score"`
Source map[string]interface{} `json:"_source"`
Highlight map[string][]string `json:"highlight,omitempty"`
} `json:"hits"`
} `json:"hits"`
}
if err := json.NewDecoder(res.Body).Decode(&osResponse); err != nil {
return nil, err
}
// Convert to SearchResults
results := make([]SearchResult, 0, len(osResponse.Hits.Hits))
for _, hit := range osResponse.Hits.Hits {
result := s.hitToResult(hit.Source, hit.Score, hit.Highlight, req.Include)
results = append(results, result)
}
return &SearchResponse{
QueryID: fmt.Sprintf("q-%d", ctx.Value("request_id")),
Results: results,
Pagination: Pagination{
Limit: req.Limit,
Offset: req.Offset,
TotalEstimate: osResponse.Hits.Total.Value,
},
}, nil
}
// buildQuery constructs the OpenSearch query
func (s *Service) buildQuery(req *SearchRequest) map[string]interface{} {
// Main query
must := []map[string]interface{}{}
filter := []map[string]interface{}{}
// Text search
if req.Query != "" {
must = append(must, map[string]interface{}{
"multi_match": map[string]interface{}{
"query": req.Query,
"fields": []string{"title^3", "content_text"},
"type": "best_fields",
},
})
}
// Filters
if len(req.Filters.Language) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"language": req.Filters.Language},
})
}
if len(req.Filters.CountryHint) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"country_hint": req.Filters.CountryHint},
})
}
if len(req.Filters.SourceCategory) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"source_category": req.Filters.SourceCategory},
})
}
if len(req.Filters.DocType) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"doc_type": req.Filters.DocType},
})
}
if len(req.Filters.SchoolLevel) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"school_level": req.Filters.SchoolLevel},
})
}
if len(req.Filters.Subjects) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"subjects": req.Filters.Subjects},
})
}
if len(req.Filters.State) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"state": req.Filters.State},
})
}
if req.Filters.MinTrustScore > 0 {
filter = append(filter, map[string]interface{}{
"range": map[string]interface{}{
"trust_score": map[string]interface{}{"gte": req.Filters.MinTrustScore},
},
})
}
if req.Filters.DateFrom != "" {
filter = append(filter, map[string]interface{}{
"range": map[string]interface{}{
"fetch_time": map[string]interface{}{"gte": req.Filters.DateFrom},
},
})
}
// Build bool query
boolQuery := map[string]interface{}{}
if len(must) > 0 {
boolQuery["must"] = must
}
if len(filter) > 0 {
boolQuery["filter"] = filter
}
// Construct full query
query := map[string]interface{}{
"query": map[string]interface{}{
"bool": boolQuery,
},
"from": req.Offset,
"size": req.Limit,
"_source": []string{
"doc_id", "title", "url", "domain", "language",
"doc_type", "school_level", "subjects",
"trust_score", "quality_score", "snippet_text",
},
}
// Add highlighting if requested
if req.Include.Highlights {
query["highlight"] = map[string]interface{}{
"fields": map[string]interface{}{
"title": map[string]interface{}{},
"content_text": map[string]interface{}{"fragment_size": 150, "number_of_fragments": 3},
},
}
}
// Add function score for trust/quality boosting
query["query"] = map[string]interface{}{
"function_score": map[string]interface{}{
"query": query["query"],
"functions": []map[string]interface{}{
{
"field_value_factor": map[string]interface{}{
"field": "trust_score",
"factor": 1.5,
"modifier": "sqrt",
"missing": 0.5,
},
},
{
"field_value_factor": map[string]interface{}{
"field": "quality_score",
"factor": 1.0,
"modifier": "sqrt",
"missing": 0.5,
},
},
},
"score_mode": "multiply",
"boost_mode": "multiply",
},
}
return query
}
// buildSemanticQuery constructs a pure vector search query using k-NN
func (s *Service) buildSemanticQuery(req *SearchRequest, embedding []float32) map[string]interface{} {
filter := s.buildFilters(req)
// k-NN query for semantic search
knnQuery := map[string]interface{}{
"content_embedding": map[string]interface{}{
"vector": embedding,
"k": req.Limit + req.Offset, // Get enough results for pagination
},
}
// Add filter if present
if len(filter) > 0 {
knnQuery["content_embedding"].(map[string]interface{})["filter"] = map[string]interface{}{
"bool": map[string]interface{}{
"filter": filter,
},
}
}
query := map[string]interface{}{
"knn": knnQuery,
"from": req.Offset,
"size": req.Limit,
"_source": []string{
"doc_id", "title", "url", "domain", "language",
"doc_type", "school_level", "subjects",
"trust_score", "quality_score", "snippet_text",
},
}
// Add highlighting if requested
if req.Include.Highlights {
query["highlight"] = map[string]interface{}{
"fields": map[string]interface{}{
"title": map[string]interface{}{},
"content_text": map[string]interface{}{"fragment_size": 150, "number_of_fragments": 3},
},
}
}
return query
}
// buildHybridQuery constructs a combined BM25 + vector search query
func (s *Service) buildHybridQuery(req *SearchRequest, embedding []float32) map[string]interface{} {
filter := s.buildFilters(req)
// Build the bool query for BM25
must := []map[string]interface{}{}
if req.Query != "" {
must = append(must, map[string]interface{}{
"multi_match": map[string]interface{}{
"query": req.Query,
"fields": []string{"title^3", "content_text"},
"type": "best_fields",
},
})
}
boolQuery := map[string]interface{}{}
if len(must) > 0 {
boolQuery["must"] = must
}
if len(filter) > 0 {
boolQuery["filter"] = filter
}
// Convert embedding to []interface{} for JSON
embeddingInterface := make([]interface{}, len(embedding))
for i, v := range embedding {
embeddingInterface[i] = v
}
// Hybrid query using script_score to combine BM25 and cosine similarity
// This is a simpler approach than OpenSearch's neural search plugin
query := map[string]interface{}{
"query": map[string]interface{}{
"script_score": map[string]interface{}{
"query": map[string]interface{}{
"bool": boolQuery,
},
"script": map[string]interface{}{
"source": "cosineSimilarity(params.query_vector, 'content_embedding') + 1.0 + _score * 0.5",
"params": map[string]interface{}{
"query_vector": embeddingInterface,
},
},
},
},
"from": req.Offset,
"size": req.Limit,
"_source": []string{
"doc_id", "title", "url", "domain", "language",
"doc_type", "school_level", "subjects",
"trust_score", "quality_score", "snippet_text",
},
}
// Add highlighting if requested
if req.Include.Highlights {
query["highlight"] = map[string]interface{}{
"fields": map[string]interface{}{
"title": map[string]interface{}{},
"content_text": map[string]interface{}{"fragment_size": 150, "number_of_fragments": 3},
},
}
}
return query
}
// buildFilters constructs the filter array for queries
func (s *Service) buildFilters(req *SearchRequest) []map[string]interface{} {
filter := []map[string]interface{}{}
if len(req.Filters.Language) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"language": req.Filters.Language},
})
}
if len(req.Filters.CountryHint) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"country_hint": req.Filters.CountryHint},
})
}
if len(req.Filters.SourceCategory) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"source_category": req.Filters.SourceCategory},
})
}
if len(req.Filters.DocType) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"doc_type": req.Filters.DocType},
})
}
if len(req.Filters.SchoolLevel) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"school_level": req.Filters.SchoolLevel},
})
}
if len(req.Filters.Subjects) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"subjects": req.Filters.Subjects},
})
}
if len(req.Filters.State) > 0 {
filter = append(filter, map[string]interface{}{
"terms": map[string]interface{}{"state": req.Filters.State},
})
}
if req.Filters.MinTrustScore > 0 {
filter = append(filter, map[string]interface{}{
"range": map[string]interface{}{
"trust_score": map[string]interface{}{"gte": req.Filters.MinTrustScore},
},
})
}
if req.Filters.DateFrom != "" {
filter = append(filter, map[string]interface{}{
"range": map[string]interface{}{
"fetch_time": map[string]interface{}{"gte": req.Filters.DateFrom},
},
})
}
return filter
}
// hitToResult converts an OpenSearch hit to SearchResult
func (s *Service) hitToResult(source map[string]interface{}, score float64, highlight map[string][]string, include SearchInclude) SearchResult {
result := SearchResult{
DocID: getString(source, "doc_id"),
Title: getString(source, "title"),
URL: getString(source, "url"),
Domain: getString(source, "domain"),
Language: getString(source, "language"),
DocType: getString(source, "doc_type"),
SchoolLevel: getString(source, "school_level"),
Subjects: getStringArray(source, "subjects"),
Scores: Scores{
BM25: score,
Trust: getFloat(source, "trust_score"),
Quality: getFloat(source, "quality_score"),
Final: score, // MVP: final = BM25 * trust * quality (via function_score)
},
}
if include.Snippets {
result.Snippet = getString(source, "snippet_text")
}
if include.Highlights && highlight != nil {
if h, ok := highlight["content_text"]; ok {
result.Highlights = h
}
}
return result
}
// Helper functions
func getString(m map[string]interface{}, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getFloat(m map[string]interface{}, key string) float64 {
if v, ok := m[key].(float64); ok {
return v
}
return 0.0
}
func getStringArray(m map[string]interface{}, key string) []string {
if v, ok := m[key].([]interface{}); ok {
result := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
return nil
}

View File

@@ -0,0 +1,217 @@
// Package staff provides university staff crawling functionality
package staff
import (
"context"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/breakpilot/edu-search-service/internal/database"
"github.com/breakpilot/edu-search-service/internal/orchestrator"
)
// OrchestratorAdapter adapts the StaffCrawler to the orchestrator.StaffCrawlerInterface
// This bridges the gap between the generic StaffCrawler and the multi-phase orchestrator
type OrchestratorAdapter struct {
crawler *StaffCrawler
repo *database.Repository
}
// NewOrchestratorAdapter creates a new adapter that connects StaffCrawler to the orchestrator
func NewOrchestratorAdapter(crawler *StaffCrawler, repo *database.Repository) *OrchestratorAdapter {
return &OrchestratorAdapter{
crawler: crawler,
repo: repo,
}
}
// DiscoverSampleProfessor finds at least one professor to validate crawling works for this university
// This is Phase 1: Quick validation that the university website is crawlable
func (a *OrchestratorAdapter) DiscoverSampleProfessor(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
start := time.Now()
progress := &orchestrator.CrawlProgress{
Phase: orchestrator.PhaseDiscovery,
StartedAt: start,
}
log.Printf("[OrchestratorAdapter] Discovery phase for university %s", universityID)
// Get university from database
uni, err := a.repo.GetUniversityByID(ctx, universityID)
if err != nil {
progress.Errors = append(progress.Errors, fmt.Sprintf("Failed to get university: %v", err))
return progress, fmt.Errorf("failed to get university: %w", err)
}
if uni == nil {
progress.Errors = append(progress.Errors, "University not found")
return progress, fmt.Errorf("university not found: %s", universityID)
}
log.Printf("[OrchestratorAdapter] Discovering staff pages for %s (%s)", uni.Name, uni.URL)
// Use the crawler to find staff pages (discovery phase)
staffPages, err := a.crawler.findStaffPages(ctx, uni)
if err != nil {
progress.Errors = append(progress.Errors, fmt.Sprintf("Failed to find staff pages: %v", err))
return progress, fmt.Errorf("failed to find staff pages: %w", err)
}
log.Printf("[OrchestratorAdapter] Found %d staff pages for %s", len(staffPages), uni.Name)
// Try to extract at least one professor as validation
var sampleFound int
for _, pageURL := range staffPages {
if sampleFound > 0 {
break // We just need to validate one works
}
staffMembers, err := a.crawler.extractStaffFromPage(ctx, pageURL, uni)
if err != nil {
log.Printf("[OrchestratorAdapter] Error extracting from %s: %v", pageURL, err)
continue
}
// Count professors found
for _, staff := range staffMembers {
if staff.IsProfessor {
sampleFound++
log.Printf("[OrchestratorAdapter] Found sample professor: %s %s",
stringValue(staff.FirstName), staff.LastName)
break
}
}
// Even non-professors validate the crawler works
if sampleFound == 0 && len(staffMembers) > 0 {
sampleFound = 1
log.Printf("[OrchestratorAdapter] Found sample staff member (not professor): %s %s",
stringValue(staffMembers[0].FirstName), staffMembers[0].LastName)
}
}
progress.ItemsFound = len(staffPages) // Number of crawlable pages found
now := time.Now()
progress.CompletedAt = &now
if sampleFound == 0 && len(staffPages) > 0 {
// Pages found but no staff extracted - still consider it successful
log.Printf("[OrchestratorAdapter] Discovery completed: %d pages found, extraction may need tuning", len(staffPages))
} else if sampleFound == 0 {
progress.Errors = append(progress.Errors, "No staff pages found")
return progress, fmt.Errorf("no staff pages found for %s", uni.Name)
}
log.Printf("[OrchestratorAdapter] Discovery completed for %s: %d pages found", uni.Name, len(staffPages))
return progress, nil
}
// CrawlProfessors crawls all professors at a university
// This is Phase 2: Focus on finding professors specifically
func (a *OrchestratorAdapter) CrawlProfessors(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
start := time.Now()
progress := &orchestrator.CrawlProgress{
Phase: orchestrator.PhaseProfessors,
StartedAt: start,
}
log.Printf("[OrchestratorAdapter] Professors phase for university %s", universityID)
// Get university
uni, err := a.repo.GetUniversityByID(ctx, universityID)
if err != nil || uni == nil {
progress.Errors = append(progress.Errors, fmt.Sprintf("Failed to get university: %v", err))
return progress, fmt.Errorf("failed to get university: %w", err)
}
// Perform full crawl
result, err := a.crawler.CrawlUniversity(ctx, uni)
if err != nil {
progress.Errors = append(progress.Errors, fmt.Sprintf("Crawl failed: %v", err))
return progress, err
}
// Count professors specifically
professorCount := 0
staffList, err := a.repo.SearchStaff(ctx, database.StaffSearchParams{
UniversityID: &universityID,
IsProfessor: boolPtr(true),
Limit: 10000,
})
if err == nil {
professorCount = staffList.Total
}
progress.ItemsFound = professorCount
progress.ItemsProcessed = result.StaffFound
progress.Errors = result.Errors
now := time.Now()
progress.CompletedAt = &now
log.Printf("[OrchestratorAdapter] Professors phase completed for %s: %d professors found", uni.Name, professorCount)
return progress, nil
}
// CrawlAllStaff crawls all staff members at a university
// This is Phase 3: Get all staff (already done in Phase 2, but we verify/extend)
func (a *OrchestratorAdapter) CrawlAllStaff(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
start := time.Now()
progress := &orchestrator.CrawlProgress{
Phase: orchestrator.PhaseAllStaff,
StartedAt: start,
}
log.Printf("[OrchestratorAdapter] All Staff phase for university %s", universityID)
// Get university
uni, err := a.repo.GetUniversityByID(ctx, universityID)
if err != nil || uni == nil {
progress.Errors = append(progress.Errors, fmt.Sprintf("Failed to get university: %v", err))
return progress, fmt.Errorf("failed to get university: %w", err)
}
// Run another crawl pass to catch any missed staff
result, err := a.crawler.CrawlUniversity(ctx, uni)
if err != nil {
progress.Errors = result.Errors
// Don't fail completely - we may have some staff already
log.Printf("[OrchestratorAdapter] All Staff crawl had errors: %v", err)
}
// Get total staff count
staffCount := 0
staffList, err := a.repo.SearchStaff(ctx, database.StaffSearchParams{
UniversityID: &universityID,
Limit: 1, // Just need count
})
if err == nil {
staffCount = staffList.Total
}
progress.ItemsFound = staffCount
if result != nil {
progress.ItemsProcessed = result.StaffFound
progress.Errors = result.Errors
}
now := time.Now()
progress.CompletedAt = &now
log.Printf("[OrchestratorAdapter] All Staff phase completed for %s: %d total staff", uni.Name, staffCount)
return progress, nil
}
// Helper functions
func stringValue(s *string) string {
if s == nil {
return ""
}
return *s
}
func boolPtr(b bool) *bool {
return &b
}

View File

@@ -0,0 +1,342 @@
package staff
import (
"regexp"
"strings"
)
// UniversityPatterns contains URL patterns for specific universities
type UniversityPatterns struct {
patterns map[string]UniversityConfig
}
// UniversityConfig contains crawling configuration for a specific university
type UniversityConfig struct {
StaffListURLs []string // URLs to staff listing pages
StaffLinkPattern *regexp.Regexp // Pattern to identify staff profile links
NameSelector string // CSS selector for person name
PositionSelector string // CSS selector for position
EmailSelector string // CSS selector for email
PhotoSelector string // CSS selector for photo
Extractors []string // List of extractor types to use
}
// NewUniversityPatterns creates a new pattern registry with known patterns
func NewUniversityPatterns() *UniversityPatterns {
p := &UniversityPatterns{
patterns: make(map[string]UniversityConfig),
}
// Register known university patterns
p.registerKnownPatterns()
return p
}
// GetConfig returns the configuration for a university domain
func (p *UniversityPatterns) GetConfig(domain string) *UniversityConfig {
// Normalize domain
domain = strings.ToLower(domain)
domain = strings.TrimPrefix(domain, "www.")
if config, ok := p.patterns[domain]; ok {
return &config
}
// Try partial match
for key, config := range p.patterns {
if strings.Contains(domain, key) || strings.Contains(key, domain) {
return &config
}
}
return nil
}
// registerKnownPatterns registers patterns for known German universities
func (p *UniversityPatterns) registerKnownPatterns() {
// KIT - Karlsruher Institut für Technologie
p.patterns["kit.edu"] = UniversityConfig{
StaffListURLs: []string{
"https://www.kit.edu/kit/fakultaeten.php",
},
StaffLinkPattern: regexp.MustCompile(`/personen/\d+`),
NameSelector: ".person-name, h1.title",
PositionSelector: ".person-position, .position",
EmailSelector: "a[href^='mailto:']",
PhotoSelector: ".person-image img, .portrait img",
}
// TUM - Technische Universität München
p.patterns["tum.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.tum.de/die-tum/fakultaeten",
},
StaffLinkPattern: regexp.MustCompile(`/person/\w+`),
NameSelector: ".person-name, h1",
PositionSelector: ".person-title, .function",
EmailSelector: "a[href^='mailto:']",
PhotoSelector: ".person-photo img",
}
// LMU - Ludwig-Maximilians-Universität München
p.patterns["lmu.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.lmu.de/de/die-lmu/struktur/fakultaeten-einrichtungen-zentren-und-weitere-institutionen/",
},
NameSelector: ".person h2, .staff-name",
PositionSelector: ".person-position, .staff-position",
EmailSelector: "a[href^='mailto:']",
}
// RWTH Aachen
p.patterns["rwth-aachen.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.rwth-aachen.de/cms/root/Die-RWTH/Fakultaeten/~ep/Fakultaeten-und-Einrichtungen/",
},
NameSelector: ".person-name, h3.title",
PositionSelector: ".person-function, .position",
EmailSelector: "a[href^='mailto:']",
}
// TU Berlin
p.patterns["tu-berlin.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.tu.berlin/ueber-die-tu-berlin/organisation/fakultaeten-und-einrichtungen",
},
NameSelector: ".person-name, h2",
PositionSelector: ".position, .function",
EmailSelector: "a[href^='mailto:']",
}
// FU Berlin
p.patterns["fu-berlin.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.fu-berlin.de/einrichtungen/fachbereiche/",
},
NameSelector: ".person-fullname, h2",
PositionSelector: ".person-position",
EmailSelector: "a[href^='mailto:']",
}
// HU Berlin
p.patterns["hu-berlin.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.hu-berlin.de/de/einrichtungen-organisation/fakultaeten-und-institute",
},
NameSelector: ".person h2, .name",
PositionSelector: ".function, .position",
EmailSelector: "a[href^='mailto:']",
}
// Universität Freiburg
p.patterns["uni-freiburg.de"] = UniversityConfig{
StaffListURLs: []string{
"https://uni-freiburg.de/universitaet/fakultaeten/",
},
NameSelector: ".person-name, h2",
PositionSelector: ".person-position, .function",
EmailSelector: "a[href^='mailto:']",
}
// Universität Heidelberg
p.patterns["uni-heidelberg.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-heidelberg.de/de/fakultaeten",
},
NameSelector: ".person-fullname, h2",
PositionSelector: ".person-position",
EmailSelector: "a[href^='mailto:']",
}
// TU Dresden
p.patterns["tu-dresden.de"] = UniversityConfig{
StaffListURLs: []string{
"https://tu-dresden.de/tu-dresden/organisation/bereiche-und-fakultaeten",
},
NameSelector: ".person-name, h2.name",
PositionSelector: ".person-function, .funktion",
EmailSelector: "a[href^='mailto:']",
}
// Universität Leipzig
p.patterns["uni-leipzig.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-leipzig.de/universitaet/struktur/fakultaeten",
},
NameSelector: ".person h2, .name",
PositionSelector: ".position, .funktion",
EmailSelector: "a[href^='mailto:']",
}
// Universität Köln
p.patterns["uni-koeln.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-koeln.de/",
},
NameSelector: ".person-name, h2",
PositionSelector: ".person-position, .function",
EmailSelector: "a[href^='mailto:']",
}
// Universität Bonn
p.patterns["uni-bonn.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-bonn.de/de/universitaet/fakultaeten",
},
NameSelector: ".person-name, h2",
PositionSelector: ".person-position",
EmailSelector: "a[href^='mailto:']",
}
// Universität Münster
p.patterns["uni-muenster.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-muenster.de/de/fakultaeten.html",
},
NameSelector: ".person-name, h2",
PositionSelector: ".person-function",
EmailSelector: "a[href^='mailto:']",
}
// Universität Hamburg
p.patterns["uni-hamburg.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-hamburg.de/einrichtungen/fakultaeten.html",
},
NameSelector: ".person-name, h2",
PositionSelector: ".position",
EmailSelector: "a[href^='mailto:']",
}
// Universität Göttingen
p.patterns["uni-goettingen.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.uni-goettingen.de/de/fakultaeten/27952.html",
},
NameSelector: ".person-name, h2",
PositionSelector: ".person-position",
EmailSelector: "a[href^='mailto:']",
}
// TU Darmstadt
p.patterns["tu-darmstadt.de"] = UniversityConfig{
StaffListURLs: []string{
"https://www.tu-darmstadt.de/universitaet/fachbereiche/index.de.jsp",
},
NameSelector: ".person-name, h2",
PositionSelector: ".person-position, .funktion",
EmailSelector: "a[href^='mailto:']",
}
}
// CommonStaffPagePaths returns common paths where staff listings are found
func CommonStaffPagePaths() []string {
return []string{
"/personen",
"/team",
"/mitarbeiter",
"/mitarbeitende",
"/staff",
"/people",
"/ueber-uns/team",
"/about/team",
"/fakultaet/personen",
"/institut/mitarbeiter",
"/lehrstuhl/team",
"/personal",
"/beschaeftigte",
"/dozenten",
"/professoren",
}
}
// CommonPersonSelectors returns common CSS selectors for person elements
func CommonPersonSelectors() []string {
return []string{
".person",
".person-card",
".staff-member",
".team-member",
".mitarbeiter",
".employee",
".vcard",
".h-card",
"[itemtype='http://schema.org/Person']",
".person-entry",
".staff-entry",
".profile-card",
}
}
// TitlePrefixes returns common German academic title prefixes
func TitlePrefixes() []string {
return []string{
"Prof. Dr. Dr. h.c. mult.",
"Prof. Dr. Dr. h.c.",
"Prof. Dr. Dr.",
"Prof. Dr.-Ing.",
"Prof. Dr. rer. nat.",
"Prof. Dr. phil.",
"Prof. Dr. jur.",
"Prof. Dr. med.",
"Prof. Dr.",
"Prof.",
"PD Dr.",
"apl. Prof. Dr.",
"Jun.-Prof. Dr.",
"Dr.-Ing.",
"Dr. rer. nat.",
"Dr. phil.",
"Dr. jur.",
"Dr. med.",
"Dr.",
"Dipl.-Ing.",
"Dipl.-Inf.",
"Dipl.-Phys.",
"Dipl.-Math.",
"Dipl.-Kfm.",
"M.Sc.",
"M.A.",
"M.Eng.",
"B.Sc.",
"B.A.",
}
}
// PositionKeywords returns keywords that indicate staff positions
func PositionKeywords() []string {
return []string{
// Professors
"Professor", "Professorin",
"Ordinarius",
"Lehrstuhlinhaber", "Lehrstuhlinhaberin",
"Dekan", "Dekanin",
"Rektor", "Rektorin",
// Research staff
"Wissenschaftlicher Mitarbeiter", "Wissenschaftliche Mitarbeiterin",
"Akademischer Rat", "Akademische Rätin",
"Postdoktorand", "Postdoktorandin",
"Doktorand", "Doktorandin",
"Promovend", "Promovendin",
"Forscher", "Forscherin",
"Researcher",
// Teaching
"Dozent", "Dozentin",
"Lektor", "Lektorin",
"Lehrbeauftragter", "Lehrbeauftragte",
// Administrative
"Sekretär", "Sekretärin",
"Geschäftsführer", "Geschäftsführerin",
"Verwaltungsleiter", "Verwaltungsleiterin",
"Referent", "Referentin",
// Students
"Studentische Hilfskraft",
"Wissenschaftliche Hilfskraft",
"Tutor", "Tutorin",
}
}

View File

@@ -0,0 +1,78 @@
// Package staff provides university staff and publication crawling functionality
package staff
import (
"context"
"log"
"time"
"github.com/google/uuid"
"github.com/breakpilot/edu-search-service/internal/database"
"github.com/breakpilot/edu-search-service/internal/orchestrator"
)
// PublicationOrchestratorAdapter adapts publication crawling to the orchestrator interface
// Note: This is a stub for now - publication crawling is a future feature
type PublicationOrchestratorAdapter struct {
repo *database.Repository
}
// NewPublicationOrchestratorAdapter creates a new publication crawler adapter
func NewPublicationOrchestratorAdapter(repo *database.Repository) *PublicationOrchestratorAdapter {
return &PublicationOrchestratorAdapter{
repo: repo,
}
}
// CrawlPublicationsForUniversity crawls publications for all staff at a university
// This is Phase 4: Publication discovery (future implementation)
func (a *PublicationOrchestratorAdapter) CrawlPublicationsForUniversity(ctx context.Context, universityID uuid.UUID) (*orchestrator.CrawlProgress, error) {
start := time.Now()
progress := &orchestrator.CrawlProgress{
Phase: orchestrator.PhasePublications,
StartedAt: start,
}
log.Printf("[PublicationAdapter] Publications phase for university %s", universityID)
// Get staff members for this university
staffList, err := a.repo.SearchStaff(ctx, database.StaffSearchParams{
UniversityID: &universityID,
Limit: 10000,
})
if err != nil {
progress.Errors = append(progress.Errors, err.Error())
return progress, err
}
log.Printf("[PublicationAdapter] Found %d staff members for publication crawling", staffList.Total)
// TODO: Implement actual publication crawling
// - For each staff member with ORCID/Google Scholar ID:
// - Fetch publications from ORCID API
// - Fetch publications from Google Scholar
// - Match and deduplicate
// - Store in database
//
// For now, we mark this phase as complete (no-op)
pubCount := 0
// Count existing publications for this university
for _, staff := range staffList.Staff {
pubs, err := a.repo.GetStaffPublications(ctx, staff.ID)
if err == nil {
pubCount += len(pubs)
}
}
progress.ItemsFound = pubCount
progress.ItemsProcessed = staffList.Total
now := time.Now()
progress.CompletedAt = &now
log.Printf("[PublicationAdapter] Publications phase completed for university %s: %d existing publications found", universityID, pubCount)
return progress, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,348 @@
package staff
import (
"testing"
"github.com/breakpilot/edu-search-service/internal/database"
)
func TestParseName_FullName_WithTitle(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
fullName string
expectedFirst string
expectedLast string
expectedTitle bool
}{
{
name: "Prof. Dr. with first and last name",
fullName: "Prof. Dr. Hans Müller",
expectedFirst: "Hans",
expectedLast: "Müller",
expectedTitle: true,
},
{
name: "Dr. with first and last name",
fullName: "Dr. Maria Schmidt",
expectedFirst: "Maria",
expectedLast: "Schmidt",
expectedTitle: true,
},
{
name: "Simple name without title",
fullName: "Thomas Weber",
expectedFirst: "Thomas",
expectedLast: "Weber",
expectedTitle: false,
},
{
name: "Multiple first names",
fullName: "Prof. Dr. Hans-Peter Meier",
expectedFirst: "Hans-Peter",
expectedLast: "Meier",
expectedTitle: true,
},
{
name: "Single name",
fullName: "Müller",
expectedFirst: "",
expectedLast: "Müller",
expectedTitle: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
person := &database.UniversityStaff{}
crawler.parseName(tt.fullName, person)
firstName := ""
if person.FirstName != nil {
firstName = *person.FirstName
}
if firstName != tt.expectedFirst {
t.Errorf("First name: expected %q, got %q", tt.expectedFirst, firstName)
}
if person.LastName != tt.expectedLast {
t.Errorf("Last name: expected %q, got %q", tt.expectedLast, person.LastName)
}
hasTitle := person.Title != nil && *person.Title != ""
if hasTitle != tt.expectedTitle {
t.Errorf("Has title: expected %v, got %v", tt.expectedTitle, hasTitle)
}
})
}
}
func TestClassifyPosition_Professor(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
expected string
}{
{"Full Professor", "Professor für Informatik", "professor"},
{"Prof abbreviation", "Prof. Dr. Müller", "professor"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.classifyPosition(tt.position)
if result == nil {
t.Errorf("Expected %q, got nil for position %q", tt.expected, tt.position)
return
}
if *result != tt.expected {
t.Errorf("Expected %q, got %q for position %q", tt.expected, *result, tt.position)
}
})
}
}
func TestClassifyPosition_Postdoc(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
expected string
}{
{"Postdoc", "Postdoc in Machine Learning", "postdoc"},
{"Post-Doc hyphenated", "Post-Doc", "postdoc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.classifyPosition(tt.position)
if result == nil {
t.Errorf("Expected %q, got nil for position %q", tt.expected, tt.position)
return
}
if *result != tt.expected {
t.Errorf("Expected %q, got %q for position %q", tt.expected, *result, tt.position)
}
})
}
}
func TestClassifyPosition_PhDStudent(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
expected string
}{
{"Doktorand", "Doktorand", "phd_student"},
{"PhD Student", "PhD Student", "phd_student"},
{"Promovend", "Promovend", "phd_student"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.classifyPosition(tt.position)
if result == nil {
t.Errorf("Expected %q, got nil for position %q", tt.expected, tt.position)
return
}
if *result != tt.expected {
t.Errorf("Expected %q, got %q for position %q", tt.expected, *result, tt.position)
}
})
}
}
func TestClassifyPosition_Admin(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
expected string
}{
{"Sekretariat", "Sekretärin", "admin"},
{"Verwaltung", "Verwaltung", "admin"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.classifyPosition(tt.position)
if result == nil {
t.Errorf("Expected %q, got nil for position %q", tt.expected, tt.position)
return
}
if *result != tt.expected {
t.Errorf("Expected %q, got %q for position %q", tt.expected, *result, tt.position)
}
})
}
}
func TestClassifyPosition_Researcher(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
expected string
}{
{"Wissenschaftlicher Mitarbeiter", "Wissenschaftlicher Mitarbeiter", "researcher"},
{"Researcher", "Senior Researcher", "researcher"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.classifyPosition(tt.position)
if result == nil {
t.Errorf("Expected %q, got nil for position %q", tt.expected, tt.position)
return
}
if *result != tt.expected {
t.Errorf("Expected %q, got %q for position %q", tt.expected, *result, tt.position)
}
})
}
}
func TestClassifyPosition_Student(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
expected string
}{
{"Studentische Hilfskraft", "Studentische Hilfskraft", "student"},
{"HiWi", "Student (HiWi)", "student"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.classifyPosition(tt.position)
if result == nil {
t.Errorf("Expected %q, got nil for position %q", tt.expected, tt.position)
return
}
if *result != tt.expected {
t.Errorf("Expected %q, got %q for position %q", tt.expected, *result, tt.position)
}
})
}
}
func TestIsProfessor_True(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
}{
{"Professor keyword", "Professor für Mathematik"},
{"Prof. abbreviation", "Prof. Dr. Müller"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.isProfessor(tt.position)
if !result {
t.Errorf("Expected true for position=%q", tt.position)
}
})
}
}
func TestIsProfessor_False(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
position string
}{
{"Dr. only", "Dr. Wissenschaftlicher Mitarbeiter"},
{"Doktorand", "Doktorand"},
{"Technical staff", "Laboringenieur"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.isProfessor(tt.position)
if result {
t.Errorf("Expected false for position=%q", tt.position)
}
})
}
}
func TestLooksLikePosition_True(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
text string
}{
{"Professor", "Professor für Informatik"},
{"Wissenschaftlicher Mitarbeiter", "Wissenschaftlicher Mitarbeiter"},
{"Doktorand", "Doktorand"},
{"Sekretär", "Sekretärin"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.looksLikePosition(tt.text)
if !result {
t.Errorf("Expected true for text=%q", tt.text)
}
})
}
}
func TestLooksLikePosition_False(t *testing.T) {
crawler := &StaffCrawler{}
tests := []struct {
name string
text string
}{
{"Name", "Hans Müller"},
{"Email", "test@example.com"},
{"Random text", "Room 123"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := crawler.looksLikePosition(tt.text)
if result {
t.Errorf("Expected false for text=%q", tt.text)
}
})
}
}
func TestResolveURL(t *testing.T) {
tests := []struct {
name string
baseURL string
href string
expected string
}{
{"Absolute URL", "https://example.com", "https://other.com/page", "https://other.com/page"},
{"Relative path", "https://example.com/team", "/person/123", "https://example.com/person/123"},
{"Relative no slash", "https://example.com/team/", "member", "https://example.com/team/member"},
{"Empty href", "https://example.com", "", ""},
{"Root relative", "https://example.com/a/b/c", "/root", "https://example.com/root"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := resolveURL(tt.baseURL, tt.href)
if result != tt.expected {
t.Errorf("resolveURL(%q, %q) = %q, expected %q",
tt.baseURL, tt.href, result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,455 @@
package tagger
import (
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"gopkg.in/yaml.v3"
)
// TagResult contains all tags assigned to a document
type TagResult struct {
DocType string `json:"doc_type"`
Subjects []string `json:"subjects"`
SchoolLevel string `json:"school_level"`
State string `json:"state"`
TrustScore float64 `json:"trust_score"`
}
// Tagger applies rules to content and URLs
type Tagger struct {
docTypeRules *DocTypeRules
subjectRules *SubjectRules
levelRules *LevelRules
trustRules *TrustRules
}
// DocTypeRules YAML structure
type DocTypeRules struct {
DocTypes map[string]DocTypeRule `yaml:"doc_types"`
PriorityOrder []string `yaml:"priority_order"`
}
type DocTypeRule struct {
StrongTerms []string `yaml:"strong_terms"`
MediumTerms []string `yaml:"medium_terms"`
URLPatterns []string `yaml:"url_patterns"`
}
// SubjectRules YAML structure
type SubjectRules struct {
Subjects map[string]SubjectRule `yaml:"subjects"`
Threshold int `yaml:"threshold"`
MaxSubjects int `yaml:"max_subjects"`
}
type SubjectRule struct {
Strong []string `yaml:"strong"`
Weak []string `yaml:"weak"`
Negative []string `yaml:"negative"`
}
// LevelRules YAML structure
type LevelRules struct {
Levels map[string]LevelRule `yaml:"levels"`
PriorityOrder []string `yaml:"priority_order"`
}
type LevelRule struct {
Strong []string `yaml:"strong"`
Weak []string `yaml:"weak"`
Negative []string `yaml:"negative"`
}
// TrustRules YAML structure
type TrustRules struct {
DomainBoosts []DomainBoost `yaml:"domain_boosts"`
TLDBoosts []TLDBoost `yaml:"tld_boosts"`
Penalties []Penalty `yaml:"penalties"`
ContentPenalties []ContentPenalty `yaml:"content_penalties"`
}
type DomainBoost struct {
Match string `yaml:"match"`
Add float64 `yaml:"add"`
Reason string `yaml:"reason"`
}
type TLDBoost struct {
TLD string `yaml:"tld"`
Add float64 `yaml:"add"`
Reason string `yaml:"reason"`
}
type Penalty struct {
IfURLContains []string `yaml:"if_url_contains"`
Add float64 `yaml:"add"`
Reason string `yaml:"reason"`
}
type ContentPenalty struct {
IfAdDensityGT *float64 `yaml:"if_ad_density_gt,omitempty"`
IfLinkDensityGT *float64 `yaml:"if_link_density_gt,omitempty"`
IfContentLengthLT *int `yaml:"if_content_length_lt,omitempty"`
Add float64 `yaml:"add"`
Reason string `yaml:"reason"`
}
// ContentFeatures for trust scoring
type ContentFeatures struct {
AdDensity float64
LinkDensity float64
ContentLength int
}
// NewTagger creates a new tagger with rules from the specified directory
func NewTagger(rulesDir string) (*Tagger, error) {
t := &Tagger{}
// Load doc type rules
docTypeBytes, err := os.ReadFile(filepath.Join(rulesDir, "doc_type_rules.yaml"))
if err != nil {
return nil, err
}
t.docTypeRules = &DocTypeRules{}
if err := yaml.Unmarshal(docTypeBytes, t.docTypeRules); err != nil {
return nil, err
}
// Load subject rules
subjectBytes, err := os.ReadFile(filepath.Join(rulesDir, "subject_rules.yaml"))
if err != nil {
return nil, err
}
t.subjectRules = &SubjectRules{}
if err := yaml.Unmarshal(subjectBytes, t.subjectRules); err != nil {
return nil, err
}
// Load level rules
levelBytes, err := os.ReadFile(filepath.Join(rulesDir, "level_rules.yaml"))
if err != nil {
return nil, err
}
t.levelRules = &LevelRules{}
if err := yaml.Unmarshal(levelBytes, t.levelRules); err != nil {
return nil, err
}
// Load trust rules
trustBytes, err := os.ReadFile(filepath.Join(rulesDir, "trust_rules.yaml"))
if err != nil {
return nil, err
}
t.trustRules = &TrustRules{}
if err := yaml.Unmarshal(trustBytes, t.trustRules); err != nil {
return nil, err
}
return t, nil
}
// Tag applies all rules to content and returns tags
func (t *Tagger) Tag(url string, title string, content string, features ContentFeatures) TagResult {
lowerURL := strings.ToLower(url)
lowerTitle := strings.ToLower(title)
lowerContent := strings.ToLower(content)
combined := lowerTitle + " " + lowerContent
result := TagResult{
DocType: "Sonstiges",
Subjects: []string{},
SchoolLevel: "NA",
State: t.detectState(lowerURL),
}
// Tag doc type
result.DocType = t.tagDocType(lowerURL, combined)
// Tag subjects
result.Subjects = t.tagSubjects(combined)
// Tag school level
result.SchoolLevel = t.tagSchoolLevel(combined)
// Calculate trust score
result.TrustScore = t.calculateTrustScore(lowerURL, features)
return result
}
func (t *Tagger) tagDocType(url string, content string) string {
scores := make(map[string]int)
for docType, rule := range t.docTypeRules.DocTypes {
score := 0
// Check strong terms (+4 each)
for _, term := range rule.StrongTerms {
if strings.Contains(content, strings.ToLower(term)) {
score += 4
}
}
// Check medium terms (+3 each)
for _, term := range rule.MediumTerms {
if strings.Contains(content, strings.ToLower(term)) {
score += 3
}
}
// Check URL patterns (+2 each)
for _, pattern := range rule.URLPatterns {
if strings.Contains(url, strings.ToLower(pattern)) {
score += 2
}
}
if score > 0 {
scores[docType] = score
}
}
if len(scores) == 0 {
return "Sonstiges"
}
// Find highest scoring type, respecting priority for ties
var bestType string
bestScore := 0
for _, docType := range t.docTypeRules.PriorityOrder {
if score, ok := scores[docType]; ok {
if score > bestScore || (score == bestScore && bestType == "") {
bestScore = score
bestType = docType
}
}
}
if bestType == "" {
return "Sonstiges"
}
return bestType
}
func (t *Tagger) tagSubjects(content string) []string {
type subjectScore struct {
name string
score int
}
var scores []subjectScore
for subject, rule := range t.subjectRules.Subjects {
score := 0
// Check strong terms (+3 each)
for _, term := range rule.Strong {
if strings.Contains(content, strings.ToLower(term)) {
score += 3
}
}
// Check weak terms (+1 each)
for _, term := range rule.Weak {
if strings.Contains(content, strings.ToLower(term)) {
score += 1
}
}
// Check negative terms (-2 each)
for _, term := range rule.Negative {
if strings.Contains(content, strings.ToLower(term)) {
score -= 2
}
}
threshold := t.subjectRules.Threshold
if threshold == 0 {
threshold = 4 // default
}
if score >= threshold {
scores = append(scores, subjectScore{name: subject, score: score})
}
}
// Sort by score descending
sort.Slice(scores, func(i, j int) bool {
return scores[i].score > scores[j].score
})
// Take top N subjects
maxSubjects := t.subjectRules.MaxSubjects
if maxSubjects == 0 {
maxSubjects = 3 // default
}
var result []string
for i, s := range scores {
if i >= maxSubjects {
break
}
result = append(result, s.name)
}
return result
}
func (t *Tagger) tagSchoolLevel(content string) string {
scores := make(map[string]int)
for level, rule := range t.levelRules.Levels {
score := 0
// Check strong terms (+3 each)
for _, term := range rule.Strong {
if strings.Contains(content, strings.ToLower(term)) {
score += 3
}
}
// Check weak terms (+1 each)
for _, term := range rule.Weak {
if strings.Contains(content, strings.ToLower(term)) {
score += 1
}
}
// Check negative terms (-2 each)
for _, term := range rule.Negative {
if strings.Contains(content, strings.ToLower(term)) {
score -= 2
}
}
if score > 0 {
scores[level] = score
}
}
if len(scores) == 0 {
return "NA"
}
// Find highest scoring level, respecting priority for ties
var bestLevel string
bestScore := 0
for _, level := range t.levelRules.PriorityOrder {
if score, ok := scores[level]; ok {
if score > bestScore {
bestScore = score
bestLevel = level
}
}
}
if bestLevel == "" {
return "NA"
}
return bestLevel
}
func (t *Tagger) calculateTrustScore(url string, features ContentFeatures) float64 {
score := 0.50 // base score
// Apply domain boosts
for _, boost := range t.trustRules.DomainBoosts {
if matchDomainPattern(url, boost.Match) {
score += boost.Add
}
}
// Apply TLD boosts
for _, boost := range t.trustRules.TLDBoosts {
if strings.HasSuffix(url, boost.TLD) || strings.Contains(url, boost.TLD+"/") {
score += boost.Add
}
}
// Apply URL penalties
for _, penalty := range t.trustRules.Penalties {
for _, pattern := range penalty.IfURLContains {
if strings.Contains(url, strings.ToLower(pattern)) {
score += penalty.Add // Add is negative
break
}
}
}
// Apply content penalties
for _, penalty := range t.trustRules.ContentPenalties {
if penalty.IfAdDensityGT != nil && features.AdDensity > *penalty.IfAdDensityGT {
score += penalty.Add
}
if penalty.IfLinkDensityGT != nil && features.LinkDensity > *penalty.IfLinkDensityGT {
score += penalty.Add
}
if penalty.IfContentLengthLT != nil && features.ContentLength < *penalty.IfContentLengthLT {
score += penalty.Add
}
}
// Clamp to [0, 1]
if score < 0 {
score = 0
}
if score > 1 {
score = 1
}
return score
}
func matchDomainPattern(url string, pattern string) bool {
// Convert wildcard pattern to regex
// *.example.de should match subdomain.example.de and example.de
regexPattern := strings.ReplaceAll(pattern, ".", "\\.")
regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
regexPattern = "(?i)" + regexPattern // case insensitive
re, err := regexp.Compile(regexPattern)
if err != nil {
return false
}
return re.MatchString(url)
}
func (t *Tagger) detectState(url string) string {
statePatterns := map[string][]string{
"BW": {"baden-wuerttemberg", "bw.de", "schule-bw.de", "kultusministerium.baden"},
"BY": {"bayern.de", "isb.bayern", "km.bayern"},
"BE": {"berlin.de", "bildungsserver.berlin"},
"BB": {"brandenburg.de", "bildungsserver.brandenburg"},
"HB": {"bremen.de", "lis.bremen"},
"HH": {"hamburg.de", "li.hamburg"},
"HE": {"hessen.de", "hkm.hessen", "bildung.hessen"},
"MV": {"mecklenburg-vorpommern", "mv.de", "bildung-mv.de"},
"NI": {"niedersachsen.de", "nibis.de", "mk.niedersachsen"},
"NW": {"nrw.de", "learnline.nrw", "schulministerium.nrw"},
"RP": {"rheinland-pfalz", "rlp.de", "bildung-rp.de"},
"SL": {"saarland.de", "bildungsserver.saarland"},
"SN": {"sachsen.de", "schule.sachsen", "smk.sachsen"},
"ST": {"sachsen-anhalt", "bildung-lsa.de", "mk.sachsen-anhalt"},
"SH": {"schleswig-holstein", "sh.de", "bildungsserver.schleswig"},
"TH": {"thueringen.de", "schulportal-thueringen"},
}
for state, patterns := range statePatterns {
for _, pattern := range patterns {
if strings.Contains(url, pattern) {
return state
}
}
}
return "" // Bundesweit or unknown
}

View File

@@ -0,0 +1,557 @@
package tagger
import (
"os"
"path/filepath"
"testing"
)
// createTestRulesDir creates temporary test rule files
func createTestRulesDir(t *testing.T) string {
t.Helper()
dir := t.TempDir()
// Create doc_type_rules.yaml
docTypeRules := `doc_types:
Lehrplan:
strong_terms:
- Lehrplan
- Kernlehrplan
- Bildungsplan
medium_terms:
- Curriculum
url_patterns:
- /lehrplan
Arbeitsblatt:
strong_terms:
- Arbeitsblatt
- Übungsblatt
medium_terms:
- Aufgaben
url_patterns:
- /arbeitsblatt
Studie_Bericht:
strong_terms:
- Studie
- PISA
medium_terms:
- Ergebnis
url_patterns:
- /studie
priority_order:
- Lehrplan
- Arbeitsblatt
- Studie_Bericht
- Sonstiges
`
if err := os.WriteFile(filepath.Join(dir, "doc_type_rules.yaml"), []byte(docTypeRules), 0644); err != nil {
t.Fatal(err)
}
// Create subject_rules.yaml
subjectRules := `subjects:
Mathematik:
strong:
- Mathematik
- Algebra
- Geometrie
weak:
- rechnen
- Zahlen
negative:
- Geschichte der Mathematik
Deutsch:
strong:
- Deutsch
- Grammatik
- Rechtschreibung
weak:
- Lesen
- Schreiben
negative:
- Deutsch als Fremdsprache
Geschichte:
strong:
- Geschichte
- Historisch
weak:
- Epoche
- Jahrhundert
negative:
- Naturgeschichte
threshold: 4
max_subjects: 3
`
if err := os.WriteFile(filepath.Join(dir, "subject_rules.yaml"), []byte(subjectRules), 0644); err != nil {
t.Fatal(err)
}
// Create level_rules.yaml
levelRules := `levels:
Grundschule:
strong:
- Grundschule
- Primarstufe
- Klasse 1-4
weak:
- Erstklässler
negative:
- Sekundarstufe
Gymnasium:
strong:
- Gymnasium
- Abitur
- Oberstufe
weak:
- Sekundarstufe II
negative:
- Realschule
Sek_I:
strong:
- Sekundarstufe I
- Klasse 5-10
- Hauptschule
weak:
- Mittelstufe
negative:
- Grundschule
priority_order:
- Gymnasium
- Sek_I
- Grundschule
- NA
`
if err := os.WriteFile(filepath.Join(dir, "level_rules.yaml"), []byte(levelRules), 0644); err != nil {
t.Fatal(err)
}
// Create trust_rules.yaml
trustRules := `domain_boosts:
- match: "*.kmk.org"
add: 0.30
reason: "Kultusministerkonferenz"
- match: "*.bildungsserver.de"
add: 0.25
reason: "Deutscher Bildungsserver"
- match: "*.bayern.de"
add: 0.20
reason: "Bayerische Landesregierung"
tld_boosts:
- tld: ".gov"
add: 0.15
reason: "Government domain"
penalties:
- if_url_contains:
- "forum"
- "blog"
add: -0.10
reason: "User generated content"
content_penalties:
- if_ad_density_gt: 0.3
add: -0.15
reason: "High ad density"
- if_content_length_lt: 200
add: -0.10
reason: "Very short content"
`
if err := os.WriteFile(filepath.Join(dir, "trust_rules.yaml"), []byte(trustRules), 0644); err != nil {
t.Fatal(err)
}
return dir
}
func TestNewTagger_Success(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatalf("NewTagger failed: %v", err)
}
if tagger == nil {
t.Fatal("Expected non-nil tagger")
}
if tagger.docTypeRules == nil {
t.Error("docTypeRules not loaded")
}
if tagger.subjectRules == nil {
t.Error("subjectRules not loaded")
}
if tagger.levelRules == nil {
t.Error("levelRules not loaded")
}
if tagger.trustRules == nil {
t.Error("trustRules not loaded")
}
}
func TestNewTagger_MissingFile(t *testing.T) {
_, err := NewTagger("/nonexistent/path")
if err == nil {
t.Error("Expected error for nonexistent rules directory")
}
}
func TestTagger_TagDocType_Lehrplan(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
url string
content string
expected string
}{
{
name: "Strong term in content",
url: "https://example.com/page",
content: "Dies ist der Lehrplan für Mathematik in der Sekundarstufe",
expected: "Lehrplan",
},
{
name: "URL pattern match",
url: "https://example.com/lehrplan/mathe",
content: "Allgemeine Informationen zum Fach",
expected: "Lehrplan",
},
{
name: "Multiple strong terms",
url: "https://example.com/bildung",
content: "Kernlehrplan und Bildungsplan für das Curriculum",
expected: "Lehrplan",
},
{
name: "Arbeitsblatt detection",
url: "https://example.com/material",
content: "Arbeitsblatt zum Thema Rechnen mit Übungsblatt",
expected: "Arbeitsblatt",
},
{
name: "No match returns Sonstiges",
url: "https://example.com/page",
content: "Eine allgemeine Webseite ohne spezifische Bildungsinhalte",
expected: "Sonstiges",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tagger.Tag(tt.url, "", tt.content, ContentFeatures{})
if result.DocType != tt.expected {
t.Errorf("Expected DocType %q, got %q", tt.expected, result.DocType)
}
})
}
}
func TestTagger_TagSubjects(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
content string
expectedContain []string
expectedMissing []string
}{
{
name: "Mathematik detection",
content: "In Mathematik lernen wir Algebra und Geometrie sowie das Rechnen mit Zahlen",
expectedContain: []string{"Mathematik"},
},
{
name: "Deutsch detection",
content: "Im Fach Deutsch geht es um Grammatik, Rechtschreibung und das Lesen von Texten",
expectedContain: []string{"Deutsch"},
},
{
name: "Multiple subjects",
content: "Mathematik und Algebra verbinden sich mit Geschichte und historischen Epochen",
expectedContain: []string{"Mathematik", "Geschichte"},
},
{
name: "No subjects detected",
content: "Ein Text ohne spezifische Fachbegriffe",
expectedContain: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tagger.Tag("https://example.com", "", tt.content, ContentFeatures{})
for _, expected := range tt.expectedContain {
found := false
for _, subject := range result.Subjects {
if subject == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected subject %q not found in %v", expected, result.Subjects)
}
}
})
}
}
func TestTagger_TagSchoolLevel(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
content string
expected string
}{
{
name: "Grundschule detection",
content: "Material für die Grundschule und Primarstufe",
expected: "Grundschule",
},
{
name: "Gymnasium detection",
content: "Vorbereitung auf das Abitur am Gymnasium in der Oberstufe",
expected: "Gymnasium",
},
{
name: "Sekundarstufe I detection",
content: "Aufgaben für Sekundarstufe I in Klasse 5-10",
expected: "Sek_I",
},
{
name: "No level detected",
content: "Allgemeine Bildungsinformationen",
expected: "NA",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tagger.Tag("https://example.com", "", tt.content, ContentFeatures{})
if result.SchoolLevel != tt.expected {
t.Errorf("Expected SchoolLevel %q, got %q", tt.expected, result.SchoolLevel)
}
})
}
}
func TestTagger_TrustScore(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
url string
features ContentFeatures
minExpected float64
maxExpected float64
}{
{
name: "Base score for unknown domain",
url: "https://unknown-domain.com/page",
features: ContentFeatures{ContentLength: 500},
minExpected: 0.40,
maxExpected: 0.60,
},
{
name: "KMK domain boost",
url: "https://www.kmk.org/bildung",
features: ContentFeatures{ContentLength: 500},
minExpected: 0.70,
maxExpected: 0.90,
},
{
name: "Bayern domain boost",
url: "https://www.km.bayern.de/lehrplan",
features: ContentFeatures{ContentLength: 500},
minExpected: 0.60,
maxExpected: 0.80,
},
{
name: "Forum penalty",
url: "https://example.com/forum/thread",
features: ContentFeatures{ContentLength: 500},
minExpected: 0.30,
maxExpected: 0.50,
},
{
name: "High ad density penalty",
url: "https://example.com/page",
features: ContentFeatures{AdDensity: 0.5, ContentLength: 500},
minExpected: 0.25,
maxExpected: 0.50,
},
{
name: "Short content penalty",
url: "https://example.com/page",
features: ContentFeatures{ContentLength: 100},
minExpected: 0.30,
maxExpected: 0.50,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tagger.Tag(tt.url, "", "Some content text", tt.features)
if result.TrustScore < tt.minExpected || result.TrustScore > tt.maxExpected {
t.Errorf("TrustScore %f not in expected range [%f, %f]",
result.TrustScore, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestTagger_DetectState(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
url string
expected string
}{
{
name: "Bayern detection",
url: "https://www.km.bayern.de/lehrplan",
expected: "BY",
},
{
name: "NRW detection",
url: "https://www.schulministerium.nrw.de/themen",
expected: "NW",
},
{
name: "Berlin detection",
url: "https://www.berlin.de/sen/bildung/schule",
expected: "BE",
},
{
name: "Hessen detection",
url: "https://kultusministerium.hessen.de",
expected: "HE",
},
{
name: "No state (federal)",
url: "https://www.kmk.org/bildung",
expected: "",
},
{
name: "Unknown domain",
url: "https://www.example.com/page",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tagger.Tag(tt.url, "", "Some content", ContentFeatures{})
if result.State != tt.expected {
t.Errorf("Expected State %q, got %q", tt.expected, result.State)
}
})
}
}
func TestMatchDomainPattern(t *testing.T) {
tests := []struct {
name string
url string
pattern string
expected bool
}{
{
name: "Exact match",
url: "https://kmk.org/page",
pattern: "kmk.org",
expected: true,
},
{
name: "Wildcard subdomain",
url: "https://www.kmk.org/page",
pattern: "*.kmk.org",
expected: true,
},
{
name: "No match",
url: "https://example.com/page",
pattern: "*.kmk.org",
expected: false,
},
{
name: "Case insensitive",
url: "https://WWW.KMK.ORG/page",
pattern: "*.kmk.org",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchDomainPattern(tt.url, tt.pattern)
if result != tt.expected {
t.Errorf("matchDomainPattern(%q, %q) = %v, expected %v",
tt.url, tt.pattern, result, tt.expected)
}
})
}
}
func TestTagger_CombinedTitleAndContent(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
// Test that title is combined with content for tagging
result := tagger.Tag(
"https://example.com/page",
"Lehrplan Mathematik Bayern", // Title with keywords
"Allgemeiner Text ohne spezifische Begriffe", // Content without keywords
ContentFeatures{ContentLength: 500},
)
if result.DocType != "Lehrplan" {
t.Errorf("Expected DocType 'Lehrplan' from title, got %q", result.DocType)
}
}
func TestTrustScoreClamping(t *testing.T) {
rulesDir := createTestRulesDir(t)
tagger, err := NewTagger(rulesDir)
if err != nil {
t.Fatal(err)
}
// Test that score is clamped to [0, 1]
result := tagger.Tag(
"https://www.kmk.org/page", // High trust domain
"",
"Content",
ContentFeatures{ContentLength: 1000},
)
if result.TrustScore < 0 || result.TrustScore > 1 {
t.Errorf("TrustScore %f should be in range [0, 1]", result.TrustScore)
}
}