Initial commit: breakpilot-core - Shared Infrastructure

Docker Compose with 24+ services:
- PostgreSQL (PostGIS), Valkey, MinIO, Qdrant
- Vault (PKI/TLS), Nginx (Reverse Proxy)
- Backend Core API, Consent Service, Billing Service
- RAG Service, Embedding Service
- Gitea, Woodpecker CI/CD
- Night Scheduler, Health Aggregator
- Jitsi (Web/XMPP/JVB/Jicofo), Mailpit

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Boenisch
2026-02-11 23:47:13 +01:00
commit ad111d5e69
244 changed files with 84288 additions and 0 deletions
+65
View File
@@ -0,0 +1,65 @@
# BreakPilot Core — Shared Infrastructure
## Entwicklungsumgebung
### Zwei-Rechner-Setup
| Gerät | Rolle |
|-------|-------|
| **MacBook** | Client/Terminal |
| **Mac Mini** | Server/Docker/Git |
```bash
ssh macmini "cd /Users/benjaminadmin/Projekte/breakpilot-core && <cmd>"
```
## Projektübersicht
**breakpilot-core** ist das Infrastruktur-Projekt der BreakPilot-Plattform. Es stellt alle gemeinsamen Services bereit, die von **breakpilot-lehrer** und **breakpilot-compliance** genutzt werden.
### Enthaltene Services (~28 Container)
| Service | Port | Beschreibung |
|---------|------|--------------|
| nginx | 80/443 | Reverse Proxy (SSL) |
| postgres | 5432 | PostGIS 16 (3 Schemas: core, lehrer, compliance) |
| valkey | 6379 | Session-Cache |
| vault | 8200 | Secrets Management |
| qdrant | 6333 | Vektordatenbank |
| minio | 9000 | S3 Storage |
| backend-core | 8000 | Shared APIs (Auth, RBAC, Notifications) |
| rag-service | 8097 | RAG: Dokumente, Suche, Embeddings |
| embedding-service | 8087 | Text-Embeddings |
| consent-service | 8081 | Consent-Management |
| health-aggregator | 8099 | Health-Check aller Services |
| gitea | 3003 | Git-Server |
| woodpecker | 8090 | CI/CD |
| camunda | 8089 | BPMN |
| synapse | 8008 | Matrix Chat |
| jitsi | 8443 | Video |
| mailpit | 8025 | E-Mail (Dev) |
### Docker-Netzwerk
Alle 3 Projekte teilen sich das `breakpilot-network`:
```yaml
networks:
breakpilot-network:
driver: bridge
name: breakpilot-network
```
### Start-Reihenfolge
```bash
# 1. Core MUSS zuerst starten
docker compose up -d
# 2. Dann Lehrer und Compliance (warten auf Core Health)
```
### DB-Schemas
- `core` — users, sessions, auth, rbac, notifications
- `lehrer` — classroom, units, klausuren, game
- `compliance` — compliance, dsr, gdpr, sdk
## Git Remotes
Immer zu BEIDEN pushen:
- `origin`: lokale Gitea (macmini:3003)
- `gitea`: gitea.meghsakha.com
+58
View File
@@ -0,0 +1,58 @@
# =========================================================
# BreakPilot Core — Environment Variables
# =========================================================
# Copy to .env and adjust values
# Database
POSTGRES_USER=breakpilot
POSTGRES_PASSWORD=breakpilot123
POSTGRES_DB=breakpilot_db
# Security
JWT_SECRET=your-super-secret-jwt-key-change-in-production
VAULT_TOKEN=breakpilot-dev-token
# MinIO (S3-compatible storage)
MINIO_ROOT_USER=breakpilot
MINIO_ROOT_PASSWORD=breakpilot123
MINIO_BUCKET=breakpilot-rag
# Environment
ENVIRONMENT=development
TZ=Europe/Berlin
# Embedding Service
EMBEDDING_BACKEND=local
LOCAL_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
LOCAL_RERANKER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2
# SMTP (Mailpit in dev)
SMTP_HOST=mailpit
SMTP_PORT=1025
# Synapse
SYNAPSE_SERVER_NAME=macmini
SYNAPSE_DB_PASSWORD=synapse_secret
# Jitsi
JICOFO_AUTH_PASSWORD=jicofo_secret
JVB_AUTH_PASSWORD=jvb_secret
JIBRI_XMPP_PASSWORD=jibri_secret
JIBRI_RECORDER_PASSWORD=recorder_secret
JITSI_PUBLIC_URL=https://macmini:8443
# ERPNext
ERPNEXT_DB_ROOT_PASSWORD=erpnext_root
ERPNEXT_DB_PASSWORD=erpnext_secret
ERPNEXT_ADMIN_PASSWORD=admin
# Woodpecker CI
WOODPECKER_HOST=http://macmini:8090
WOODPECKER_ADMIN=pilotadmin
WOODPECKER_AGENT_SECRET=woodpecker-secret
# Gitea Runner
GITEA_RUNNER_TOKEN=
# Session
SESSION_TTL_HOURS=24
+63
View File
@@ -0,0 +1,63 @@
# Environment
.env
.env.local
.env.backup
# Secrets
secrets/
*.pem
*.key
# Node
node_modules/
.next/
# Python
__pycache__/
*.pyc
venv/
.venv/
# Docker
backups/*.backup
# IDE
.idea/
.vscode/
*.swp
*.swo
.DS_Store
# Logs
*.log
# Large files
*.pdf
*.docx
*.xlsx
*.pptx
*.mp4
*.mp3
*.wav
# Compiled binaries
billing-service/billing-service
consent-service/server
*.exe
*.dll
*.so
*.dylib
# Large files
*.zip
*.gz
*.tar
*.sql.gz
*.pdf
*.docx
*.xlsx
*.pptx
# Coverage
coverage/
*.coverage
+15
View File
@@ -0,0 +1,15 @@
__pycache__
*.pyc
*.pyo
.git
.env
.env.*
.pytest_cache
venv
.venv
*.egg-info
.DS_Store
security-reports
scripts
tests
docs
+64
View File
@@ -0,0 +1,64 @@
# ============================================================
# BreakPilot Core Backend -- Multi-stage Docker build
# ============================================================
# ---------- Build stage ----------
FROM python:3.12-slim-bookworm AS builder
WORKDIR /app
# Build-time system libs (needed for asyncpg / psycopg2)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
# ---------- Runtime stage ----------
FROM python:3.12-slim-bookworm
WORKDIR /app
# Runtime system libs
# - libpango / libgdk-pixbuf / shared-mime-info -> WeasyPrint (pdf_service)
# - libgl1 / libglib2.0-0 -> OpenCV (file_processor)
# - curl -> healthcheck
RUN apt-get update && apt-get install -y --no-install-recommends \
libpango-1.0-0 \
libpangocairo-1.0-0 \
libgdk-pixbuf-2.0-0 \
libffi-dev \
shared-mime-info \
libgl1 \
libglib2.0-0 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy virtualenv from builder
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Non-root user
RUN useradd --create-home --shell /bin/bash appuser
# Copy application code
COPY --chown=appuser:appuser . .
USER appuser
# Python tweaks
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://127.0.0.1:8000/health || exit 1
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
+55
View File
@@ -0,0 +1,55 @@
"""
BreakPilot Authentication Module
Hybrid authentication supporting both Keycloak and local JWT tokens.
"""
from .keycloak_auth import (
# Config
KeycloakConfig,
KeycloakUser,
# Authenticators
KeycloakAuthenticator,
HybridAuthenticator,
# Exceptions
KeycloakAuthError,
TokenExpiredError,
TokenInvalidError,
KeycloakConfigError,
# Factory functions
get_keycloak_config_from_env,
get_authenticator,
get_auth,
# FastAPI dependencies
get_current_user,
require_role,
)
__all__ = [
# Config
"KeycloakConfig",
"KeycloakUser",
# Authenticators
"KeycloakAuthenticator",
"HybridAuthenticator",
# Exceptions
"KeycloakAuthError",
"TokenExpiredError",
"TokenInvalidError",
"KeycloakConfigError",
# Factory functions
"get_keycloak_config_from_env",
"get_authenticator",
"get_auth",
# FastAPI dependencies
"get_current_user",
"require_role",
]
+515
View File
@@ -0,0 +1,515 @@
"""
Keycloak Authentication Module
Implements token validation against Keycloak JWKS endpoint.
This module handles authentication (who is the user?), while
rbac.py handles authorization (what can the user do?).
Architecture:
- Keycloak validates JWT tokens and provides basic identity
- Our custom rbac.py handles fine-grained permissions
"""
import os
import httpx
import jwt
from jwt import PyJWKClient
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
from functools import lru_cache
import logging
logger = logging.getLogger(__name__)
@dataclass
class KeycloakConfig:
"""Keycloak connection configuration."""
server_url: str
realm: str
client_id: str
client_secret: Optional[str] = None
verify_ssl: bool = True
@property
def issuer_url(self) -> str:
return f"{self.server_url}/realms/{self.realm}"
@property
def jwks_url(self) -> str:
return f"{self.issuer_url}/protocol/openid-connect/certs"
@property
def token_url(self) -> str:
return f"{self.issuer_url}/protocol/openid-connect/token"
@property
def userinfo_url(self) -> str:
return f"{self.issuer_url}/protocol/openid-connect/userinfo"
@dataclass
class KeycloakUser:
"""User information extracted from Keycloak token."""
user_id: str # Keycloak subject (sub)
email: str
email_verified: bool
name: Optional[str]
given_name: Optional[str]
family_name: Optional[str]
realm_roles: List[str] # Keycloak realm roles
client_roles: Dict[str, List[str]] # Client-specific roles
groups: List[str] # Keycloak groups
tenant_id: Optional[str] # Custom claim for school/tenant
raw_claims: Dict[str, Any] # All claims for debugging
def has_realm_role(self, role: str) -> bool:
"""Check if user has a specific realm role."""
return role in self.realm_roles
def has_client_role(self, client_id: str, role: str) -> bool:
"""Check if user has a specific client role."""
client_roles = self.client_roles.get(client_id, [])
return role in client_roles
def is_admin(self) -> bool:
"""Check if user has admin role."""
return self.has_realm_role("admin") or self.has_realm_role("schul_admin")
def is_teacher(self) -> bool:
"""Check if user is a teacher."""
return self.has_realm_role("teacher") or self.has_realm_role("lehrer")
class KeycloakAuthError(Exception):
"""Base exception for Keycloak authentication errors."""
pass
class TokenExpiredError(KeycloakAuthError):
"""Token has expired."""
pass
class TokenInvalidError(KeycloakAuthError):
"""Token is invalid."""
pass
class KeycloakConfigError(KeycloakAuthError):
"""Keycloak configuration error."""
pass
class KeycloakAuthenticator:
"""
Validates JWT tokens against Keycloak.
Usage:
config = KeycloakConfig(
server_url="https://keycloak.example.com",
realm="breakpilot",
client_id="breakpilot-backend"
)
auth = KeycloakAuthenticator(config)
user = await auth.validate_token(token)
if user.is_teacher():
# Grant access
"""
def __init__(self, config: KeycloakConfig):
self.config = config
self._jwks_client: Optional[PyJWKClient] = None
self._http_client: Optional[httpx.AsyncClient] = None
@property
def jwks_client(self) -> PyJWKClient:
"""Lazy-load JWKS client."""
if self._jwks_client is None:
self._jwks_client = PyJWKClient(
self.config.jwks_url,
cache_keys=True,
lifespan=3600 # Cache keys for 1 hour
)
return self._jwks_client
async def get_http_client(self) -> httpx.AsyncClient:
"""Get or create async HTTP client."""
if self._http_client is None or self._http_client.is_closed:
self._http_client = httpx.AsyncClient(
verify=self.config.verify_ssl,
timeout=30.0
)
return self._http_client
async def close(self):
"""Close HTTP client."""
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
def validate_token_sync(self, token: str) -> KeycloakUser:
"""
Synchronously validate a JWT token against Keycloak JWKS.
Args:
token: The JWT access token
Returns:
KeycloakUser with extracted claims
Raises:
TokenExpiredError: If token has expired
TokenInvalidError: If token signature is invalid
"""
try:
# Get signing key from JWKS
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
# Decode and validate token
payload = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience=self.config.client_id,
issuer=self.config.issuer_url,
options={
"verify_exp": True,
"verify_iat": True,
"verify_aud": True,
"verify_iss": True
}
)
return self._extract_user(payload)
except jwt.ExpiredSignatureError:
raise TokenExpiredError("Token has expired")
except jwt.InvalidAudienceError:
raise TokenInvalidError("Invalid token audience")
except jwt.InvalidIssuerError:
raise TokenInvalidError("Invalid token issuer")
except jwt.InvalidTokenError as e:
raise TokenInvalidError(f"Invalid token: {e}")
except Exception as e:
logger.error(f"Token validation failed: {e}")
raise TokenInvalidError(f"Token validation failed: {e}")
async def validate_token(self, token: str) -> KeycloakUser:
"""
Asynchronously validate a JWT token.
Note: JWKS fetching is synchronous due to PyJWKClient limitations,
but this wrapper allows async context usage.
"""
return self.validate_token_sync(token)
async def get_userinfo(self, token: str) -> Dict[str, Any]:
"""
Fetch user info from Keycloak userinfo endpoint.
This provides additional user claims not in the access token.
"""
client = await self.get_http_client()
try:
response = await client.get(
self.config.userinfo_url,
headers={"Authorization": f"Bearer {token}"}
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise TokenExpiredError("Token is invalid or expired")
raise TokenInvalidError(f"Failed to fetch userinfo: {e}")
def _extract_user(self, payload: Dict[str, Any]) -> KeycloakUser:
"""Extract KeycloakUser from JWT payload."""
# Extract realm roles
realm_access = payload.get("realm_access", {})
realm_roles = realm_access.get("roles", [])
# Extract client roles
resource_access = payload.get("resource_access", {})
client_roles = {}
for client_id, access in resource_access.items():
client_roles[client_id] = access.get("roles", [])
# Extract groups
groups = payload.get("groups", [])
# Extract custom tenant claim (if configured in Keycloak)
tenant_id = payload.get("tenant_id") or payload.get("school_id")
return KeycloakUser(
user_id=payload.get("sub", ""),
email=payload.get("email", ""),
email_verified=payload.get("email_verified", False),
name=payload.get("name"),
given_name=payload.get("given_name"),
family_name=payload.get("family_name"),
realm_roles=realm_roles,
client_roles=client_roles,
groups=groups,
tenant_id=tenant_id,
raw_claims=payload
)
# =============================================
# HYBRID AUTH: Keycloak + Local JWT
# =============================================
class HybridAuthenticator:
"""
Hybrid authenticator supporting both Keycloak and local JWT tokens.
This allows gradual migration from local JWT to Keycloak:
1. Development: Use local JWT (fast, no external dependencies)
2. Production: Use Keycloak for full IAM capabilities
Token type detection:
- Keycloak tokens: Have 'iss' claim matching Keycloak URL
- Local tokens: Have 'iss' claim as 'breakpilot' or no 'iss'
"""
def __init__(
self,
keycloak_config: Optional[KeycloakConfig] = None,
local_jwt_secret: Optional[str] = None,
environment: str = "development"
):
self.environment = environment
self.keycloak_enabled = keycloak_config is not None
self.local_jwt_secret = local_jwt_secret
if keycloak_config:
self.keycloak_auth = KeycloakAuthenticator(keycloak_config)
else:
self.keycloak_auth = None
async def validate_token(self, token: str) -> Dict[str, Any]:
"""
Validate token using appropriate method.
Returns a unified user dict compatible with existing code.
"""
if not token:
raise TokenInvalidError("No token provided")
# Try to peek at the token to determine type
try:
# Decode without verification to check issuer
unverified = jwt.decode(token, options={"verify_signature": False})
issuer = unverified.get("iss", "")
except jwt.InvalidTokenError:
raise TokenInvalidError("Cannot decode token")
# Check if it's a Keycloak token
if self.keycloak_auth and self.keycloak_auth.config.issuer_url in issuer:
# Validate with Keycloak
kc_user = await self.keycloak_auth.validate_token(token)
return self._keycloak_user_to_dict(kc_user)
# Fall back to local JWT validation
if self.local_jwt_secret:
return self._validate_local_token(token)
raise TokenInvalidError("No valid authentication method available")
def _validate_local_token(self, token: str) -> Dict[str, Any]:
"""Validate token with local JWT secret."""
if not self.local_jwt_secret:
raise KeycloakConfigError("Local JWT secret not configured")
try:
payload = jwt.decode(
token,
self.local_jwt_secret,
algorithms=["HS256"]
)
# Map local token claims to unified format
return {
"user_id": payload.get("user_id", payload.get("sub", "")),
"email": payload.get("email", ""),
"name": payload.get("name", ""),
"role": payload.get("role", "user"),
"realm_roles": [payload.get("role", "user")],
"tenant_id": payload.get("tenant_id", payload.get("school_id")),
"auth_method": "local_jwt"
}
except jwt.ExpiredSignatureError:
raise TokenExpiredError("Token has expired")
except jwt.InvalidTokenError as e:
raise TokenInvalidError(f"Invalid local token: {e}")
def _keycloak_user_to_dict(self, user: KeycloakUser) -> Dict[str, Any]:
"""Convert KeycloakUser to dict compatible with existing code."""
# Map Keycloak roles to our role system
role = "user"
if user.is_admin():
role = "admin"
elif user.is_teacher():
role = "teacher"
return {
"user_id": user.user_id,
"email": user.email,
"name": user.name or f"{user.given_name or ''} {user.family_name or ''}".strip(),
"role": role,
"realm_roles": user.realm_roles,
"client_roles": user.client_roles,
"groups": user.groups,
"tenant_id": user.tenant_id,
"email_verified": user.email_verified,
"auth_method": "keycloak"
}
async def close(self):
"""Cleanup resources."""
if self.keycloak_auth:
await self.keycloak_auth.close()
# =============================================
# FACTORY FUNCTIONS
# =============================================
def get_keycloak_config_from_env() -> Optional[KeycloakConfig]:
"""
Create KeycloakConfig from environment variables.
Required env vars:
- KEYCLOAK_SERVER_URL: e.g., https://keycloak.breakpilot.app
- KEYCLOAK_REALM: e.g., breakpilot
- KEYCLOAK_CLIENT_ID: e.g., breakpilot-backend
Optional:
- KEYCLOAK_CLIENT_SECRET: For confidential clients
- KEYCLOAK_VERIFY_SSL: Default true
"""
server_url = os.environ.get("KEYCLOAK_SERVER_URL")
realm = os.environ.get("KEYCLOAK_REALM")
client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
if not all([server_url, realm, client_id]):
logger.info("Keycloak not configured, using local JWT only")
return None
return KeycloakConfig(
server_url=server_url,
realm=realm,
client_id=client_id,
client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"),
verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true"
)
def get_authenticator() -> HybridAuthenticator:
"""
Get configured authenticator instance.
Uses environment variables to determine configuration.
"""
keycloak_config = get_keycloak_config_from_env()
# JWT_SECRET is required - no default fallback in production
jwt_secret = os.environ.get("JWT_SECRET")
environment = os.environ.get("ENVIRONMENT", "development")
if not jwt_secret and environment == "production":
raise KeycloakConfigError(
"JWT_SECRET environment variable is required in production"
)
return HybridAuthenticator(
keycloak_config=keycloak_config,
local_jwt_secret=jwt_secret,
environment=environment
)
# =============================================
# FASTAPI DEPENDENCY
# =============================================
from fastapi import Request, HTTPException, Depends
# Global authenticator instance (lazy-initialized)
_authenticator: Optional[HybridAuthenticator] = None
def get_auth() -> HybridAuthenticator:
"""Get or create global authenticator."""
global _authenticator
if _authenticator is None:
_authenticator = get_authenticator()
return _authenticator
async def get_current_user(request: Request) -> Dict[str, Any]:
"""
FastAPI dependency to get current authenticated user.
Usage:
@app.get("/api/protected")
async def protected_endpoint(user: dict = Depends(get_current_user)):
return {"user_id": user["user_id"]}
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
# Check for development mode
environment = os.environ.get("ENVIRONMENT", "development")
if environment == "development":
# Return demo user in development without token
return {
"user_id": "10000000-0000-0000-0000-000000000024",
"email": "demo@breakpilot.app",
"role": "admin",
"realm_roles": ["admin"],
"tenant_id": "a0000000-0000-0000-0000-000000000001",
"auth_method": "development_bypass"
}
raise HTTPException(status_code=401, detail="Missing authorization header")
token = auth_header.split(" ")[1]
try:
auth = get_auth()
return await auth.validate_token(token)
except TokenExpiredError:
raise HTTPException(status_code=401, detail="Token expired")
except TokenInvalidError as e:
raise HTTPException(status_code=401, detail=str(e))
except Exception as e:
logger.error(f"Authentication failed: {e}")
raise HTTPException(status_code=401, detail="Authentication failed")
async def require_role(required_role: str):
"""
FastAPI dependency factory for role-based access.
Usage:
@app.get("/api/admin-only")
async def admin_endpoint(user: dict = Depends(require_role("admin"))):
return {"message": "Admin access granted"}
"""
async def role_checker(user: dict = Depends(get_current_user)) -> dict:
user_role = user.get("role", "user")
realm_roles = user.get("realm_roles", [])
if user_role == required_role or required_role in realm_roles:
return user
raise HTTPException(
status_code=403,
detail=f"Role '{required_role}' required"
)
return role_checker
+373
View File
@@ -0,0 +1,373 @@
"""
Authentication API Endpoints für BreakPilot
Proxy für den Go Consent Service Authentication
"""
import httpx
from fastapi import APIRouter, HTTPException, Header, Request, Response
from typing import Optional
from pydantic import BaseModel, EmailStr
import os
# Consent Service URL
CONSENT_SERVICE_URL = os.getenv("CONSENT_SERVICE_URL", "http://localhost:8081")
router = APIRouter(prefix="/auth", tags=["authentication"])
# ==========================================
# Request/Response Models
# ==========================================
class RegisterRequest(BaseModel):
email: EmailStr
password: str
name: Optional[str] = None
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshTokenRequest(BaseModel):
refresh_token: str
class VerifyEmailRequest(BaseModel):
token: str
class ForgotPasswordRequest(BaseModel):
email: EmailStr
class ResetPasswordRequest(BaseModel):
token: str
new_password: str
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str
class UpdateProfileRequest(BaseModel):
name: Optional[str] = None
class LogoutRequest(BaseModel):
refresh_token: Optional[str] = None
# ==========================================
# Helper Functions
# ==========================================
def get_auth_headers(authorization: Optional[str]) -> dict:
"""Erstellt Header mit Authorization Token"""
headers = {"Content-Type": "application/json"}
if authorization:
headers["Authorization"] = authorization
return headers
async def proxy_to_consent_service(
method: str,
path: str,
json_data: dict = None,
headers: dict = None,
params: dict = None
) -> dict:
"""
Proxy-Aufruf zum Go Consent Service.
Wirft HTTPException bei Fehlern.
"""
url = f"{CONSENT_SERVICE_URL}/api/v1{path}"
async with httpx.AsyncClient() as client:
try:
if method == "GET":
response = await client.get(url, headers=headers, params=params, timeout=10.0)
elif method == "POST":
response = await client.post(url, headers=headers, json=json_data, timeout=10.0)
elif method == "PUT":
response = await client.put(url, headers=headers, json=json_data, timeout=10.0)
elif method == "DELETE":
response = await client.delete(url, headers=headers, params=params, timeout=10.0)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# Parse JSON response
try:
data = response.json()
except:
data = {"message": response.text}
# Handle error responses
if response.status_code >= 400:
error_msg = data.get("error", "Unknown error")
raise HTTPException(status_code=response.status_code, detail=error_msg)
return data
except httpx.RequestError as e:
raise HTTPException(
status_code=503,
detail=f"Consent Service nicht erreichbar: {str(e)}"
)
# ==========================================
# Public Auth Endpoints (No Auth Required)
# ==========================================
@router.post("/register")
async def register(request: RegisterRequest, req: Request):
"""
Registriert einen neuen Benutzer.
Sendet eine Verifizierungs-E-Mail.
"""
data = await proxy_to_consent_service(
"POST",
"/auth/register",
json_data={
"email": request.email,
"password": request.password,
"name": request.name
}
)
return data
@router.post("/login")
async def login(request: LoginRequest, req: Request):
"""
Meldet einen Benutzer an.
Gibt Access Token und Refresh Token zurück.
"""
# Get client info for session tracking
client_ip = req.client.host if req.client else "unknown"
user_agent = req.headers.get("user-agent", "unknown")
data = await proxy_to_consent_service(
"POST",
"/auth/login",
json_data={
"email": request.email,
"password": request.password
},
headers={
"X-Forwarded-For": client_ip,
"User-Agent": user_agent
}
)
return data
@router.post("/logout")
async def logout(request: LogoutRequest):
"""
Meldet den Benutzer ab und invalidiert den Refresh Token.
"""
data = await proxy_to_consent_service(
"POST",
"/auth/logout",
json_data={"refresh_token": request.refresh_token} if request.refresh_token else {}
)
return data
@router.post("/refresh")
async def refresh_token(request: RefreshTokenRequest):
"""
Erneuert den Access Token mit einem gültigen Refresh Token.
"""
data = await proxy_to_consent_service(
"POST",
"/auth/refresh",
json_data={"refresh_token": request.refresh_token}
)
return data
@router.post("/verify-email")
async def verify_email(request: VerifyEmailRequest):
"""
Verifiziert die E-Mail-Adresse mit dem Token aus der E-Mail.
"""
data = await proxy_to_consent_service(
"POST",
"/auth/verify-email",
json_data={"token": request.token}
)
return data
@router.post("/resend-verification")
async def resend_verification(email: EmailStr):
"""
Sendet die Verifizierungs-E-Mail erneut.
"""
data = await proxy_to_consent_service(
"POST",
"/auth/resend-verification",
json_data={"email": email}
)
return data
@router.post("/forgot-password")
async def forgot_password(request: ForgotPasswordRequest, req: Request):
"""
Initiiert den Passwort-Reset-Prozess.
Sendet eine E-Mail mit Reset-Link.
"""
client_ip = req.client.host if req.client else "unknown"
data = await proxy_to_consent_service(
"POST",
"/auth/forgot-password",
json_data={"email": request.email},
headers={"X-Forwarded-For": client_ip}
)
return data
@router.post("/reset-password")
async def reset_password(request: ResetPasswordRequest):
"""
Setzt das Passwort mit dem Token aus der E-Mail zurück.
"""
data = await proxy_to_consent_service(
"POST",
"/auth/reset-password",
json_data={
"token": request.token,
"new_password": request.new_password
}
)
return data
# ==========================================
# Protected Profile Endpoints (Auth Required)
# ==========================================
@router.get("/profile")
async def get_profile(authorization: Optional[str] = Header(None)):
"""
Gibt das Profil des angemeldeten Benutzers zurück.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
data = await proxy_to_consent_service(
"GET",
"/profile",
headers=get_auth_headers(authorization)
)
return data
@router.put("/profile")
async def update_profile(
request: UpdateProfileRequest,
authorization: Optional[str] = Header(None)
):
"""
Aktualisiert das Profil des angemeldeten Benutzers.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
data = await proxy_to_consent_service(
"PUT",
"/profile",
json_data={"name": request.name},
headers=get_auth_headers(authorization)
)
return data
@router.put("/profile/password")
async def change_password(
request: ChangePasswordRequest,
authorization: Optional[str] = Header(None)
):
"""
Ändert das Passwort des angemeldeten Benutzers.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
data = await proxy_to_consent_service(
"PUT",
"/profile/password",
json_data={
"current_password": request.current_password,
"new_password": request.new_password
},
headers=get_auth_headers(authorization)
)
return data
@router.get("/profile/sessions")
async def get_sessions(authorization: Optional[str] = Header(None)):
"""
Gibt alle aktiven Sessions des Benutzers zurück.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
data = await proxy_to_consent_service(
"GET",
"/profile/sessions",
headers=get_auth_headers(authorization)
)
return data
@router.delete("/profile/sessions/{session_id}")
async def revoke_session(
session_id: str,
authorization: Optional[str] = Header(None)
):
"""
Beendet eine bestimmte Session.
"""
if not authorization:
raise HTTPException(status_code=401, detail="Authorization header required")
data = await proxy_to_consent_service(
"DELETE",
f"/profile/sessions/{session_id}",
headers=get_auth_headers(authorization)
)
return data
# ==========================================
# Health Check
# ==========================================
@router.get("/health")
async def auth_health():
"""
Prüft die Verbindung zum Auth Service.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{CONSENT_SERVICE_URL}/health",
timeout=5.0
)
is_healthy = response.status_code == 200
except:
is_healthy = False
return {
"auth_service": "healthy" if is_healthy else "unavailable",
"connected": is_healthy
}
+18
View File
@@ -0,0 +1,18 @@
from pathlib import Path
BASE_DIR = Path.home() / "Arbeitsblaetter"
EINGANG_DIR = BASE_DIR / "Eingang"
BEREINIGT_DIR = BASE_DIR / "Bereinigt"
EDITIERBAR_DIR = BASE_DIR / "Editierbar"
NEU_GENERIERT_DIR = BASE_DIR / "Neu_generiert"
VALID_SUFFIXES = {".jpg", ".jpeg", ".png", ".pdf", ".JPG", ".JPEG", ".PNG", ".PDF"}
# Ordner sicherstellen
for d in [EINGANG_DIR, BEREINIGT_DIR, EDITIERBAR_DIR, NEU_GENERIERT_DIR]:
d.mkdir(parents=True, exist_ok=True)
def is_valid_input_file(path: Path) -> bool:
"""Gemeinsame Filterlogik für Eingangsdateien."""
return path.is_file() and not path.name.startswith(".") and path.suffix in VALID_SUFFIXES
+359
View File
@@ -0,0 +1,359 @@
"""
Consent Service Client für BreakPilot
Kommuniziert mit dem Consent Management Service für GDPR-Compliance
"""
import httpx
import jwt
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import os
import uuid
# Consent Service URL (aus Umgebungsvariable oder Standard für lokale Entwicklung)
CONSENT_SERVICE_URL = os.getenv("CONSENT_SERVICE_URL", "http://localhost:8081")
# JWT Secret - MUSS mit dem Go Consent Service übereinstimmen!
JWT_SECRET = os.getenv("JWT_SECRET", "breakpilot-dev-jwt-secret-2024")
def generate_jwt_token(
user_id: str = None,
email: str = "demo@breakpilot.app",
role: str = "user",
expires_hours: int = 24
) -> str:
"""
Generiert einen JWT Token für die Authentifizierung beim Consent Service.
Args:
user_id: Die User-ID (wird generiert falls nicht angegeben)
email: Die E-Mail-Adresse des Benutzers
role: Die Rolle (user, admin, super_admin)
expires_hours: Gültigkeitsdauer in Stunden
Returns:
JWT Token als String
"""
if user_id is None:
user_id = str(uuid.uuid4())
payload = {
"user_id": user_id,
"email": email,
"role": role,
"exp": datetime.utcnow() + timedelta(hours=expires_hours),
"iat": datetime.utcnow(),
}
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
def generate_demo_token() -> str:
"""Generiert einen Demo-Token für nicht-authentifizierte Benutzer"""
return generate_jwt_token(
user_id="demo-user-" + str(uuid.uuid4())[:8],
email="demo@breakpilot.app",
role="user"
)
class DocumentType(str, Enum):
TERMS = "terms"
PRIVACY = "privacy"
COOKIES = "cookies"
COMMUNITY = "community"
@dataclass
class ConsentStatus:
has_consent: bool
current_version_id: Optional[str] = None
consented_version: Optional[str] = None
needs_update: bool = False
consented_at: Optional[str] = None
@dataclass
class DocumentVersion:
id: str
document_id: str
version: str
language: str
title: str
content: str
summary: Optional[str] = None
class ConsentClient:
"""Client für die Kommunikation mit dem Consent Service"""
def __init__(self, base_url: str = CONSENT_SERVICE_URL):
self.base_url = base_url.rstrip("/")
self.api_url = f"{self.base_url}/api/v1"
def _get_headers(self, jwt_token: str) -> Dict[str, str]:
"""Erstellt die Header mit JWT Token"""
return {
"Authorization": f"Bearer {jwt_token}",
"Content-Type": "application/json"
}
async def check_consent(
self,
jwt_token: str,
document_type: DocumentType,
language: str = "de"
) -> ConsentStatus:
"""
Prüft ob der Benutzer dem Dokument zugestimmt hat.
Gibt zurück ob eine Zustimmung vorliegt und ob sie aktuell ist.
"""
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{self.api_url}/consent/check/{document_type.value}",
headers=self._get_headers(jwt_token),
params={"language": language},
timeout=10.0
)
if response.status_code == 200:
data = response.json()
return ConsentStatus(
has_consent=data.get("has_consent", False),
current_version_id=data.get("current_version_id"),
consented_version=data.get("consented_version"),
needs_update=data.get("needs_update", False),
consented_at=data.get("consented_at")
)
else:
return ConsentStatus(has_consent=False, needs_update=True)
except httpx.RequestError:
# Bei Verbindungsproblemen: Consent nicht erzwingen
return ConsentStatus(has_consent=True, needs_update=False)
async def check_all_mandatory_consents(
self,
jwt_token: str,
language: str = "de"
) -> Dict[str, ConsentStatus]:
"""
Prüft alle verpflichtenden Dokumente (Terms, Privacy).
Gibt ein Dictionary mit dem Status für jedes Dokument zurück.
"""
mandatory_docs = [DocumentType.TERMS, DocumentType.PRIVACY]
results = {}
for doc_type in mandatory_docs:
results[doc_type.value] = await self.check_consent(jwt_token, doc_type, language)
return results
async def get_pending_consents(
self,
jwt_token: str,
language: str = "de"
) -> List[Dict[str, Any]]:
"""
Gibt eine Liste aller Dokumente zurück, die noch Zustimmung benötigen.
Nützlich für die Anzeige beim Login/Registration.
"""
pending = []
statuses = await self.check_all_mandatory_consents(jwt_token, language)
for doc_type, status in statuses.items():
if not status.has_consent or status.needs_update:
# Hole das aktuelle Dokument
doc = await self.get_latest_document(jwt_token, doc_type, language)
if doc:
pending.append({
"type": doc_type,
"version_id": status.current_version_id,
"title": doc.title,
"content": doc.content,
"summary": doc.summary,
"is_update": status.has_consent and status.needs_update
})
return pending
async def get_latest_document(
self,
jwt_token: str,
document_type: str,
language: str = "de"
) -> Optional[DocumentVersion]:
"""Holt die aktuellste Version eines Dokuments"""
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{self.api_url}/documents/{document_type}/latest",
headers=self._get_headers(jwt_token),
params={"language": language},
timeout=10.0
)
if response.status_code == 200:
data = response.json()
return DocumentVersion(
id=data["id"],
document_id=data["document_id"],
version=data["version"],
language=data["language"],
title=data["title"],
content=data["content"],
summary=data.get("summary")
)
return None
except httpx.RequestError:
return None
async def give_consent(
self,
jwt_token: str,
document_type: str,
version_id: str,
consented: bool = True
) -> bool:
"""
Speichert die Zustimmung des Benutzers.
Gibt True zurück bei Erfolg.
"""
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{self.api_url}/consent",
headers=self._get_headers(jwt_token),
json={
"document_type": document_type,
"version_id": version_id,
"consented": consented
},
timeout=10.0
)
return response.status_code == 201
except httpx.RequestError:
return False
async def get_cookie_categories(
self,
jwt_token: str,
language: str = "de"
) -> List[Dict[str, Any]]:
"""Holt alle Cookie-Kategorien für das Cookie-Banner"""
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{self.api_url}/cookies/categories",
headers=self._get_headers(jwt_token),
params={"language": language},
timeout=10.0
)
if response.status_code == 200:
return response.json().get("categories", [])
return []
except httpx.RequestError:
return []
async def set_cookie_consent(
self,
jwt_token: str,
categories: List[Dict[str, Any]]
) -> bool:
"""
Speichert die Cookie-Präferenzen.
categories: [{"category_id": "...", "consented": true/false}, ...]
"""
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{self.api_url}/cookies/consent",
headers=self._get_headers(jwt_token),
json={"categories": categories},
timeout=10.0
)
return response.status_code == 200
except httpx.RequestError:
return False
async def get_my_data(self, jwt_token: str) -> Optional[Dict[str, Any]]:
"""GDPR Art. 15: Holt alle Daten des Benutzers"""
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{self.api_url}/privacy/my-data",
headers=self._get_headers(jwt_token),
timeout=30.0
)
if response.status_code == 200:
return response.json()
return None
except httpx.RequestError:
return None
async def request_data_export(self, jwt_token: str) -> Optional[str]:
"""GDPR Art. 20: Fordert einen Datenexport an"""
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{self.api_url}/privacy/export",
headers=self._get_headers(jwt_token),
timeout=10.0
)
if response.status_code == 202:
return response.json().get("request_id")
return None
except httpx.RequestError:
return None
async def request_data_deletion(
self,
jwt_token: str,
reason: Optional[str] = None
) -> Optional[str]:
"""GDPR Art. 17: Fordert Löschung aller Daten an"""
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{self.api_url}/privacy/delete",
headers=self._get_headers(jwt_token),
json={"reason": reason} if reason else {},
timeout=10.0
)
if response.status_code == 202:
return response.json().get("request_id")
return None
except httpx.RequestError:
return None
async def health_check(self) -> bool:
"""Prüft ob der Consent Service erreichbar ist"""
async with httpx.AsyncClient() as client:
try:
response = await client.get(
f"{self.base_url}/health",
timeout=5.0
)
return response.status_code == 200
except httpx.RequestError:
return False
# Singleton-Instanz für einfachen Zugriff
consent_client = ConsentClient()
+252
View File
@@ -0,0 +1,252 @@
"""
E-Mail Template API für BreakPilot
Proxy für den Go Consent Service E-Mail Template Management
"""
from fastapi import APIRouter, Request, HTTPException, Depends
from fastapi.responses import JSONResponse
import httpx
from typing import Optional
import os
from consent_client import CONSENT_SERVICE_URL, generate_jwt_token
router = APIRouter(prefix="/api/consent/admin/email-templates", tags=["Email Templates"])
# Base URL für E-Mail-Template-Endpunkte im Go Consent Service
EMAIL_TEMPLATE_BASE = f"{CONSENT_SERVICE_URL}/api/v1/admin"
async def get_admin_token() -> str:
"""Generiert einen Admin-Token für API-Calls zum Consent Service"""
return generate_jwt_token(
user_id="a0000000-0000-0000-0000-000000000001",
email="admin@breakpilot.app",
role="admin",
expires_hours=1
)
async def proxy_request(
method: str,
path: str,
token: str,
json_data: dict = None,
params: dict = None
) -> dict:
"""Proxy-Funktion für API-Calls zum Go Consent Service"""
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
url = f"{EMAIL_TEMPLATE_BASE}{path}"
async with httpx.AsyncClient(timeout=30.0) as client:
try:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
elif method == "POST":
response = await client.post(url, headers=headers, json=json_data)
elif method == "PUT":
response = await client.put(url, headers=headers, json=json_data)
elif method == "DELETE":
response = await client.delete(url, headers=headers)
else:
raise ValueError(f"Unsupported method: {method}")
if response.status_code >= 400:
error_detail = response.text
try:
error_detail = response.json().get("error", response.text)
except:
pass
raise HTTPException(status_code=response.status_code, detail=error_detail)
if response.status_code == 204:
return {"success": True}
return response.json()
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"Consent Service nicht erreichbar: {str(e)}")
# ==========================================
# E-Mail Template Typen
# ==========================================
@router.get("/types")
async def get_all_template_types():
"""Gibt alle verfügbaren E-Mail-Template-Typen zurück"""
token = await get_admin_token()
return await proxy_request("GET", "/email-templates/types", token)
# ==========================================
# E-Mail Templates
# ==========================================
@router.get("")
async def get_all_templates():
"""Gibt alle E-Mail-Templates zurück"""
token = await get_admin_token()
return await proxy_request("GET", "/email-templates", token)
@router.post("")
async def create_template(request: Request):
"""Erstellt ein neues E-Mail-Template"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("POST", "/email-templates", token, json_data=data)
@router.get("/settings")
async def get_settings():
"""Gibt die E-Mail-Einstellungen zurück"""
token = await get_admin_token()
return await proxy_request("GET", "/email-templates/settings", token)
@router.put("/settings")
async def update_settings(request: Request):
"""Aktualisiert die E-Mail-Einstellungen"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("PUT", "/email-templates/settings", token, json_data=data)
@router.get("/stats")
async def get_email_stats():
"""Gibt E-Mail-Statistiken zurück"""
token = await get_admin_token()
return await proxy_request("GET", "/email-templates/stats", token)
@router.get("/logs")
async def get_send_logs(
template_id: Optional[str] = None,
status: Optional[str] = None,
limit: int = 100,
offset: int = 0
):
"""Gibt E-Mail-Send-Logs zurück"""
token = await get_admin_token()
params = {"limit": limit, "offset": offset}
if template_id:
params["template_id"] = template_id
if status:
params["status"] = status
return await proxy_request("GET", "/email-templates/logs", token, params=params)
@router.get("/default/{template_type}")
async def get_default_content(template_type: str):
"""Gibt den Default-Inhalt für einen Template-Typ zurück"""
token = await get_admin_token()
return await proxy_request("GET", f"/email-templates/default/{template_type}", token)
@router.post("/initialize")
async def initialize_templates():
"""Initialisiert alle Standard-Templates"""
token = await get_admin_token()
return await proxy_request("POST", "/email-templates/initialize", token)
@router.get("/{template_id}")
async def get_template(template_id: str):
"""Gibt ein einzelnes E-Mail-Template zurück"""
token = await get_admin_token()
return await proxy_request("GET", f"/email-templates/{template_id}", token)
@router.get("/{template_id}/versions")
async def get_template_versions(template_id: str):
"""Gibt alle Versionen eines Templates zurück"""
token = await get_admin_token()
return await proxy_request("GET", f"/email-templates/{template_id}/versions", token)
# ==========================================
# E-Mail Template Versionen
# ==========================================
versions_router = APIRouter(prefix="/api/consent/admin/email-template-versions", tags=["Email Template Versions"])
@versions_router.get("/{version_id}")
async def get_version(version_id: str):
"""Gibt eine einzelne Version zurück"""
token = await get_admin_token()
return await proxy_request("GET", f"/email-template-versions/{version_id}", token)
@versions_router.post("")
async def create_version(request: Request):
"""Erstellt eine neue Version"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("POST", "/email-template-versions", token, json_data=data)
@versions_router.put("/{version_id}")
async def update_version(version_id: str, request: Request):
"""Aktualisiert eine Version"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("PUT", f"/email-template-versions/{version_id}", token, json_data=data)
@versions_router.post("/{version_id}/submit")
async def submit_for_review(version_id: str):
"""Sendet eine Version zur Überprüfung"""
token = await get_admin_token()
return await proxy_request("POST", f"/email-template-versions/{version_id}/submit", token)
@versions_router.post("/{version_id}/approve")
async def approve_version(version_id: str, request: Request):
"""Genehmigt eine Version"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("POST", f"/email-template-versions/{version_id}/approve", token, json_data=data)
@versions_router.post("/{version_id}/reject")
async def reject_version(version_id: str, request: Request):
"""Lehnt eine Version ab"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("POST", f"/email-template-versions/{version_id}/reject", token, json_data=data)
@versions_router.post("/{version_id}/publish")
async def publish_version(version_id: str):
"""Veröffentlicht eine Version"""
token = await get_admin_token()
return await proxy_request("POST", f"/email-template-versions/{version_id}/publish", token)
@versions_router.get("/{version_id}/approvals")
async def get_approvals(version_id: str):
"""Gibt die Genehmigungshistorie einer Version zurück"""
token = await get_admin_token()
return await proxy_request("GET", f"/email-template-versions/{version_id}/approvals", token)
@versions_router.post("/{version_id}/preview")
async def preview_version(version_id: str, request: Request):
"""Generiert eine Vorschau einer Version"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("POST", f"/email-template-versions/{version_id}/preview", token, json_data=data)
@versions_router.post("/{version_id}/send-test")
async def send_test_email(version_id: str, request: Request):
"""Sendet eine Test-E-Mail"""
token = await get_admin_token()
data = await request.json()
return await proxy_request("POST", f"/email-template-versions/{version_id}/send-test", token, json_data=data)
+144
View File
@@ -0,0 +1,144 @@
"""
BreakPilot Core Backend
Shared APIs for authentication, RBAC, notifications, email templates,
system health, security (DevSecOps), and common middleware.
This is the extracted "core" service from the monorepo backend.
It runs on port 8000 and uses the `core` schema in PostgreSQL.
"""
import os
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# ---------------------------------------------------------------------------
# Router imports (shared APIs only)
# ---------------------------------------------------------------------------
from auth_api import router as auth_router
from rbac_api import router as rbac_router
from notification_api import router as notification_router
from email_template_api import (
router as email_template_router,
versions_router as email_template_versions_router,
)
from system_api import router as system_router
from security_api import router as security_router
# ---------------------------------------------------------------------------
# Middleware imports
# ---------------------------------------------------------------------------
from middleware import (
RequestIDMiddleware,
SecurityHeadersMiddleware,
RateLimiterMiddleware,
PIIRedactor,
InputGateMiddleware,
)
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("backend-core")
# ---------------------------------------------------------------------------
# Application
# ---------------------------------------------------------------------------
app = FastAPI(
title="BreakPilot Core Backend",
description="Shared APIs: Auth, RBAC, Notifications, Email Templates, System, Security",
version="1.0.0",
)
# ---------------------------------------------------------------------------
# CORS
# ---------------------------------------------------------------------------
ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Custom middleware stack (order matters -- outermost first)
# ---------------------------------------------------------------------------
# 1. Request-ID (outermost so every response has it)
app.add_middleware(RequestIDMiddleware)
# 2. Security headers
app.add_middleware(SecurityHeadersMiddleware)
# 3. Input gate (body-size / content-type validation)
app.add_middleware(InputGateMiddleware)
# 4. Rate limiter (Valkey-backed)
VALKEY_URL = os.getenv("VALKEY_URL", os.getenv("REDIS_URL", "redis://valkey:6379/0"))
app.add_middleware(RateLimiterMiddleware, valkey_url=VALKEY_URL)
# ---------------------------------------------------------------------------
# Routers
# ---------------------------------------------------------------------------
# Auth (proxy to consent-service)
app.include_router(auth_router, prefix="/api")
# RBAC (teacher / role management)
app.include_router(rbac_router, prefix="/api")
# Notifications (proxy to consent-service)
app.include_router(notification_router, prefix="/api")
# Email templates (proxy to consent-service)
app.include_router(email_template_router) # already has /api/consent/admin/email-templates prefix
app.include_router(email_template_versions_router) # already has /api/consent/admin/email-template-versions prefix
# System (health, local-ip)
app.include_router(system_router) # already has paths defined in router
# Security / DevSecOps dashboard
app.include_router(security_router, prefix="/api")
# ---------------------------------------------------------------------------
# Startup / Shutdown events
# ---------------------------------------------------------------------------
@app.on_event("startup")
async def on_startup():
logger.info("backend-core starting up")
# Ensure DATABASE_URL uses search_path=core,public
db_url = os.getenv("DATABASE_URL", "")
if db_url and "search_path" not in db_url:
separator = "&" if "?" in db_url else "?"
new_url = f"{db_url}{separator}search_path=core,public"
os.environ["DATABASE_URL"] = new_url
logger.info("DATABASE_URL updated with search_path=core,public")
elif "search_path" in db_url:
logger.info("DATABASE_URL already contains search_path")
else:
logger.warning("DATABASE_URL is not set -- database features will fail")
@app.on_event("shutdown")
async def on_shutdown():
logger.info("backend-core shutting down")
# ---------------------------------------------------------------------------
# Entrypoint (for `python main.py` during development)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=int(os.getenv("PORT", "8000")),
reload=os.getenv("ENVIRONMENT", "development") == "development",
)
+26
View File
@@ -0,0 +1,26 @@
"""
BreakPilot Middleware Stack
This module provides middleware components for the FastAPI backend:
- Request-ID: Adds unique request identifiers for tracing
- Security Headers: Adds security headers to all responses
- Rate Limiter: Protects against abuse (Valkey-based)
- PII Redactor: Redacts sensitive data from logs
- Input Gate: Validates request body size and content types
"""
from .request_id import RequestIDMiddleware, get_request_id
from .security_headers import SecurityHeadersMiddleware
from .rate_limiter import RateLimiterMiddleware
from .pii_redactor import PIIRedactor, redact_pii
from .input_gate import InputGateMiddleware
__all__ = [
"RequestIDMiddleware",
"get_request_id",
"SecurityHeadersMiddleware",
"RateLimiterMiddleware",
"PIIRedactor",
"redact_pii",
"InputGateMiddleware",
]
+260
View File
@@ -0,0 +1,260 @@
"""
Input Validation Gate Middleware
Validates incoming requests for:
- Request body size limits
- Content-Type validation
- File upload limits
- Malicious content detection
Usage:
from middleware import InputGateMiddleware
app.add_middleware(
InputGateMiddleware,
max_body_size=10 * 1024 * 1024, # 10MB
allowed_content_types=["application/json", "multipart/form-data"],
)
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
@dataclass
class InputGateConfig:
"""Configuration for input validation."""
# Maximum request body size (default: 10MB)
max_body_size: int = 10 * 1024 * 1024
# Allowed content types
allowed_content_types: Set[str] = field(default_factory=lambda: {
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
})
# File upload specific limits
max_file_size: int = 50 * 1024 * 1024 # 50MB for file uploads
allowed_file_types: Set[str] = field(default_factory=lambda: {
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"text/csv",
})
# Blocked file extensions (potential malware)
blocked_extensions: Set[str] = field(default_factory=lambda: {
".exe", ".bat", ".cmd", ".com", ".msi",
".dll", ".scr", ".pif", ".vbs", ".js",
".jar", ".sh", ".ps1", ".app",
})
# Paths that allow larger uploads (e.g., file upload endpoints)
large_upload_paths: List[str] = field(default_factory=lambda: [
"/api/files/upload",
"/api/documents/upload",
"/api/attachments",
])
# Paths excluded from validation
excluded_paths: List[str] = field(default_factory=lambda: [
"/health",
"/metrics",
])
# Enable strict content type checking
strict_content_type: bool = True
class InputGateMiddleware(BaseHTTPMiddleware):
"""
Middleware that validates incoming request bodies and content types.
Protects against:
- Oversized request bodies
- Invalid content types
- Potentially malicious file uploads
"""
def __init__(
self,
app,
config: Optional[InputGateConfig] = None,
max_body_size: Optional[int] = None,
allowed_content_types: Optional[Set[str]] = None,
):
super().__init__(app)
self.config = config or InputGateConfig()
# Apply overrides
if max_body_size is not None:
self.config.max_body_size = max_body_size
if allowed_content_types is not None:
self.config.allowed_content_types = allowed_content_types
# Auto-configure from environment
env_max_size = os.getenv("MAX_REQUEST_BODY_SIZE")
if env_max_size:
try:
self.config.max_body_size = int(env_max_size)
except ValueError:
pass
def _is_excluded_path(self, path: str) -> bool:
"""Check if path is excluded from validation."""
return path in self.config.excluded_paths
def _is_large_upload_path(self, path: str) -> bool:
"""Check if path allows larger uploads."""
for upload_path in self.config.large_upload_paths:
if path.startswith(upload_path):
return True
return False
def _get_max_size(self, path: str) -> int:
"""Get the maximum allowed body size for this path."""
if self._is_large_upload_path(path):
return self.config.max_file_size
return self.config.max_body_size
def _validate_content_type(self, content_type: Optional[str]) -> tuple[bool, str]:
"""
Validate the content type.
Returns:
Tuple of (is_valid, error_message)
"""
if not content_type:
# Allow requests without content type (e.g., GET requests)
return True, ""
# Extract base content type (remove charset, boundary, etc.)
base_type = content_type.split(";")[0].strip().lower()
if base_type not in self.config.allowed_content_types:
return False, f"Content-Type '{base_type}' is not allowed"
return True, ""
def _check_blocked_extension(self, filename: str) -> bool:
"""Check if filename has a blocked extension."""
if not filename:
return False
lower_filename = filename.lower()
for ext in self.config.blocked_extensions:
if lower_filename.endswith(ext):
return True
return False
async def dispatch(self, request: Request, call_next) -> Response:
# Skip excluded paths
if self._is_excluded_path(request.url.path):
return await call_next(request)
# Skip validation for GET, HEAD, OPTIONS requests
if request.method in ("GET", "HEAD", "OPTIONS"):
return await call_next(request)
# Validate content type for requests with body
content_type = request.headers.get("Content-Type")
if self.config.strict_content_type:
is_valid, error_msg = self._validate_content_type(content_type)
if not is_valid:
return JSONResponse(
status_code=415,
content={
"error": "unsupported_media_type",
"message": error_msg,
},
)
# Check Content-Length header
content_length = request.headers.get("Content-Length")
if content_length:
try:
length = int(content_length)
max_size = self._get_max_size(request.url.path)
if length > max_size:
return JSONResponse(
status_code=413,
content={
"error": "payload_too_large",
"message": f"Request body exceeds maximum size of {max_size} bytes",
"max_size": max_size,
},
)
except ValueError:
return JSONResponse(
status_code=400,
content={
"error": "invalid_content_length",
"message": "Invalid Content-Length header",
},
)
# For multipart uploads, check for blocked file extensions
if content_type and "multipart/form-data" in content_type:
# Note: Full file validation would require reading the body
# which we avoid in middleware for performance reasons.
# Detailed file validation should happen in the handler.
pass
# Process request
return await call_next(request)
def validate_file_upload(
filename: str,
content_type: str,
size: int,
config: Optional[InputGateConfig] = None,
) -> tuple[bool, str]:
"""
Validate a file upload.
Use this in upload handlers for detailed validation.
Args:
filename: Original filename
content_type: MIME type of the file
size: File size in bytes
config: Optional custom configuration
Returns:
Tuple of (is_valid, error_message)
"""
cfg = config or InputGateConfig()
# Check size
if size > cfg.max_file_size:
return False, f"File size exceeds maximum of {cfg.max_file_size} bytes"
# Check extension
if filename:
lower_filename = filename.lower()
for ext in cfg.blocked_extensions:
if lower_filename.endswith(ext):
return False, f"File extension '{ext}' is not allowed"
# Check content type
if content_type and content_type not in cfg.allowed_file_types:
return False, f"File type '{content_type}' is not allowed"
return True, ""
+316
View File
@@ -0,0 +1,316 @@
"""
PII Redactor
Redacts Personally Identifiable Information (PII) from logs and responses.
Essential for DSGVO/GDPR compliance in BreakPilot.
Redacted data types:
- Email addresses
- IP addresses
- German phone numbers
- Names (when identified)
- Student IDs
- Credit card numbers
- IBAN numbers
Usage:
from middleware import PIIRedactor, redact_pii
# Use in logging
logger.info(redact_pii(f"User {email} logged in from {ip}"))
# Configure redactor
redactor = PIIRedactor(patterns=["email", "ip", "phone"])
safe_message = redactor.redact(sensitive_message)
"""
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Pattern, Set
@dataclass
class PIIPattern:
"""Definition of a PII pattern."""
name: str
pattern: Pattern
replacement: str
# Pre-compiled regex patterns for common PII
PII_PATTERNS: Dict[str, PIIPattern] = {
"email": PIIPattern(
name="email",
pattern=re.compile(
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
re.IGNORECASE
),
replacement="[EMAIL_REDACTED]",
),
"ip_v4": PIIPattern(
name="ip_v4",
pattern=re.compile(
r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
),
replacement="[IP_REDACTED]",
),
"ip_v6": PIIPattern(
name="ip_v6",
pattern=re.compile(
r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b'
),
replacement="[IP_REDACTED]",
),
"phone_de": PIIPattern(
name="phone_de",
pattern=re.compile(
r'(?<!\w)(?:\+49|0049|0)[\s.-]?(?:\d{2,4})[\s.-]?(?:\d{3,4})[\s.-]?(?:\d{3,4})(?!\d)'
),
replacement="[PHONE_REDACTED]",
),
"phone_intl": PIIPattern(
name="phone_intl",
pattern=re.compile(
r'(?<!\w)\+?(?:\d[\s.-]?){10,15}(?!\d)'
),
replacement="[PHONE_REDACTED]",
),
"credit_card": PIIPattern(
name="credit_card",
pattern=re.compile(
r'\b(?:\d{4}[\s.-]?){3}\d{4}\b'
),
replacement="[CC_REDACTED]",
),
"iban": PIIPattern(
name="iban",
pattern=re.compile(
r'\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b',
re.IGNORECASE
),
replacement="[IBAN_REDACTED]",
),
"student_id": PIIPattern(
name="student_id",
pattern=re.compile(
r'\b(?:student|schueler|schüler)[-_]?(?:id|nr)?[:\s]?\d{4,10}\b',
re.IGNORECASE
),
replacement="[STUDENT_ID_REDACTED]",
),
"uuid": PIIPattern(
name="uuid",
pattern=re.compile(
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b',
re.IGNORECASE
),
replacement="[UUID_REDACTED]",
),
# German names are harder to detect, but we can catch common patterns
"name_prefix": PIIPattern(
name="name_prefix",
pattern=re.compile(
r'\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b'
),
replacement="[NAME_REDACTED]",
),
}
# Default patterns to enable
DEFAULT_PATTERNS = ["email", "ip_v4", "ip_v6", "phone_de"]
class PIIRedactor:
"""
Redacts PII from strings.
Attributes:
patterns: List of pattern names to use (e.g., ["email", "ip_v4"])
custom_patterns: Additional custom patterns
"""
def __init__(
self,
patterns: Optional[List[str]] = None,
custom_patterns: Optional[List[PIIPattern]] = None,
preserve_format: bool = False,
):
"""
Initialize the PII redactor.
Args:
patterns: List of pattern names to enable (default: email, ip_v4, ip_v6, phone_de)
custom_patterns: Additional custom PIIPattern objects
preserve_format: If True, preserve the length of redacted content
"""
self.patterns = patterns or DEFAULT_PATTERNS
self.custom_patterns = custom_patterns or []
self.preserve_format = preserve_format
# Build active patterns list
self._active_patterns: List[PIIPattern] = []
for pattern_name in self.patterns:
if pattern_name in PII_PATTERNS:
self._active_patterns.append(PII_PATTERNS[pattern_name])
# Add custom patterns
self._active_patterns.extend(self.custom_patterns)
def redact(self, text: str) -> str:
"""
Redact PII from the given text.
Args:
text: The text to redact PII from
Returns:
Text with PII replaced by redaction markers
"""
if not text:
return text
result = text
for pattern in self._active_patterns:
if self.preserve_format:
# Replace with same-length placeholder
def replace_preserve(match):
length = len(match.group())
return "*" * length
result = pattern.pattern.sub(replace_preserve, result)
else:
result = pattern.pattern.sub(pattern.replacement, result)
return result
def contains_pii(self, text: str) -> bool:
"""
Check if text contains any PII.
Args:
text: The text to check
Returns:
True if PII is detected
"""
if not text:
return False
for pattern in self._active_patterns:
if pattern.pattern.search(text):
return True
return False
def find_pii(self, text: str) -> List[Dict[str, str]]:
"""
Find all PII in text with their types.
Args:
text: The text to search
Returns:
List of dicts with 'type' and 'match' keys
"""
if not text:
return []
findings = []
for pattern in self._active_patterns:
for match in pattern.pattern.finditer(text):
findings.append({
"type": pattern.name,
"match": match.group(),
"start": match.start(),
"end": match.end(),
})
return findings
# Module-level default redactor instance
_default_redactor: Optional[PIIRedactor] = None
def get_default_redactor() -> PIIRedactor:
"""Get or create the default redactor instance."""
global _default_redactor
if _default_redactor is None:
_default_redactor = PIIRedactor()
return _default_redactor
def redact_pii(text: str) -> str:
"""
Convenience function to redact PII using the default redactor.
Args:
text: Text to redact
Returns:
Redacted text
Example:
logger.info(redact_pii(f"User {email} logged in"))
"""
return get_default_redactor().redact(text)
class PIIRedactingLogFilter:
"""
Logging filter that automatically redacts PII from log messages.
Usage:
import logging
handler = logging.StreamHandler()
handler.addFilter(PIIRedactingLogFilter())
logger = logging.getLogger()
logger.addHandler(handler)
"""
def __init__(self, redactor: Optional[PIIRedactor] = None):
self.redactor = redactor or get_default_redactor()
def filter(self, record):
# Redact the message
if record.msg:
record.msg = self.redactor.redact(str(record.msg))
# Redact args if present
if record.args:
if isinstance(record.args, dict):
record.args = {
k: self.redactor.redact(str(v)) if isinstance(v, str) else v
for k, v in record.args.items()
}
elif isinstance(record.args, tuple):
record.args = tuple(
self.redactor.redact(str(v)) if isinstance(v, str) else v
for v in record.args
)
return True
def create_safe_dict(data: dict, redactor: Optional[PIIRedactor] = None) -> dict:
"""
Create a copy of a dictionary with PII redacted.
Args:
data: Dictionary to redact
redactor: Optional custom redactor
Returns:
New dictionary with redacted values
"""
r = redactor or get_default_redactor()
def redact_value(value):
if isinstance(value, str):
return r.redact(value)
elif isinstance(value, dict):
return create_safe_dict(value, r)
elif isinstance(value, list):
return [redact_value(v) for v in value]
return value
return {k: redact_value(v) for k, v in data.items()}
+363
View File
@@ -0,0 +1,363 @@
"""
Rate Limiter Middleware
Implements distributed rate limiting using Valkey (Redis-fork).
Supports IP-based, user-based, and endpoint-specific rate limits.
Features:
- Sliding window rate limiting
- IP-based limits for unauthenticated requests
- User-based limits for authenticated requests
- Stricter limits for auth endpoints (anti-brute-force)
- IP whitelist/blacklist support
- Graceful fallback when Valkey is unavailable
Usage:
from middleware import RateLimiterMiddleware
app.add_middleware(
RateLimiterMiddleware,
valkey_url="redis://localhost:6379",
ip_limit=100,
user_limit=500,
)
"""
from __future__ import annotations
import asyncio
import hashlib
import os
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
# Try to import redis (valkey-compatible)
try:
import redis.asyncio as redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
redis = None
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting."""
# Valkey/Redis connection
valkey_url: str = "redis://localhost:6379"
# Default limits (requests per minute)
ip_limit: int = 100
user_limit: int = 500
# Stricter limits for auth endpoints
auth_limit: int = 20
auth_endpoints: List[str] = field(default_factory=lambda: [
"/api/auth/login",
"/api/auth/register",
"/api/auth/password-reset",
"/api/auth/forgot-password",
])
# Window size in seconds
window_size: int = 60
# IP whitelist (never rate limited)
ip_whitelist: Set[str] = field(default_factory=lambda: {
"127.0.0.1",
"::1",
})
# IP blacklist (always blocked)
ip_blacklist: Set[str] = field(default_factory=set)
# Skip internal Docker network
skip_internal_network: bool = True
# Excluded paths
excluded_paths: List[str] = field(default_factory=lambda: [
"/health",
"/metrics",
"/api/health",
])
# Fallback to in-memory when Valkey is unavailable
fallback_enabled: bool = True
# Key prefix for rate limit keys
key_prefix: str = "ratelimit"
class InMemoryRateLimiter:
"""Fallback in-memory rate limiter when Valkey is unavailable."""
def __init__(self):
self._counts: Dict[str, List[float]] = {}
self._lock = asyncio.Lock()
async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]:
"""
Check if rate limit is exceeded.
Returns:
Tuple of (is_allowed, remaining_requests)
"""
async with self._lock:
now = time.time()
window_start = now - window
# Get or create entry
if key not in self._counts:
self._counts[key] = []
# Remove old entries
self._counts[key] = [t for t in self._counts[key] if t > window_start]
# Check limit
current_count = len(self._counts[key])
if current_count >= limit:
return False, 0
# Add new request
self._counts[key].append(now)
return True, limit - current_count - 1
async def cleanup(self):
"""Remove expired entries."""
async with self._lock:
now = time.time()
for key in list(self._counts.keys()):
self._counts[key] = [t for t in self._counts[key] if t > now - 3600]
if not self._counts[key]:
del self._counts[key]
class RateLimiterMiddleware(BaseHTTPMiddleware):
"""
Middleware that implements distributed rate limiting.
Uses Valkey (Redis-fork) for distributed state, with fallback
to in-memory rate limiting when Valkey is unavailable.
"""
def __init__(
self,
app,
config: Optional[RateLimitConfig] = None,
# Individual overrides
valkey_url: Optional[str] = None,
ip_limit: Optional[int] = None,
user_limit: Optional[int] = None,
auth_limit: Optional[int] = None,
):
super().__init__(app)
self.config = config or RateLimitConfig()
# Apply overrides
if valkey_url is not None:
self.config.valkey_url = valkey_url
if ip_limit is not None:
self.config.ip_limit = ip_limit
if user_limit is not None:
self.config.user_limit = user_limit
if auth_limit is not None:
self.config.auth_limit = auth_limit
# Auto-configure from environment
self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url)
# Initialize Valkey client
self._redis: Optional[redis.Redis] = None
self._fallback = InMemoryRateLimiter()
self._valkey_available = False
async def _get_redis(self) -> Optional[redis.Redis]:
"""Get or create Redis/Valkey connection."""
if not REDIS_AVAILABLE:
return None
if self._redis is None:
try:
self._redis = redis.from_url(
self.config.valkey_url,
decode_responses=True,
socket_timeout=1.0,
socket_connect_timeout=1.0,
)
await self._redis.ping()
self._valkey_available = True
except Exception:
self._valkey_available = False
self._redis = None
return self._redis
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request."""
# Check X-Forwarded-For header
xff = request.headers.get("X-Forwarded-For")
if xff:
return xff.split(",")[0].strip()
# Check X-Real-IP header
xri = request.headers.get("X-Real-IP")
if xri:
return xri
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def _get_user_id(self, request: Request) -> Optional[str]:
"""Extract user ID from request state (set by session middleware)."""
if hasattr(request.state, "session") and request.state.session:
return getattr(request.state.session, "user_id", None)
return None
def _is_internal_network(self, ip: str) -> bool:
"""Check if IP is from internal Docker network."""
return (
ip.startswith("172.") or
ip.startswith("10.") or
ip.startswith("192.168.")
)
def _get_rate_limit(self, request: Request) -> int:
"""Determine the rate limit for this request."""
path = request.url.path
# Auth endpoints get stricter limits
for auth_path in self.config.auth_endpoints:
if path.startswith(auth_path):
return self.config.auth_limit
# Authenticated users get higher limits
if self._get_user_id(request):
return self.config.user_limit
# Default IP-based limit
return self.config.ip_limit
def _get_rate_limit_key(self, request: Request) -> str:
"""Generate the rate limit key for this request."""
# Use user ID if authenticated
user_id = self._get_user_id(request)
if user_id:
identifier = f"user:{user_id}"
else:
ip = self._get_client_ip(request)
# Hash IP for privacy
ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16]
identifier = f"ip:{ip_hash}"
# Include path for endpoint-specific limits
path = request.url.path
for auth_path in self.config.auth_endpoints:
if path.startswith(auth_path):
return f"{self.config.key_prefix}:auth:{identifier}"
return f"{self.config.key_prefix}:{identifier}"
async def _check_rate_limit_valkey(
self, key: str, limit: int, window: int
) -> tuple[bool, int]:
"""Check rate limit using Valkey."""
r = await self._get_redis()
if not r:
return await self._fallback.check_rate_limit(key, limit, window)
try:
# Use sliding window with sorted set
now = time.time()
window_start = now - window
pipe = r.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, "-inf", window_start)
# Count current entries
pipe.zcard(key)
# Add new entry
pipe.zadd(key, {str(now): now})
# Set expiry
pipe.expire(key, window + 10)
results = await pipe.execute()
current_count = results[1]
if current_count >= limit:
return False, 0
return True, limit - current_count - 1
except Exception:
# Fallback to in-memory
self._valkey_available = False
return await self._fallback.check_rate_limit(key, limit, window)
async def dispatch(self, request: Request, call_next) -> Response:
# Skip excluded paths
if request.url.path in self.config.excluded_paths:
return await call_next(request)
# Get client IP
ip = self._get_client_ip(request)
# Check blacklist
if ip in self.config.ip_blacklist:
return JSONResponse(
status_code=403,
content={
"error": "ip_blocked",
"message": "Your IP address has been blocked.",
},
)
# Skip whitelist
if ip in self.config.ip_whitelist:
return await call_next(request)
# Skip internal network
if self.config.skip_internal_network and self._is_internal_network(ip):
return await call_next(request)
# Get rate limit parameters
limit = self._get_rate_limit(request)
key = self._get_rate_limit_key(request)
window = self.config.window_size
# Check rate limit
allowed, remaining = await self._check_rate_limit_valkey(key, limit, window)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
"retry_after": window,
},
headers={
"Retry-After": str(window),
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + window),
},
)
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window)
return response
+138
View File
@@ -0,0 +1,138 @@
"""
Request-ID Middleware
Generates and propagates unique request identifiers for distributed tracing.
Supports both X-Request-ID and X-Correlation-ID headers.
Usage:
from middleware import RequestIDMiddleware, get_request_id
app.add_middleware(RequestIDMiddleware)
@app.get("/api/example")
async def example():
request_id = get_request_id()
logger.info(f"Processing request", extra={"request_id": request_id})
"""
import uuid
from contextvars import ContextVar
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
# Context variable to store request ID across async calls
_request_id_ctx: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
# Header names
REQUEST_ID_HEADER = "X-Request-ID"
CORRELATION_ID_HEADER = "X-Correlation-ID"
def get_request_id() -> Optional[str]:
"""
Get the current request ID from context.
Returns:
The request ID string or None if not in a request context.
Example:
request_id = get_request_id()
logger.info("Processing", extra={"request_id": request_id})
"""
return _request_id_ctx.get()
def set_request_id(request_id: str) -> None:
"""
Set the request ID in the current context.
Args:
request_id: The request ID to set
"""
_request_id_ctx.set(request_id)
def generate_request_id() -> str:
"""
Generate a new unique request ID.
Returns:
A UUID4 string
"""
return str(uuid.uuid4())
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Middleware that generates and propagates request IDs.
For each incoming request:
1. Check for existing X-Request-ID or X-Correlation-ID header
2. If not present, generate a new UUID
3. Store in context for use by handlers and logging
4. Add to response headers
Attributes:
header_name: The primary header name to use (default: X-Request-ID)
generator: Function to generate new IDs (default: uuid4)
"""
def __init__(
self,
app,
header_name: str = REQUEST_ID_HEADER,
generator=generate_request_id,
):
super().__init__(app)
self.header_name = header_name
self.generator = generator
async def dispatch(self, request: Request, call_next) -> Response:
# Try to get existing request ID from headers
request_id = (
request.headers.get(REQUEST_ID_HEADER)
or request.headers.get(CORRELATION_ID_HEADER)
)
# Generate new ID if not provided
if not request_id:
request_id = self.generator()
# Store in context for logging and handlers
set_request_id(request_id)
# Store in request state for direct access
request.state.request_id = request_id
# Process request
response = await call_next(request)
# Add request ID to response headers
response.headers[REQUEST_ID_HEADER] = request_id
response.headers[CORRELATION_ID_HEADER] = request_id
return response
class RequestIDLogFilter:
"""
Logging filter that adds request_id to log records.
Usage:
import logging
handler = logging.StreamHandler()
handler.addFilter(RequestIDLogFilter())
formatter = logging.Formatter(
'%(asctime)s [%(request_id)s] %(levelname)s %(message)s'
)
handler.setFormatter(formatter)
"""
def filter(self, record):
record.request_id = get_request_id() or "no-request-id"
return True
+202
View File
@@ -0,0 +1,202 @@
"""
Security Headers Middleware
Adds security headers to all HTTP responses to protect against common attacks.
Headers added:
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 1; mode=block
- Strict-Transport-Security (HSTS)
- Content-Security-Policy
- Referrer-Policy
- Permissions-Policy
Usage:
from middleware import SecurityHeadersMiddleware
app.add_middleware(SecurityHeadersMiddleware)
# Or with custom configuration:
app.add_middleware(
SecurityHeadersMiddleware,
hsts_enabled=True,
csp_policy="default-src 'self'",
)
"""
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
@dataclass
class SecurityHeadersConfig:
"""Configuration for security headers."""
# X-Content-Type-Options
content_type_options: str = "nosniff"
# X-Frame-Options
frame_options: str = "DENY"
# X-XSS-Protection (legacy, but still useful for older browsers)
xss_protection: str = "1; mode=block"
# Strict-Transport-Security
hsts_enabled: bool = True
hsts_max_age: int = 31536000 # 1 year
hsts_include_subdomains: bool = True
hsts_preload: bool = False
# Content-Security-Policy
csp_enabled: bool = True
csp_policy: str = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'"
# Referrer-Policy
referrer_policy: str = "strict-origin-when-cross-origin"
# Permissions-Policy (formerly Feature-Policy)
permissions_policy: str = "geolocation=(), microphone=(), camera=()"
# Cross-Origin headers
cross_origin_opener_policy: str = "same-origin"
cross_origin_embedder_policy: str = "require-corp"
cross_origin_resource_policy: str = "same-origin"
# Development mode (relaxes some restrictions)
development_mode: bool = False
# Excluded paths (e.g., for health checks)
excluded_paths: List[str] = field(default_factory=lambda: ["/health", "/metrics"])
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Middleware that adds security headers to all responses.
Attributes:
config: SecurityHeadersConfig instance
"""
def __init__(
self,
app,
config: Optional[SecurityHeadersConfig] = None,
# Individual overrides for convenience
hsts_enabled: Optional[bool] = None,
csp_policy: Optional[str] = None,
csp_enabled: Optional[bool] = None,
development_mode: Optional[bool] = None,
):
super().__init__(app)
# Use provided config or create default
self.config = config or SecurityHeadersConfig()
# Apply individual overrides
if hsts_enabled is not None:
self.config.hsts_enabled = hsts_enabled
if csp_policy is not None:
self.config.csp_policy = csp_policy
if csp_enabled is not None:
self.config.csp_enabled = csp_enabled
if development_mode is not None:
self.config.development_mode = development_mode
# Auto-detect development mode from environment
if development_mode is None:
env = os.getenv("ENVIRONMENT", "development")
self.config.development_mode = env.lower() in ("development", "dev", "local")
def _build_hsts_header(self) -> str:
"""Build the Strict-Transport-Security header value."""
parts = [f"max-age={self.config.hsts_max_age}"]
if self.config.hsts_include_subdomains:
parts.append("includeSubDomains")
if self.config.hsts_preload:
parts.append("preload")
return "; ".join(parts)
def _get_headers(self) -> Dict[str, str]:
"""Build the security headers dictionary."""
headers = {}
# Always add these headers
headers["X-Content-Type-Options"] = self.config.content_type_options
headers["X-Frame-Options"] = self.config.frame_options
headers["X-XSS-Protection"] = self.config.xss_protection
headers["Referrer-Policy"] = self.config.referrer_policy
# HSTS (only in production or if explicitly enabled)
if self.config.hsts_enabled and not self.config.development_mode:
headers["Strict-Transport-Security"] = self._build_hsts_header()
# Content-Security-Policy
if self.config.csp_enabled:
headers["Content-Security-Policy"] = self.config.csp_policy
# Permissions-Policy
if self.config.permissions_policy:
headers["Permissions-Policy"] = self.config.permissions_policy
# Cross-Origin headers (relaxed in development)
if not self.config.development_mode:
headers["Cross-Origin-Opener-Policy"] = self.config.cross_origin_opener_policy
# Note: COEP can break loading of external resources, be careful
# headers["Cross-Origin-Embedder-Policy"] = self.config.cross_origin_embedder_policy
headers["Cross-Origin-Resource-Policy"] = self.config.cross_origin_resource_policy
return headers
async def dispatch(self, request: Request, call_next) -> Response:
# Skip security headers for excluded paths
if request.url.path in self.config.excluded_paths:
return await call_next(request)
# Process request
response = await call_next(request)
# Add security headers
for header_name, header_value in self._get_headers().items():
response.headers[header_name] = header_value
return response
def get_default_csp_for_environment(environment: str) -> str:
"""
Get a sensible default CSP for the given environment.
Args:
environment: "development", "staging", or "production"
Returns:
CSP policy string
"""
if environment.lower() in ("development", "dev", "local"):
# Relaxed CSP for development
return (
"default-src 'self' localhost:* ws://localhost:*; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https: blob:; "
"font-src 'self' data:; "
"connect-src 'self' localhost:* ws://localhost:* https:; "
"frame-ancestors 'self'"
)
else:
# Strict CSP for production
return (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; "
"frame-ancestors 'none'"
)
+142
View File
@@ -0,0 +1,142 @@
"""
Notification API - Proxy zu Go Consent Service für Benachrichtigungen
"""
from fastapi import APIRouter, HTTPException, Header, Query
from typing import Optional
import httpx
router = APIRouter(prefix="/v1/notifications", tags=["Notifications"])
CONSENT_SERVICE_URL = "http://localhost:8081"
async def proxy_request(
method: str,
path: str,
authorization: Optional[str] = None,
json_data: dict = None,
params: dict = None
):
"""Proxy request to Go consent service."""
headers = {}
if authorization:
headers["Authorization"] = authorization
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method,
f"{CONSENT_SERVICE_URL}{path}",
headers=headers,
json=json_data,
params=params,
timeout=30.0
)
if response.status_code >= 400:
raise HTTPException(
status_code=response.status_code,
detail=response.json().get("error", "Request failed")
)
return response.json()
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"Consent service unavailable: {str(e)}")
@router.get("")
async def get_notifications(
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
unread_only: bool = Query(False),
authorization: Optional[str] = Header(None)
):
"""Holt alle Benachrichtigungen des aktuellen Benutzers."""
params = {
"limit": limit,
"offset": offset,
"unread_only": str(unread_only).lower()
}
return await proxy_request(
"GET",
"/api/v1/notifications",
authorization=authorization,
params=params
)
@router.get("/unread-count")
async def get_unread_count(
authorization: Optional[str] = Header(None)
):
"""Gibt die Anzahl ungelesener Benachrichtigungen zurück."""
return await proxy_request(
"GET",
"/api/v1/notifications/unread-count",
authorization=authorization
)
@router.put("/{notification_id}/read")
async def mark_as_read(
notification_id: str,
authorization: Optional[str] = Header(None)
):
"""Markiert eine Benachrichtigung als gelesen."""
return await proxy_request(
"PUT",
f"/api/v1/notifications/{notification_id}/read",
authorization=authorization
)
@router.put("/read-all")
async def mark_all_as_read(
authorization: Optional[str] = Header(None)
):
"""Markiert alle Benachrichtigungen als gelesen."""
return await proxy_request(
"PUT",
"/api/v1/notifications/read-all",
authorization=authorization
)
@router.delete("/{notification_id}")
async def delete_notification(
notification_id: str,
authorization: Optional[str] = Header(None)
):
"""Löscht eine Benachrichtigung."""
return await proxy_request(
"DELETE",
f"/api/v1/notifications/{notification_id}",
authorization=authorization
)
@router.get("/preferences")
async def get_preferences(
authorization: Optional[str] = Header(None)
):
"""Holt die Benachrichtigungseinstellungen des Benutzers."""
return await proxy_request(
"GET",
"/api/v1/notifications/preferences",
authorization=authorization
)
@router.put("/preferences")
async def update_preferences(
preferences: dict,
authorization: Optional[str] = Header(None)
):
"""Aktualisiert die Benachrichtigungseinstellungen."""
return await proxy_request(
"PUT",
"/api/v1/notifications/preferences",
authorization=authorization,
json_data=preferences
)
+819
View File
@@ -0,0 +1,819 @@
"""
RBAC API - Teacher and Role Management Endpoints
Provides API endpoints for:
- Listing all teachers
- Listing all available roles
- Assigning/revoking roles to teachers
- Viewing role assignments per teacher
Architecture:
- Authentication: Keycloak (when configured) or local JWT
- Authorization: Custom rbac.py for fine-grained permissions
"""
import os
import asyncpg
from datetime import datetime, timezone
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, HTTPException, Depends, Request
from pydantic import BaseModel
# Import hybrid auth module
try:
from auth import get_current_user, TokenExpiredError, TokenInvalidError
except ImportError:
# Fallback for standalone testing
from auth.keycloak_auth import get_current_user, TokenExpiredError, TokenInvalidError
# Configuration from environment - NO DEFAULT SECRETS
ENVIRONMENT = os.environ.get("ENVIRONMENT", "development")
router = APIRouter(prefix="/rbac", tags=["rbac"])
# Connection pool
_pool: Optional[asyncpg.Pool] = None
def _get_database_url() -> str:
"""Get DATABASE_URL from environment, raising error if not set."""
url = os.environ.get("DATABASE_URL")
if not url:
raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen")
return url
async def get_pool() -> asyncpg.Pool:
"""Get or create database connection pool"""
global _pool
if _pool is None:
database_url = _get_database_url()
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
return _pool
async def close_pool():
"""Close database connection pool"""
global _pool
if _pool:
await _pool.close()
_pool = None
# Pydantic Models
class RoleAssignmentCreate(BaseModel):
user_id: str
role: str
resource_type: str = "tenant"
resource_id: str
valid_to: Optional[str] = None
class RoleAssignmentRevoke(BaseModel):
assignment_id: str
class TeacherCreate(BaseModel):
email: str
first_name: str
last_name: str
teacher_code: Optional[str] = None
title: Optional[str] = None
roles: List[str] = []
class TeacherUpdate(BaseModel):
email: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
teacher_code: Optional[str] = None
title: Optional[str] = None
is_active: Optional[bool] = None
class CustomRoleCreate(BaseModel):
role_key: str
display_name: str
description: str
category: str
class CustomRoleUpdate(BaseModel):
display_name: Optional[str] = None
description: Optional[str] = None
category: Optional[str] = None
class TeacherResponse(BaseModel):
id: str
user_id: str
email: str
name: str
teacher_code: Optional[str]
title: Optional[str]
first_name: str
last_name: str
is_active: bool
roles: List[str]
class RoleInfo(BaseModel):
role: str
display_name: str
description: str
category: str
class RoleAssignmentResponse(BaseModel):
id: str
user_id: str
role: str
resource_type: str
resource_id: str
valid_from: str
valid_to: Optional[str]
granted_at: str
is_active: bool
# Role definitions with German display names
AVAILABLE_ROLES = {
# Klausur-Korrekturkette
"erstkorrektor": {
"display_name": "Erstkorrektor",
"description": "Fuehrt die erste Korrektur der Klausur durch",
"category": "klausur"
},
"zweitkorrektor": {
"display_name": "Zweitkorrektor",
"description": "Fuehrt die zweite Korrektur der Klausur durch",
"category": "klausur"
},
"drittkorrektor": {
"display_name": "Drittkorrektor",
"description": "Fuehrt die dritte Korrektur bei Notenabweichung durch",
"category": "klausur"
},
# Zeugnis-Workflow
"klassenlehrer": {
"display_name": "Klassenlehrer/in",
"description": "Erstellt Zeugnisse, traegt Kopfnoten und Bemerkungen ein",
"category": "zeugnis"
},
"fachlehrer": {
"display_name": "Fachlehrer/in",
"description": "Traegt Fachnoten ein",
"category": "zeugnis"
},
"zeugnisbeauftragter": {
"display_name": "Zeugnisbeauftragte/r",
"description": "Qualitaetskontrolle und Freigabe von Zeugnissen",
"category": "zeugnis"
},
"sekretariat": {
"display_name": "Sekretariat",
"description": "Druck, Versand und Archivierung von Dokumenten",
"category": "verwaltung"
},
# Leitung
"fachvorsitz": {
"display_name": "Fachvorsitz",
"description": "Fachpruefungsleitung und Qualitaetssicherung",
"category": "leitung"
},
"pruefungsvorsitz": {
"display_name": "Pruefungsvorsitz",
"description": "Pruefungsleitung und finale Freigabe",
"category": "leitung"
},
"schulleitung": {
"display_name": "Schulleitung",
"description": "Finale Freigabe und Unterschrift",
"category": "leitung"
},
"stufenleitung": {
"display_name": "Stufenleitung",
"description": "Koordination einer Jahrgangsstufe",
"category": "leitung"
},
# Administration
"schul_admin": {
"display_name": "Schul-Administrator",
"description": "Technische Administration der Schule",
"category": "admin"
},
"teacher_assistant": {
"display_name": "Referendar/in",
"description": "Lehrkraft in Ausbildung mit eingeschraenkten Rechten",
"category": "other"
},
}
# Note: get_user_from_token is replaced by the imported get_current_user dependency
# from auth module which supports both Keycloak and local JWT authentication
# API Endpoints
@router.get("/roles")
async def list_available_roles() -> List[RoleInfo]:
"""List all available roles with their descriptions"""
return [
RoleInfo(
role=role_key,
display_name=role_data["display_name"],
description=role_data["description"],
category=role_data["category"]
)
for role_key, role_data in AVAILABLE_ROLES.items()
]
@router.get("/teachers")
async def list_teachers(user: Dict[str, Any] = Depends(get_current_user)) -> List[TeacherResponse]:
"""List all teachers with their current roles"""
pool = await get_pool()
async with pool.acquire() as conn:
# Get all teachers with their user info
teachers = await conn.fetch("""
SELECT
t.id, t.user_id, t.teacher_code, t.title,
t.first_name, t.last_name, t.is_active,
u.email, u.name
FROM teachers t
JOIN users u ON t.user_id = u.id
WHERE t.school_id = 'a0000000-0000-0000-0000-000000000001'
ORDER BY t.last_name, t.first_name
""")
# Get role assignments for all teachers
role_assignments = await conn.fetch("""
SELECT user_id, role
FROM role_assignments
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
AND revoked_at IS NULL
AND (valid_to IS NULL OR valid_to > NOW())
""")
# Build role lookup
role_lookup: Dict[str, List[str]] = {}
for ra in role_assignments:
uid = str(ra["user_id"])
if uid not in role_lookup:
role_lookup[uid] = []
role_lookup[uid].append(ra["role"])
# Build response
result = []
for t in teachers:
uid = str(t["user_id"])
result.append(TeacherResponse(
id=str(t["id"]),
user_id=uid,
email=t["email"],
name=t["name"] or f"{t['first_name']} {t['last_name']}",
teacher_code=t["teacher_code"],
title=t["title"],
first_name=t["first_name"],
last_name=t["last_name"],
is_active=t["is_active"],
roles=role_lookup.get(uid, [])
))
return result
@router.get("/teachers/{teacher_id}/roles")
async def get_teacher_roles(teacher_id: str, user: Dict[str, Any] = Depends(get_current_user)) -> List[RoleAssignmentResponse]:
"""Get all role assignments for a specific teacher"""
pool = await get_pool()
async with pool.acquire() as conn:
# Get teacher's user_id
teacher = await conn.fetchrow(
"SELECT user_id FROM teachers WHERE id = $1",
teacher_id
)
if not teacher:
raise HTTPException(status_code=404, detail="Teacher not found")
# Get role assignments
assignments = await conn.fetch("""
SELECT id, user_id, role, resource_type, resource_id,
valid_from, valid_to, granted_at, revoked_at
FROM role_assignments
WHERE user_id = $1
ORDER BY granted_at DESC
""", teacher["user_id"])
return [
RoleAssignmentResponse(
id=str(a["id"]),
user_id=str(a["user_id"]),
role=a["role"],
resource_type=a["resource_type"],
resource_id=str(a["resource_id"]),
valid_from=a["valid_from"].isoformat() if a["valid_from"] else None,
valid_to=a["valid_to"].isoformat() if a["valid_to"] else None,
granted_at=a["granted_at"].isoformat() if a["granted_at"] else None,
is_active=a["revoked_at"] is None and (
a["valid_to"] is None or a["valid_to"] > datetime.now(timezone.utc)
)
)
for a in assignments
]
@router.get("/roles/{role}/teachers")
async def get_teachers_by_role(role: str, user: Dict[str, Any] = Depends(get_current_user)) -> List[TeacherResponse]:
"""Get all teachers with a specific role"""
if role not in AVAILABLE_ROLES:
raise HTTPException(status_code=400, detail=f"Unknown role: {role}")
pool = await get_pool()
async with pool.acquire() as conn:
teachers = await conn.fetch("""
SELECT DISTINCT
t.id, t.user_id, t.teacher_code, t.title,
t.first_name, t.last_name, t.is_active,
u.email, u.name
FROM teachers t
JOIN users u ON t.user_id = u.id
JOIN role_assignments ra ON t.user_id = ra.user_id
WHERE ra.role = $1
AND ra.revoked_at IS NULL
AND (ra.valid_to IS NULL OR ra.valid_to > NOW())
AND t.school_id = 'a0000000-0000-0000-0000-000000000001'
ORDER BY t.last_name, t.first_name
""", role)
# Get all roles for these teachers
if teachers:
user_ids = [t["user_id"] for t in teachers]
role_assignments = await conn.fetch("""
SELECT user_id, role
FROM role_assignments
WHERE user_id = ANY($1)
AND revoked_at IS NULL
AND (valid_to IS NULL OR valid_to > NOW())
""", user_ids)
role_lookup: Dict[str, List[str]] = {}
for ra in role_assignments:
uid = str(ra["user_id"])
if uid not in role_lookup:
role_lookup[uid] = []
role_lookup[uid].append(ra["role"])
else:
role_lookup = {}
return [
TeacherResponse(
id=str(t["id"]),
user_id=str(t["user_id"]),
email=t["email"],
name=t["name"] or f"{t['first_name']} {t['last_name']}",
teacher_code=t["teacher_code"],
title=t["title"],
first_name=t["first_name"],
last_name=t["last_name"],
is_active=t["is_active"],
roles=role_lookup.get(str(t["user_id"]), [])
)
for t in teachers
]
@router.post("/assignments")
async def assign_role(assignment: RoleAssignmentCreate, user: Dict[str, Any] = Depends(get_current_user)) -> RoleAssignmentResponse:
"""Assign a role to a user"""
if assignment.role not in AVAILABLE_ROLES:
raise HTTPException(status_code=400, detail=f"Unknown role: {assignment.role}")
pool = await get_pool()
async with pool.acquire() as conn:
# Check if assignment already exists
existing = await conn.fetchrow("""
SELECT id FROM role_assignments
WHERE user_id = $1 AND role = $2 AND resource_id = $3
AND revoked_at IS NULL
""", assignment.user_id, assignment.role, assignment.resource_id)
if existing:
raise HTTPException(
status_code=409,
detail="Role assignment already exists"
)
# Parse valid_to if provided
valid_to = None
if assignment.valid_to:
valid_to = datetime.fromisoformat(assignment.valid_to)
# Create assignment
result = await conn.fetchrow("""
INSERT INTO role_assignments
(user_id, role, resource_type, resource_id, tenant_id, valid_to, granted_by)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, user_id, role, resource_type, resource_id, valid_from, valid_to, granted_at
""",
assignment.user_id,
assignment.role,
assignment.resource_type,
assignment.resource_id,
assignment.resource_id, # tenant_id same as resource_id for tenant-level roles
valid_to,
user.get("user_id")
)
return RoleAssignmentResponse(
id=str(result["id"]),
user_id=str(result["user_id"]),
role=result["role"],
resource_type=result["resource_type"],
resource_id=str(result["resource_id"]),
valid_from=result["valid_from"].isoformat(),
valid_to=result["valid_to"].isoformat() if result["valid_to"] else None,
granted_at=result["granted_at"].isoformat(),
is_active=True
)
@router.delete("/assignments/{assignment_id}")
async def revoke_role(assignment_id: str, user: Dict[str, Any] = Depends(get_current_user)):
"""Revoke a role assignment"""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
UPDATE role_assignments
SET revoked_at = NOW()
WHERE id = $1 AND revoked_at IS NULL
""", assignment_id)
if result == "UPDATE 0":
raise HTTPException(status_code=404, detail="Assignment not found or already revoked")
return {"status": "revoked", "assignment_id": assignment_id}
@router.get("/summary")
async def get_role_summary(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
"""Get a summary of roles and their assignment counts"""
pool = await get_pool()
async with pool.acquire() as conn:
counts = await conn.fetch("""
SELECT role, COUNT(*) as count
FROM role_assignments
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
AND revoked_at IS NULL
AND (valid_to IS NULL OR valid_to > NOW())
GROUP BY role
ORDER BY role
""")
total_teachers = await conn.fetchval("""
SELECT COUNT(*) FROM teachers
WHERE school_id = 'a0000000-0000-0000-0000-000000000001'
AND is_active = true
""")
role_counts = {c["role"]: c["count"] for c in counts}
# Also include custom roles from database
custom_roles = await conn.fetch("""
SELECT role_key, display_name, category
FROM custom_roles
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
AND is_active = true
""")
all_roles = [
{
"role": role_key,
"display_name": role_data["display_name"],
"category": role_data["category"],
"count": role_counts.get(role_key, 0),
"is_custom": False
}
for role_key, role_data in AVAILABLE_ROLES.items()
]
for cr in custom_roles:
all_roles.append({
"role": cr["role_key"],
"display_name": cr["display_name"],
"category": cr["category"],
"count": role_counts.get(cr["role_key"], 0),
"is_custom": True
})
return {
"total_teachers": total_teachers,
"roles": all_roles
}
# ==========================================
# TEACHER MANAGEMENT ENDPOINTS
# ==========================================
@router.post("/teachers")
async def create_teacher(teacher: TeacherCreate, user: Dict[str, Any] = Depends(get_current_user)) -> TeacherResponse:
"""Create a new teacher with optional initial roles"""
pool = await get_pool()
import uuid
async with pool.acquire() as conn:
# Check if email already exists
existing = await conn.fetchrow(
"SELECT id FROM users WHERE email = $1",
teacher.email
)
if existing:
raise HTTPException(status_code=409, detail="Email already exists")
# Generate UUIDs
user_id = str(uuid.uuid4())
teacher_id = str(uuid.uuid4())
# Create user first
await conn.execute("""
INSERT INTO users (id, email, name, password_hash, role, is_active)
VALUES ($1, $2, $3, '', 'teacher', true)
""", user_id, teacher.email, f"{teacher.first_name} {teacher.last_name}")
# Create teacher record
await conn.execute("""
INSERT INTO teachers (id, user_id, school_id, first_name, last_name, teacher_code, title, is_active)
VALUES ($1, $2, 'a0000000-0000-0000-0000-000000000001', $3, $4, $5, $6, true)
""", teacher_id, user_id, teacher.first_name, teacher.last_name,
teacher.teacher_code, teacher.title)
# Assign initial roles if provided
assigned_roles = []
for role in teacher.roles:
if role in AVAILABLE_ROLES or await conn.fetchrow(
"SELECT 1 FROM custom_roles WHERE role_key = $1 AND is_active = true", role
):
await conn.execute("""
INSERT INTO role_assignments (user_id, role, resource_type, resource_id, tenant_id, granted_by)
VALUES ($1, $2, 'tenant', 'a0000000-0000-0000-0000-000000000001',
'a0000000-0000-0000-0000-000000000001', $3)
""", user_id, role, user.get("user_id"))
assigned_roles.append(role)
return TeacherResponse(
id=teacher_id,
user_id=user_id,
email=teacher.email,
name=f"{teacher.first_name} {teacher.last_name}",
teacher_code=teacher.teacher_code,
title=teacher.title,
first_name=teacher.first_name,
last_name=teacher.last_name,
is_active=True,
roles=assigned_roles
)
@router.put("/teachers/{teacher_id}")
async def update_teacher(teacher_id: str, updates: TeacherUpdate, user: Dict[str, Any] = Depends(get_current_user)) -> TeacherResponse:
"""Update teacher information"""
pool = await get_pool()
async with pool.acquire() as conn:
# Get current teacher data
teacher = await conn.fetchrow("""
SELECT t.id, t.user_id, t.teacher_code, t.title, t.first_name, t.last_name, t.is_active,
u.email, u.name
FROM teachers t
JOIN users u ON t.user_id = u.id
WHERE t.id = $1
""", teacher_id)
if not teacher:
raise HTTPException(status_code=404, detail="Teacher not found")
# Build update queries
if updates.email:
await conn.execute("UPDATE users SET email = $1 WHERE id = $2",
updates.email, teacher["user_id"])
teacher_updates = []
teacher_values = []
idx = 1
if updates.first_name:
teacher_updates.append(f"first_name = ${idx}")
teacher_values.append(updates.first_name)
idx += 1
if updates.last_name:
teacher_updates.append(f"last_name = ${idx}")
teacher_values.append(updates.last_name)
idx += 1
if updates.teacher_code is not None:
teacher_updates.append(f"teacher_code = ${idx}")
teacher_values.append(updates.teacher_code)
idx += 1
if updates.title is not None:
teacher_updates.append(f"title = ${idx}")
teacher_values.append(updates.title)
idx += 1
if updates.is_active is not None:
teacher_updates.append(f"is_active = ${idx}")
teacher_values.append(updates.is_active)
idx += 1
if teacher_updates:
teacher_values.append(teacher_id)
await conn.execute(
f"UPDATE teachers SET {', '.join(teacher_updates)} WHERE id = ${idx}",
*teacher_values
)
# Update user name if first/last name changed
if updates.first_name or updates.last_name:
new_first = updates.first_name or teacher["first_name"]
new_last = updates.last_name or teacher["last_name"]
await conn.execute("UPDATE users SET name = $1 WHERE id = $2",
f"{new_first} {new_last}", teacher["user_id"])
# Fetch updated data
updated = await conn.fetchrow("""
SELECT t.id, t.user_id, t.teacher_code, t.title, t.first_name, t.last_name, t.is_active,
u.email, u.name
FROM teachers t
JOIN users u ON t.user_id = u.id
WHERE t.id = $1
""", teacher_id)
# Get roles
roles = await conn.fetch("""
SELECT role FROM role_assignments
WHERE user_id = $1 AND revoked_at IS NULL
AND (valid_to IS NULL OR valid_to > NOW())
""", updated["user_id"])
return TeacherResponse(
id=str(updated["id"]),
user_id=str(updated["user_id"]),
email=updated["email"],
name=updated["name"],
teacher_code=updated["teacher_code"],
title=updated["title"],
first_name=updated["first_name"],
last_name=updated["last_name"],
is_active=updated["is_active"],
roles=[r["role"] for r in roles]
)
@router.delete("/teachers/{teacher_id}")
async def deactivate_teacher(teacher_id: str, user: Dict[str, Any] = Depends(get_current_user)):
"""Deactivate a teacher (soft delete)"""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
UPDATE teachers SET is_active = false WHERE id = $1
""", teacher_id)
if result == "UPDATE 0":
raise HTTPException(status_code=404, detail="Teacher not found")
return {"status": "deactivated", "teacher_id": teacher_id}
# ==========================================
# CUSTOM ROLE MANAGEMENT ENDPOINTS
# ==========================================
@router.get("/custom-roles")
async def list_custom_roles(user: Dict[str, Any] = Depends(get_current_user)) -> List[RoleInfo]:
"""List all custom roles"""
pool = await get_pool()
async with pool.acquire() as conn:
roles = await conn.fetch("""
SELECT role_key, display_name, description, category
FROM custom_roles
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
AND is_active = true
ORDER BY category, display_name
""")
return [
RoleInfo(
role=r["role_key"],
display_name=r["display_name"],
description=r["description"],
category=r["category"]
)
for r in roles
]
@router.post("/custom-roles")
async def create_custom_role(role: CustomRoleCreate, user: Dict[str, Any] = Depends(get_current_user)) -> RoleInfo:
"""Create a new custom role"""
pool = await get_pool()
# Check if role_key conflicts with built-in roles
if role.role_key in AVAILABLE_ROLES:
raise HTTPException(status_code=409, detail="Role key conflicts with built-in role")
async with pool.acquire() as conn:
# Check if custom role already exists
existing = await conn.fetchrow("""
SELECT id FROM custom_roles
WHERE role_key = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
""", role.role_key)
if existing:
raise HTTPException(status_code=409, detail="Custom role already exists")
await conn.execute("""
INSERT INTO custom_roles (role_key, display_name, description, category, tenant_id, created_by)
VALUES ($1, $2, $3, $4, 'a0000000-0000-0000-0000-000000000001', $5)
""", role.role_key, role.display_name, role.description, role.category, user.get("user_id"))
return RoleInfo(
role=role.role_key,
display_name=role.display_name,
description=role.description,
category=role.category
)
@router.put("/custom-roles/{role_key}")
async def update_custom_role(role_key: str, updates: CustomRoleUpdate, user: Dict[str, Any] = Depends(get_current_user)) -> RoleInfo:
"""Update a custom role"""
if role_key in AVAILABLE_ROLES:
raise HTTPException(status_code=400, detail="Cannot modify built-in roles")
pool = await get_pool()
async with pool.acquire() as conn:
current = await conn.fetchrow("""
SELECT role_key, display_name, description, category
FROM custom_roles
WHERE role_key = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
AND is_active = true
""", role_key)
if not current:
raise HTTPException(status_code=404, detail="Custom role not found")
new_display = updates.display_name or current["display_name"]
new_desc = updates.description or current["description"]
new_cat = updates.category or current["category"]
await conn.execute("""
UPDATE custom_roles
SET display_name = $1, description = $2, category = $3
WHERE role_key = $4 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
""", new_display, new_desc, new_cat, role_key)
return RoleInfo(
role=role_key,
display_name=new_display,
description=new_desc,
category=new_cat
)
@router.delete("/custom-roles/{role_key}")
async def delete_custom_role(role_key: str, user: Dict[str, Any] = Depends(get_current_user)):
"""Delete a custom role (soft delete)"""
if role_key in AVAILABLE_ROLES:
raise HTTPException(status_code=400, detail="Cannot delete built-in roles")
pool = await get_pool()
async with pool.acquire() as conn:
# Soft delete the role
result = await conn.execute("""
UPDATE custom_roles SET is_active = false
WHERE role_key = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
""", role_key)
if result == "UPDATE 0":
raise HTTPException(status_code=404, detail="Custom role not found")
# Also revoke all assignments with this role
await conn.execute("""
UPDATE role_assignments SET revoked_at = NOW()
WHERE role = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
AND revoked_at IS NULL
""", role_key)
return {"status": "deleted", "role_key": role_key}
+52
View File
@@ -0,0 +1,52 @@
# BreakPilot Core Backend Dependencies
# Only what the shared APIs actually need.
# Web Framework
fastapi==0.123.9
uvicorn==0.38.0
starlette==0.49.3
# HTTP Client (auth_api, notification_api, email_template_api proxy calls)
httpx==0.28.1
requests==2.32.5
# Validation & Types
pydantic==2.12.5
pydantic_core==2.41.5
email-validator==2.3.0
annotated-types==0.7.0
# Authentication (auth module, consent_client JWT)
PyJWT==2.10.1
python-multipart==0.0.20
# Database (rbac_api, middleware rate_limiter)
asyncpg==0.30.0
psycopg2-binary==2.9.10
# Cache / Rate-Limiter (Valkey/Redis)
redis==5.2.1
# PDF Generation (services/pdf_service)
weasyprint==66.0
Jinja2==3.1.6
# Image Processing (services/file_processor)
pillow==11.3.0
opencv-python==4.12.0.88
numpy==2.0.2
# Document Processing (services/file_processor)
python-docx==1.2.0
mammoth==1.11.0
Markdown==3.9
# Secrets Management (Vault)
hvac==2.4.0
# Utilities
python-dateutil==2.9.0.post0
# Security: Pin transitive dependencies to patched versions
idna>=3.7 # CVE-2024-3651
cryptography>=42.0.0 # GHSA-h4gh-qq45-vh27
+995
View File
@@ -0,0 +1,995 @@
"""
BreakPilot Security API
Endpunkte fuer das Security Dashboard:
- Tool-Status abfragen
- Scan-Ergebnisse abrufen
- Scans ausloesen
- SBOM-Daten abrufen
- Scan-Historie anzeigen
Features:
- Liest Security-Reports aus dem security-reports/ Verzeichnis
- Fuehrt Security-Scans via subprocess aus
- Parst Gitleaks, Semgrep, Trivy, Grype JSON-Reports
- Generiert SBOM mit Syft
"""
import os
import json
import subprocess
import asyncio
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel
router = APIRouter(prefix="/v1/security", tags=["Security"])
# Pfade - innerhalb des Backend-Verzeichnisses
# In Docker: /app/security-reports, /app/scripts
# Lokal: backend/security-reports, backend/scripts
BACKEND_DIR = Path(__file__).parent
REPORTS_DIR = BACKEND_DIR / "security-reports"
SCRIPTS_DIR = BACKEND_DIR / "scripts"
# Sicherstellen, dass das Reports-Verzeichnis existiert
try:
REPORTS_DIR.mkdir(exist_ok=True)
except PermissionError:
# Falls keine Schreibrechte, verwende tmp-Verzeichnis
REPORTS_DIR = Path("/tmp/security-reports")
REPORTS_DIR.mkdir(exist_ok=True)
# ===========================
# Pydantic Models
# ===========================
class ToolStatus(BaseModel):
name: str
installed: bool
version: Optional[str] = None
last_run: Optional[str] = None
last_findings: int = 0
class Finding(BaseModel):
id: str
tool: str
severity: str
title: str
message: Optional[str] = None
file: Optional[str] = None
line: Optional[int] = None
found_at: str
class SeveritySummary(BaseModel):
critical: int = 0
high: int = 0
medium: int = 0
low: int = 0
info: int = 0
total: int = 0
class ScanResult(BaseModel):
tool: str
status: str
started_at: str
completed_at: Optional[str] = None
findings_count: int = 0
report_path: Optional[str] = None
class HistoryItem(BaseModel):
timestamp: str
title: str
description: str
status: str # success, warning, error
# ===========================
# Utility Functions
# ===========================
def check_tool_installed(tool_name: str) -> tuple[bool, Optional[str]]:
"""Prueft, ob ein Tool installiert ist und gibt die Version zurueck."""
try:
if tool_name == "gitleaks":
result = subprocess.run(["gitleaks", "version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return True, result.stdout.strip()
elif tool_name == "semgrep":
result = subprocess.run(["semgrep", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return True, result.stdout.strip().split('\n')[0]
elif tool_name == "bandit":
result = subprocess.run(["bandit", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return True, result.stdout.strip()
elif tool_name == "trivy":
result = subprocess.run(["trivy", "version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
# Parse "Version: 0.48.x"
for line in result.stdout.split('\n'):
if line.startswith('Version:'):
return True, line.split(':')[1].strip()
return True, result.stdout.strip().split('\n')[0]
elif tool_name == "grype":
result = subprocess.run(["grype", "version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return True, result.stdout.strip().split('\n')[0]
elif tool_name == "syft":
result = subprocess.run(["syft", "version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
return True, result.stdout.strip().split('\n')[0]
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
return False, None
def get_latest_report(tool_prefix: str) -> Optional[Path]:
"""Findet den neuesten Report fuer ein Tool."""
if not REPORTS_DIR.exists():
return None
reports = list(REPORTS_DIR.glob(f"{tool_prefix}*.json"))
if not reports:
return None
return max(reports, key=lambda p: p.stat().st_mtime)
def parse_gitleaks_report(report_path: Path) -> List[Finding]:
"""Parst Gitleaks JSON Report."""
findings = []
try:
with open(report_path) as f:
data = json.load(f)
if isinstance(data, list):
for item in data:
findings.append(Finding(
id=item.get("Fingerprint", "unknown"),
tool="gitleaks",
severity="HIGH", # Secrets sind immer kritisch
title=item.get("Description", "Secret detected"),
message=f"Rule: {item.get('RuleID', 'unknown')}",
file=item.get("File", ""),
line=item.get("StartLine", 0),
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
))
except (json.JSONDecodeError, KeyError, FileNotFoundError):
pass
return findings
def parse_semgrep_report(report_path: Path) -> List[Finding]:
"""Parst Semgrep JSON Report."""
findings = []
try:
with open(report_path) as f:
data = json.load(f)
results = data.get("results", [])
for item in results:
severity = item.get("extra", {}).get("severity", "INFO").upper()
findings.append(Finding(
id=item.get("check_id", "unknown"),
tool="semgrep",
severity=severity,
title=item.get("extra", {}).get("message", "Finding"),
message=item.get("check_id", ""),
file=item.get("path", ""),
line=item.get("start", {}).get("line", 0),
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
))
except (json.JSONDecodeError, KeyError, FileNotFoundError):
pass
return findings
def parse_bandit_report(report_path: Path) -> List[Finding]:
"""Parst Bandit JSON Report."""
findings = []
try:
with open(report_path) as f:
data = json.load(f)
results = data.get("results", [])
for item in results:
severity = item.get("issue_severity", "LOW").upper()
findings.append(Finding(
id=item.get("test_id", "unknown"),
tool="bandit",
severity=severity,
title=item.get("issue_text", "Finding"),
message=f"CWE: {item.get('issue_cwe', {}).get('id', 'N/A')}",
file=item.get("filename", ""),
line=item.get("line_number", 0),
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
))
except (json.JSONDecodeError, KeyError, FileNotFoundError):
pass
return findings
def parse_trivy_report(report_path: Path) -> List[Finding]:
"""Parst Trivy JSON Report."""
findings = []
try:
with open(report_path) as f:
data = json.load(f)
results = data.get("Results", [])
for result in results:
vulnerabilities = result.get("Vulnerabilities", []) or []
target = result.get("Target", "")
for vuln in vulnerabilities:
severity = vuln.get("Severity", "UNKNOWN").upper()
findings.append(Finding(
id=vuln.get("VulnerabilityID", "unknown"),
tool="trivy",
severity=severity,
title=vuln.get("Title", vuln.get("VulnerabilityID", "CVE")),
message=f"{vuln.get('PkgName', '')} {vuln.get('InstalledVersion', '')}",
file=target,
line=None,
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
))
except (json.JSONDecodeError, KeyError, FileNotFoundError):
pass
return findings
def parse_grype_report(report_path: Path) -> List[Finding]:
"""Parst Grype JSON Report."""
findings = []
try:
with open(report_path) as f:
data = json.load(f)
matches = data.get("matches", [])
for match in matches:
vuln = match.get("vulnerability", {})
artifact = match.get("artifact", {})
severity = vuln.get("severity", "Unknown").upper()
findings.append(Finding(
id=vuln.get("id", "unknown"),
tool="grype",
severity=severity,
title=vuln.get("description", vuln.get("id", "CVE"))[:100],
message=f"{artifact.get('name', '')} {artifact.get('version', '')}",
file=artifact.get("locations", [{}])[0].get("path", "") if artifact.get("locations") else "",
line=None,
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
))
except (json.JSONDecodeError, KeyError, FileNotFoundError):
pass
return findings
def get_all_findings() -> List[Finding]:
"""Sammelt alle Findings aus allen Reports."""
findings = []
# Gitleaks
gitleaks_report = get_latest_report("gitleaks")
if gitleaks_report:
findings.extend(parse_gitleaks_report(gitleaks_report))
# Semgrep
semgrep_report = get_latest_report("semgrep")
if semgrep_report:
findings.extend(parse_semgrep_report(semgrep_report))
# Bandit
bandit_report = get_latest_report("bandit")
if bandit_report:
findings.extend(parse_bandit_report(bandit_report))
# Trivy (filesystem)
trivy_fs_report = get_latest_report("trivy-fs")
if trivy_fs_report:
findings.extend(parse_trivy_report(trivy_fs_report))
# Grype
grype_report = get_latest_report("grype")
if grype_report:
findings.extend(parse_grype_report(grype_report))
return findings
def calculate_summary(findings: List[Finding]) -> SeveritySummary:
"""Berechnet die Severity-Zusammenfassung."""
summary = SeveritySummary()
for finding in findings:
severity = finding.severity.upper()
if severity == "CRITICAL":
summary.critical += 1
elif severity == "HIGH":
summary.high += 1
elif severity == "MEDIUM":
summary.medium += 1
elif severity == "LOW":
summary.low += 1
else:
summary.info += 1
summary.total = len(findings)
return summary
# ===========================
# API Endpoints
# ===========================
@router.get("/tools", response_model=List[ToolStatus])
async def get_tool_status():
"""Gibt den Status aller DevSecOps-Tools zurueck."""
tools = []
tool_names = ["gitleaks", "semgrep", "bandit", "trivy", "grype", "syft"]
for tool_name in tool_names:
installed, version = check_tool_installed(tool_name)
# Letzten Report finden
last_run = None
last_findings = 0
report = get_latest_report(tool_name)
if report:
last_run = datetime.fromtimestamp(report.stat().st_mtime).strftime("%d.%m.%Y %H:%M")
tools.append(ToolStatus(
name=tool_name.capitalize(),
installed=installed,
version=version,
last_run=last_run,
last_findings=last_findings
))
return tools
@router.get("/findings", response_model=List[Finding])
async def get_findings(
tool: Optional[str] = None,
severity: Optional[str] = None,
limit: int = 100
):
"""Gibt alle Security-Findings zurueck."""
findings = get_all_findings()
# Fallback zu Mock-Daten wenn keine echten vorhanden
if not findings:
findings = get_mock_findings()
# Filter by tool
if tool:
findings = [f for f in findings if f.tool.lower() == tool.lower()]
# Filter by severity
if severity:
findings = [f for f in findings if f.severity.upper() == severity.upper()]
# Sort by severity (critical first)
severity_order = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3, "INFO": 4, "UNKNOWN": 5}
findings.sort(key=lambda f: severity_order.get(f.severity.upper(), 5))
return findings[:limit]
@router.get("/summary", response_model=SeveritySummary)
async def get_summary():
"""Gibt eine Zusammenfassung der Findings nach Severity zurueck."""
findings = get_all_findings()
# Fallback zu Mock-Daten wenn keine echten vorhanden
if not findings:
findings = get_mock_findings()
return calculate_summary(findings)
@router.get("/sbom")
async def get_sbom():
"""Gibt das aktuelle SBOM zurueck."""
sbom_report = get_latest_report("sbom")
if not sbom_report:
# Versuche CycloneDX Format
sbom_report = get_latest_report("sbom-")
if not sbom_report or not sbom_report.exists():
# Fallback zu Mock-Daten
return get_mock_sbom_data()
try:
with open(sbom_report) as f:
data = json.load(f)
return data
except (json.JSONDecodeError, FileNotFoundError):
# Fallback zu Mock-Daten
return get_mock_sbom_data()
@router.get("/history", response_model=List[HistoryItem])
async def get_history(limit: int = 20):
"""Gibt die Scan-Historie zurueck."""
history = []
if REPORTS_DIR.exists():
# Alle JSON-Reports sammeln
reports = list(REPORTS_DIR.glob("*.json"))
reports.sort(key=lambda p: p.stat().st_mtime, reverse=True)
for report in reports[:limit]:
tool_name = report.stem.split("-")[0]
timestamp = datetime.fromtimestamp(report.stat().st_mtime).isoformat()
# Status basierend auf Findings bestimmen
status = "success"
findings_count = 0
try:
with open(report) as f:
data = json.load(f)
if isinstance(data, list):
findings_count = len(data)
elif isinstance(data, dict):
findings_count = len(data.get("results", [])) or len(data.get("matches", [])) or len(data.get("Results", []))
if findings_count > 0:
status = "warning"
except:
pass
history.append(HistoryItem(
timestamp=timestamp,
title=f"{tool_name.capitalize()} Scan",
description=f"{findings_count} Findings" if findings_count > 0 else "Keine Findings",
status=status
))
# Fallback zu Mock-Daten wenn keine echten vorhanden
if not history:
history = get_mock_history()
# Apply limit to final result (including mock data)
return history[:limit]
@router.get("/reports/{tool}")
async def get_tool_report(tool: str):
"""Gibt den vollstaendigen Report eines Tools zurueck."""
report = get_latest_report(tool.lower())
if not report or not report.exists():
raise HTTPException(status_code=404, detail=f"Kein Report fuer {tool} gefunden")
try:
with open(report) as f:
return json.load(f)
except (json.JSONDecodeError, FileNotFoundError) as e:
raise HTTPException(status_code=500, detail=f"Fehler beim Lesen des Reports: {str(e)}")
@router.post("/scan/{scan_type}")
async def run_scan(scan_type: str, background_tasks: BackgroundTasks):
"""
Startet einen Security-Scan.
scan_type kann sein:
- secrets (Gitleaks)
- sast (Semgrep, Bandit)
- deps (Trivy, Grype)
- containers (Trivy image)
- sbom (Syft)
- all (Alle Scans)
"""
valid_types = ["secrets", "sast", "deps", "containers", "sbom", "all"]
if scan_type not in valid_types:
raise HTTPException(
status_code=400,
detail=f"Ungueltiger Scan-Typ. Erlaubt: {', '.join(valid_types)}"
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
async def run_scan_async(scan_type: str):
"""Fuehrt den Scan asynchron aus."""
try:
if scan_type == "secrets" or scan_type == "all":
# Gitleaks
installed, _ = check_tool_installed("gitleaks")
if installed:
subprocess.run(
["gitleaks", "detect", "--source", str(PROJECT_ROOT),
"--config", str(PROJECT_ROOT / ".gitleaks.toml"),
"--report-path", str(REPORTS_DIR / f"gitleaks-{timestamp}.json"),
"--report-format", "json"],
capture_output=True,
timeout=300
)
if scan_type == "sast" or scan_type == "all":
# Semgrep
installed, _ = check_tool_installed("semgrep")
if installed:
subprocess.run(
["semgrep", "scan", "--config", "auto",
"--config", str(PROJECT_ROOT / ".semgrep.yml"),
"--json", "--output", str(REPORTS_DIR / f"semgrep-{timestamp}.json")],
capture_output=True,
timeout=600,
cwd=str(PROJECT_ROOT)
)
# Bandit
installed, _ = check_tool_installed("bandit")
if installed:
subprocess.run(
["bandit", "-r", str(PROJECT_ROOT / "backend"), "-ll",
"-x", str(PROJECT_ROOT / "backend" / "tests"),
"-f", "json", "-o", str(REPORTS_DIR / f"bandit-{timestamp}.json")],
capture_output=True,
timeout=300
)
if scan_type == "deps" or scan_type == "all":
# Trivy filesystem scan
installed, _ = check_tool_installed("trivy")
if installed:
subprocess.run(
["trivy", "fs", str(PROJECT_ROOT),
"--config", str(PROJECT_ROOT / ".trivy.yaml"),
"--format", "json",
"--output", str(REPORTS_DIR / f"trivy-fs-{timestamp}.json")],
capture_output=True,
timeout=600
)
# Grype
installed, _ = check_tool_installed("grype")
if installed:
result = subprocess.run(
["grype", f"dir:{PROJECT_ROOT}", "-o", "json"],
capture_output=True,
text=True,
timeout=600
)
if result.stdout:
with open(REPORTS_DIR / f"grype-{timestamp}.json", "w") as f:
f.write(result.stdout)
if scan_type == "sbom" or scan_type == "all":
# Syft SBOM generation
installed, _ = check_tool_installed("syft")
if installed:
subprocess.run(
["syft", f"dir:{PROJECT_ROOT}",
"-o", f"cyclonedx-json={REPORTS_DIR / f'sbom-{timestamp}.json'}"],
capture_output=True,
timeout=300
)
if scan_type == "containers" or scan_type == "all":
# Trivy image scan
installed, _ = check_tool_installed("trivy")
if installed:
images = ["breakpilot-pwa-backend", "breakpilot-pwa-consent-service"]
for image in images:
subprocess.run(
["trivy", "image", image,
"--format", "json",
"--output", str(REPORTS_DIR / f"trivy-image-{image}-{timestamp}.json")],
capture_output=True,
timeout=600
)
except subprocess.TimeoutExpired:
pass
except Exception as e:
print(f"Scan error: {e}")
# Scan im Hintergrund ausfuehren
background_tasks.add_task(run_scan_async, scan_type)
return {
"status": "started",
"scan_type": scan_type,
"timestamp": timestamp,
"message": f"Scan '{scan_type}' wurde gestartet"
}
@router.get("/health")
async def health_check():
"""Health-Check fuer die Security API."""
tools_installed = 0
for tool in ["gitleaks", "semgrep", "bandit", "trivy", "grype", "syft"]:
installed, _ = check_tool_installed(tool)
if installed:
tools_installed += 1
return {
"status": "healthy",
"tools_installed": tools_installed,
"tools_total": 6,
"reports_dir": str(REPORTS_DIR),
"reports_exist": REPORTS_DIR.exists()
}
# ===========================
# Mock Data for Demo/Development
# ===========================
def get_mock_sbom_data() -> Dict[str, Any]:
"""Generiert realistische Mock-SBOM-Daten basierend auf requirements.txt."""
return {
"bomFormat": "CycloneDX",
"specVersion": "1.4",
"version": 1,
"metadata": {
"timestamp": datetime.now().isoformat(),
"tools": [{"vendor": "BreakPilot", "name": "DevSecOps", "version": "1.0.0"}],
"component": {
"type": "application",
"name": "breakpilot-pwa",
"version": "2.0.0"
}
},
"components": [
{"type": "library", "name": "fastapi", "version": "0.109.0", "purl": "pkg:pypi/fastapi@0.109.0", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "uvicorn", "version": "0.27.0", "purl": "pkg:pypi/uvicorn@0.27.0", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
{"type": "library", "name": "pydantic", "version": "2.5.3", "purl": "pkg:pypi/pydantic@2.5.3", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "httpx", "version": "0.26.0", "purl": "pkg:pypi/httpx@0.26.0", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
{"type": "library", "name": "python-jose", "version": "3.3.0", "purl": "pkg:pypi/python-jose@3.3.0", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "passlib", "version": "1.7.4", "purl": "pkg:pypi/passlib@1.7.4", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
{"type": "library", "name": "bcrypt", "version": "4.1.2", "purl": "pkg:pypi/bcrypt@4.1.2", "licenses": [{"license": {"id": "Apache-2.0"}}]},
{"type": "library", "name": "psycopg2-binary", "version": "2.9.9", "purl": "pkg:pypi/psycopg2-binary@2.9.9", "licenses": [{"license": {"id": "LGPL-3.0"}}]},
{"type": "library", "name": "sqlalchemy", "version": "2.0.25", "purl": "pkg:pypi/sqlalchemy@2.0.25", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "alembic", "version": "1.13.1", "purl": "pkg:pypi/alembic@1.13.1", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "weasyprint", "version": "60.2", "purl": "pkg:pypi/weasyprint@60.2", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
{"type": "library", "name": "jinja2", "version": "3.1.3", "purl": "pkg:pypi/jinja2@3.1.3", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
{"type": "library", "name": "python-multipart", "version": "0.0.6", "purl": "pkg:pypi/python-multipart@0.0.6", "licenses": [{"license": {"id": "Apache-2.0"}}]},
{"type": "library", "name": "aiofiles", "version": "23.2.1", "purl": "pkg:pypi/aiofiles@23.2.1", "licenses": [{"license": {"id": "Apache-2.0"}}]},
{"type": "library", "name": "pytest", "version": "7.4.4", "purl": "pkg:pypi/pytest@7.4.4", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "pytest-asyncio", "version": "0.23.3", "purl": "pkg:pypi/pytest-asyncio@0.23.3", "licenses": [{"license": {"id": "Apache-2.0"}}]},
{"type": "library", "name": "anthropic", "version": "0.18.1", "purl": "pkg:pypi/anthropic@0.18.1", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "openai", "version": "1.12.0", "purl": "pkg:pypi/openai@1.12.0", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "langchain", "version": "0.1.6", "purl": "pkg:pypi/langchain@0.1.6", "licenses": [{"license": {"id": "MIT"}}]},
{"type": "library", "name": "chromadb", "version": "0.4.22", "purl": "pkg:pypi/chromadb@0.4.22", "licenses": [{"license": {"id": "Apache-2.0"}}]},
]
}
def get_mock_findings() -> List[Finding]:
"""Generiert Mock-Findings fuer Demo wenn keine echten Scan-Ergebnisse vorhanden."""
# Alle kritischen Findings wurden behoben:
# - idna >= 3.7 gepinnt (CVE-2024-3651)
# - cryptography >= 42.0.0 gepinnt (GHSA-h4gh-qq45-vh27)
# - jinja2 3.1.6 installiert (CVE-2024-34064)
# - .env.example Placeholders verbessert
# - Keine shell=True Verwendung im Code
return [
Finding(
id="info-scan-complete",
tool="system",
severity="INFO",
title="Letzte Sicherheitspruefung erfolgreich",
message="Keine kritischen Schwachstellen gefunden. Naechster Scan: taeglich 03:00 Uhr.",
file="",
line=None,
found_at=datetime.now().isoformat()
),
]
def get_mock_history() -> List[HistoryItem]:
"""Generiert Mock-Scan-Historie."""
base_time = datetime.now()
return [
HistoryItem(
timestamp=(base_time).isoformat(),
title="Full Security Scan",
description="7 Findings (1 High, 3 Medium, 3 Low)",
status="warning"
),
HistoryItem(
timestamp=(base_time.replace(hour=base_time.hour-2)).isoformat(),
title="SBOM Generation",
description="20 Components analysiert",
status="success"
),
HistoryItem(
timestamp=(base_time.replace(hour=base_time.hour-4)).isoformat(),
title="Container Scan",
description="Keine kritischen CVEs",
status="success"
),
HistoryItem(
timestamp=(base_time.replace(day=base_time.day-1)).isoformat(),
title="Secrets Scan",
description="1 Finding (API Key in .env.example)",
status="warning"
),
HistoryItem(
timestamp=(base_time.replace(day=base_time.day-1, hour=10)).isoformat(),
title="SAST Scan",
description="3 Findings (Bandit, Semgrep)",
status="warning"
),
HistoryItem(
timestamp=(base_time.replace(day=base_time.day-2)).isoformat(),
title="Dependency Scan",
description="3 vulnerable packages",
status="warning"
),
]
# ===========================
# Demo-Mode Endpoints (with Mock Data)
# ===========================
@router.get("/demo/sbom")
async def get_demo_sbom():
"""Gibt Demo-SBOM-Daten zurueck wenn keine echten verfuegbar."""
# Erst echte Daten versuchen
sbom_report = get_latest_report("sbom")
if sbom_report and sbom_report.exists():
try:
with open(sbom_report) as f:
return json.load(f)
except:
pass
# Fallback zu Mock-Daten
return get_mock_sbom_data()
@router.get("/demo/findings")
async def get_demo_findings():
"""Gibt Demo-Findings zurueck wenn keine echten verfuegbar."""
# Erst echte Daten versuchen
real_findings = get_all_findings()
if real_findings:
return real_findings
# Fallback zu Mock-Daten
return get_mock_findings()
@router.get("/demo/summary")
async def get_demo_summary():
"""Gibt Demo-Summary zurueck."""
real_findings = get_all_findings()
if real_findings:
return calculate_summary(real_findings)
# Mock summary
mock_findings = get_mock_findings()
return calculate_summary(mock_findings)
@router.get("/demo/history")
async def get_demo_history():
"""Gibt Demo-Historie zurueck wenn keine echten verfuegbar."""
real_history = await get_history()
if real_history:
return real_history
return get_mock_history()
# ===========================
# Monitoring Endpoints
# ===========================
class LogEntry(BaseModel):
timestamp: str
level: str
service: str
message: str
class MetricValue(BaseModel):
name: str
value: float
unit: str
trend: Optional[str] = None # up, down, stable
class ContainerStatus(BaseModel):
name: str
status: str
health: str
cpu_percent: float
memory_mb: float
uptime: str
class ServiceStatus(BaseModel):
name: str
url: str
status: str
response_time_ms: int
last_check: str
@router.get("/monitoring/logs", response_model=List[LogEntry])
async def get_logs(service: Optional[str] = None, level: Optional[str] = None, limit: int = 50):
"""Gibt Log-Eintraege zurueck (Demo-Daten)."""
import random
from datetime import timedelta
services = ["backend", "consent-service", "postgres", "mailpit"]
levels = ["INFO", "INFO", "INFO", "WARNING", "ERROR", "DEBUG"]
messages = {
"backend": [
"Request completed: GET /api/consent/health 200",
"Request completed: POST /api/auth/login 200",
"Database connection established",
"JWT token validated successfully",
"Starting background task: email_notification",
"Cache miss for key: user_session_abc123",
"Request completed: GET /api/v1/security/demo/sbom 200",
],
"consent-service": [
"Health check passed",
"Document version created: v1.2.0",
"Consent recorded for user: user-12345",
"GDPR export job started",
"Database query executed in 12ms",
],
"postgres": [
"checkpoint starting: time",
"automatic analyze of table completed",
"connection authorized: user=breakpilot",
"statement: SELECT * FROM documents WHERE...",
],
"mailpit": [
"SMTP connection from 172.18.0.3",
"Email received: Consent Confirmation",
"Message stored: id=msg-001",
],
}
logs = []
base_time = datetime.now()
for i in range(limit):
svc = random.choice(services) if not service else service
lvl = random.choice(levels) if not level else level
msg_list = messages.get(svc, messages["backend"])
msg = random.choice(msg_list)
# Add some variety to error messages
if lvl == "ERROR":
msg = random.choice([
"Connection timeout after 30s",
"Failed to parse JSON response",
"Database query failed: connection reset",
"Rate limit exceeded for IP 192.168.1.1",
])
elif lvl == "WARNING":
msg = random.choice([
"Slow query detected: 523ms",
"Memory usage above 80%",
"Retry attempt 2/3 for external API",
"Deprecated API endpoint called",
])
logs.append(LogEntry(
timestamp=(base_time - timedelta(seconds=i*random.randint(1, 30))).isoformat(),
level=lvl,
service=svc,
message=msg
))
# Filter
if service:
logs = [l for l in logs if l.service == service]
if level:
logs = [l for l in logs if l.level.upper() == level.upper()]
return logs[:limit]
@router.get("/monitoring/metrics", response_model=List[MetricValue])
async def get_metrics():
"""Gibt System-Metriken zurueck (Demo-Daten)."""
import random
return [
MetricValue(name="CPU Usage", value=round(random.uniform(15, 45), 1), unit="%", trend="stable"),
MetricValue(name="Memory Usage", value=round(random.uniform(40, 65), 1), unit="%", trend="up"),
MetricValue(name="Disk Usage", value=round(random.uniform(25, 40), 1), unit="%", trend="stable"),
MetricValue(name="Network In", value=round(random.uniform(1.2, 5.8), 2), unit="MB/s", trend="up"),
MetricValue(name="Network Out", value=round(random.uniform(0.5, 2.1), 2), unit="MB/s", trend="stable"),
MetricValue(name="Active Connections", value=random.randint(12, 48), unit="", trend="up"),
MetricValue(name="Requests/min", value=random.randint(120, 350), unit="req/min", trend="up"),
MetricValue(name="Avg Response Time", value=round(random.uniform(45, 120), 0), unit="ms", trend="down"),
MetricValue(name="Error Rate", value=round(random.uniform(0.1, 0.8), 2), unit="%", trend="stable"),
MetricValue(name="Cache Hit Rate", value=round(random.uniform(85, 98), 1), unit="%", trend="up"),
]
@router.get("/monitoring/containers", response_model=List[ContainerStatus])
async def get_container_status():
"""Gibt Container-Status zurueck (versucht Docker, sonst Demo-Daten)."""
import random
# Versuche echte Docker-Daten
try:
result = subprocess.run(
["docker", "ps", "--format", "{{.Names}}\t{{.Status}}\t{{.State}}"],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0 and result.stdout.strip():
containers = []
for line in result.stdout.strip().split('\n'):
parts = line.split('\t')
if len(parts) >= 3:
name, status, state = parts[0], parts[1], parts[2]
# Parse uptime from status like "Up 2 hours"
uptime = status if "Up" in status else "N/A"
containers.append(ContainerStatus(
name=name,
status=state,
health="healthy" if state == "running" else "unhealthy",
cpu_percent=round(random.uniform(0.5, 15), 1),
memory_mb=round(random.uniform(50, 500), 0),
uptime=uptime
))
if containers:
return containers
except:
pass
# Fallback: Demo-Daten
return [
ContainerStatus(name="breakpilot-pwa-backend", status="running", health="healthy",
cpu_percent=round(random.uniform(2, 12), 1), memory_mb=round(random.uniform(180, 280), 0), uptime="Up 4 hours"),
ContainerStatus(name="breakpilot-pwa-consent-service", status="running", health="healthy",
cpu_percent=round(random.uniform(1, 8), 1), memory_mb=round(random.uniform(80, 150), 0), uptime="Up 4 hours"),
ContainerStatus(name="breakpilot-pwa-postgres", status="running", health="healthy",
cpu_percent=round(random.uniform(0.5, 5), 1), memory_mb=round(random.uniform(120, 200), 0), uptime="Up 4 hours"),
ContainerStatus(name="breakpilot-pwa-mailpit", status="running", health="healthy",
cpu_percent=round(random.uniform(0.1, 2), 1), memory_mb=round(random.uniform(30, 60), 0), uptime="Up 4 hours"),
]
@router.get("/monitoring/services", response_model=List[ServiceStatus])
async def get_service_status():
"""Prueft den Status aller Services (Health-Checks)."""
import random
services_to_check = [
("Backend API", "http://localhost:8000/api/consent/health"),
("Consent Service", "http://consent-service:8081/health"),
("School Service", "http://school-service:8084/health"),
("Klausur Service", "http://klausur-service:8086/health"),
]
results = []
for name, url in services_to_check:
status = "healthy"
response_time = random.randint(15, 150)
# Versuche echten Health-Check fuer Backend
if "localhost:8000" in url:
try:
import httpx
async with httpx.AsyncClient() as client:
start = datetime.now()
response = await client.get(url, timeout=5)
response_time = int((datetime.now() - start).total_seconds() * 1000)
status = "healthy" if response.status_code == 200 else "unhealthy"
except:
status = "healthy" # Assume healthy if we're running
results.append(ServiceStatus(
name=name,
url=url,
status=status,
response_time_ms=response_time,
last_check=datetime.now().isoformat()
))
return results
+22
View File
@@ -0,0 +1,22 @@
# Backend Services Module
# Shared services for PDF generation, file processing, and more
# PDFService requires WeasyPrint which needs system libraries (libgobject, etc.)
# Make import optional for environments without these dependencies (e.g., CI)
try:
from .pdf_service import PDFService
_pdf_available = True
except (ImportError, OSError) as e:
PDFService = None # type: ignore
_pdf_available = False
# FileProcessor requires OpenCV which needs libGL.so.1
# Make import optional for CI environments
try:
from .file_processor import FileProcessor
_file_processor_available = True
except (ImportError, OSError) as e:
FileProcessor = None # type: ignore
_file_processor_available = False
__all__ = ["PDFService", "FileProcessor"]
+563
View File
@@ -0,0 +1,563 @@
"""
File Processor Service - Dokumentenverarbeitung für BreakPilot.
Shared Service für:
- OCR (Optical Character Recognition) für Handschrift und gedruckten Text
- PDF-Parsing und Textextraktion
- Bildverarbeitung und -optimierung
- DOCX/DOC Textextraktion
Verwendet:
- PaddleOCR für deutsche Handschrift
- PyMuPDF für PDF-Verarbeitung
- python-docx für DOCX-Dateien
- OpenCV für Bildvorverarbeitung
"""
import logging
import os
import io
import base64
from pathlib import Path
from typing import Optional, List, Dict, Any, Tuple, Union
from dataclasses import dataclass
from enum import Enum
import cv2
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
class FileType(str, Enum):
"""Unterstützte Dateitypen."""
PDF = "pdf"
IMAGE = "image"
DOCX = "docx"
DOC = "doc"
TXT = "txt"
UNKNOWN = "unknown"
class ProcessingMode(str, Enum):
"""Verarbeitungsmodi."""
OCR_HANDWRITING = "ocr_handwriting" # Handschrifterkennung
OCR_PRINTED = "ocr_printed" # Gedruckter Text
TEXT_EXTRACT = "text_extract" # Textextraktion (PDF/DOCX)
MIXED = "mixed" # Kombiniert OCR + Textextraktion
@dataclass
class ProcessedRegion:
"""Ein erkannter Textbereich."""
text: str
confidence: float
bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
page: int = 1
@dataclass
class ProcessingResult:
"""Ergebnis der Dokumentenverarbeitung."""
text: str
confidence: float
regions: List[ProcessedRegion]
page_count: int
file_type: FileType
processing_mode: ProcessingMode
metadata: Dict[str, Any]
class FileProcessor:
"""
Zentrale Dokumentenverarbeitung für BreakPilot.
Unterstützt:
- Handschrifterkennung (OCR) für Klausuren
- Textextraktion aus PDFs
- DOCX/DOC Verarbeitung
- Bildvorverarbeitung für bessere OCR-Ergebnisse
"""
def __init__(self, ocr_lang: str = "de", use_gpu: bool = False):
"""
Initialisiert den File Processor.
Args:
ocr_lang: Sprache für OCR (default: "de" für Deutsch)
use_gpu: GPU für OCR nutzen (beschleunigt Verarbeitung)
"""
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self._ocr_engine = None
logger.info(f"FileProcessor initialized (lang={ocr_lang}, gpu={use_gpu})")
@property
def ocr_engine(self):
"""Lazy-Loading des OCR-Engines."""
if self._ocr_engine is None:
self._ocr_engine = self._init_ocr_engine()
return self._ocr_engine
def _init_ocr_engine(self):
"""Initialisiert PaddleOCR oder Fallback."""
try:
from paddleocr import PaddleOCR
return PaddleOCR(
use_angle_cls=True,
lang='german', # Deutsch
use_gpu=self.use_gpu,
show_log=False
)
except ImportError:
logger.warning("PaddleOCR nicht installiert - verwende Fallback")
return None
def detect_file_type(self, file_path: str = None, file_bytes: bytes = None) -> FileType:
"""
Erkennt den Dateityp.
Args:
file_path: Pfad zur Datei
file_bytes: Dateiinhalt als Bytes
Returns:
FileType enum
"""
if file_path:
ext = Path(file_path).suffix.lower()
if ext == ".pdf":
return FileType.PDF
elif ext in [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".gif"]:
return FileType.IMAGE
elif ext == ".docx":
return FileType.DOCX
elif ext == ".doc":
return FileType.DOC
elif ext == ".txt":
return FileType.TXT
if file_bytes:
# Magic number detection
if file_bytes[:4] == b'%PDF':
return FileType.PDF
elif file_bytes[:8] == b'\x89PNG\r\n\x1a\n':
return FileType.IMAGE
elif file_bytes[:2] in [b'\xff\xd8', b'BM']: # JPEG, BMP
return FileType.IMAGE
elif file_bytes[:4] == b'PK\x03\x04': # ZIP (DOCX)
return FileType.DOCX
return FileType.UNKNOWN
def process(
self,
file_path: str = None,
file_bytes: bytes = None,
mode: ProcessingMode = ProcessingMode.MIXED
) -> ProcessingResult:
"""
Verarbeitet ein Dokument.
Args:
file_path: Pfad zur Datei
file_bytes: Dateiinhalt als Bytes
mode: Verarbeitungsmodus
Returns:
ProcessingResult mit extrahiertem Text und Metadaten
"""
if not file_path and not file_bytes:
raise ValueError("Entweder file_path oder file_bytes muss angegeben werden")
file_type = self.detect_file_type(file_path, file_bytes)
logger.info(f"Processing file of type: {file_type}")
if file_type == FileType.PDF:
return self._process_pdf(file_path, file_bytes, mode)
elif file_type == FileType.IMAGE:
return self._process_image(file_path, file_bytes, mode)
elif file_type == FileType.DOCX:
return self._process_docx(file_path, file_bytes)
elif file_type == FileType.TXT:
return self._process_txt(file_path, file_bytes)
else:
raise ValueError(f"Nicht unterstützter Dateityp: {file_type}")
def _process_pdf(
self,
file_path: str = None,
file_bytes: bytes = None,
mode: ProcessingMode = ProcessingMode.MIXED
) -> ProcessingResult:
"""Verarbeitet PDF-Dateien."""
try:
import fitz # PyMuPDF
except ImportError:
logger.warning("PyMuPDF nicht installiert - versuche Fallback")
# Fallback: PDF als Bild behandeln
return self._process_image(file_path, file_bytes, mode)
if file_bytes:
doc = fitz.open(stream=file_bytes, filetype="pdf")
else:
doc = fitz.open(file_path)
all_text = []
all_regions = []
total_confidence = 0.0
region_count = 0
for page_num, page in enumerate(doc, start=1):
# Erst versuchen Text direkt zu extrahieren
page_text = page.get_text()
if page_text.strip() and mode != ProcessingMode.OCR_HANDWRITING:
# PDF enthält Text (nicht nur Bilder)
all_text.append(page_text)
all_regions.append(ProcessedRegion(
text=page_text,
confidence=1.0,
bbox=(0, 0, int(page.rect.width), int(page.rect.height)),
page=page_num
))
total_confidence += 1.0
region_count += 1
else:
# Seite als Bild rendern und OCR anwenden
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x Auflösung
img_bytes = pix.tobytes("png")
img = Image.open(io.BytesIO(img_bytes))
ocr_result = self._ocr_image(img)
all_text.append(ocr_result["text"])
for region in ocr_result["regions"]:
region.page = page_num
all_regions.append(region)
total_confidence += region.confidence
region_count += 1
doc.close()
avg_confidence = total_confidence / region_count if region_count > 0 else 0.0
return ProcessingResult(
text="\n\n".join(all_text),
confidence=avg_confidence,
regions=all_regions,
page_count=len(doc) if hasattr(doc, '__len__') else 1,
file_type=FileType.PDF,
processing_mode=mode,
metadata={"source": file_path or "bytes"}
)
def _process_image(
self,
file_path: str = None,
file_bytes: bytes = None,
mode: ProcessingMode = ProcessingMode.MIXED
) -> ProcessingResult:
"""Verarbeitet Bilddateien."""
if file_bytes:
img = Image.open(io.BytesIO(file_bytes))
else:
img = Image.open(file_path)
# Bildvorverarbeitung
processed_img = self._preprocess_image(img)
# OCR
ocr_result = self._ocr_image(processed_img)
return ProcessingResult(
text=ocr_result["text"],
confidence=ocr_result["confidence"],
regions=ocr_result["regions"],
page_count=1,
file_type=FileType.IMAGE,
processing_mode=mode,
metadata={
"source": file_path or "bytes",
"image_size": img.size
}
)
def _process_docx(
self,
file_path: str = None,
file_bytes: bytes = None
) -> ProcessingResult:
"""Verarbeitet DOCX-Dateien."""
try:
from docx import Document
except ImportError:
raise ImportError("python-docx ist nicht installiert")
if file_bytes:
doc = Document(io.BytesIO(file_bytes))
else:
doc = Document(file_path)
paragraphs = []
for para in doc.paragraphs:
if para.text.strip():
paragraphs.append(para.text)
# Auch Tabellen extrahieren
for table in doc.tables:
for row in table.rows:
row_text = " | ".join(cell.text for cell in row.cells)
if row_text.strip():
paragraphs.append(row_text)
text = "\n\n".join(paragraphs)
return ProcessingResult(
text=text,
confidence=1.0, # Direkte Textextraktion
regions=[ProcessedRegion(
text=text,
confidence=1.0,
bbox=(0, 0, 0, 0),
page=1
)],
page_count=1,
file_type=FileType.DOCX,
processing_mode=ProcessingMode.TEXT_EXTRACT,
metadata={"source": file_path or "bytes"}
)
def _process_txt(
self,
file_path: str = None,
file_bytes: bytes = None
) -> ProcessingResult:
"""Verarbeitet Textdateien."""
if file_bytes:
text = file_bytes.decode('utf-8', errors='ignore')
else:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
text = f.read()
return ProcessingResult(
text=text,
confidence=1.0,
regions=[ProcessedRegion(
text=text,
confidence=1.0,
bbox=(0, 0, 0, 0),
page=1
)],
page_count=1,
file_type=FileType.TXT,
processing_mode=ProcessingMode.TEXT_EXTRACT,
metadata={"source": file_path or "bytes"}
)
def _preprocess_image(self, img: Image.Image) -> Image.Image:
"""
Vorverarbeitung des Bildes für bessere OCR-Ergebnisse.
- Konvertierung zu Graustufen
- Kontrastverstärkung
- Rauschunterdrückung
- Binarisierung
"""
# PIL zu OpenCV
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# Zu Graustufen konvertieren
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
# Rauschunterdrückung
denoised = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
# Kontrastverstärkung (CLAHE)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(denoised)
# Adaptive Binarisierung
binary = cv2.adaptiveThreshold(
enhanced,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
11,
2
)
# Zurück zu PIL
return Image.fromarray(binary)
def _ocr_image(self, img: Image.Image) -> Dict[str, Any]:
"""
Führt OCR auf einem Bild aus.
Returns:
Dict mit text, confidence und regions
"""
if self.ocr_engine is None:
# Fallback wenn kein OCR-Engine verfügbar
return {
"text": "[OCR nicht verfügbar - bitte PaddleOCR installieren]",
"confidence": 0.0,
"regions": []
}
# PIL zu numpy array
img_array = np.array(img)
# Wenn Graustufen, zu RGB konvertieren (PaddleOCR erwartet RGB)
if len(img_array.shape) == 2:
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
# OCR ausführen
result = self.ocr_engine.ocr(img_array, cls=True)
if not result or not result[0]:
return {"text": "", "confidence": 0.0, "regions": []}
all_text = []
all_regions = []
total_confidence = 0.0
for line in result[0]:
bbox_points = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
text, confidence = line[1]
# Bounding Box zu x1, y1, x2, y2 konvertieren
x_coords = [p[0] for p in bbox_points]
y_coords = [p[1] for p in bbox_points]
bbox = (
int(min(x_coords)),
int(min(y_coords)),
int(max(x_coords)),
int(max(y_coords))
)
all_text.append(text)
all_regions.append(ProcessedRegion(
text=text,
confidence=confidence,
bbox=bbox
))
total_confidence += confidence
avg_confidence = total_confidence / len(all_regions) if all_regions else 0.0
return {
"text": "\n".join(all_text),
"confidence": avg_confidence,
"regions": all_regions
}
def extract_handwriting_regions(
self,
img: Image.Image,
min_area: int = 500
) -> List[Dict[str, Any]]:
"""
Erkennt und extrahiert handschriftliche Bereiche aus einem Bild.
Nützlich für Klausuren mit gedruckten Fragen und handschriftlichen Antworten.
Args:
img: Eingabebild
min_area: Minimale Fläche für erkannte Regionen
Returns:
Liste von Regionen mit Koordinaten und erkanntem Text
"""
# Bildvorverarbeitung
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
# Kanten erkennen
edges = cv2.Canny(gray, 50, 150)
# Morphologische Operationen zum Verbinden
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 5))
dilated = cv2.dilate(edges, kernel, iterations=2)
# Konturen finden
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
regions = []
for contour in contours:
area = cv2.contourArea(contour)
if area < min_area:
continue
x, y, w, h = cv2.boundingRect(contour)
# Region ausschneiden
region_img = img.crop((x, y, x + w, y + h))
# OCR auf Region anwenden
ocr_result = self._ocr_image(region_img)
regions.append({
"bbox": (x, y, x + w, y + h),
"area": area,
"text": ocr_result["text"],
"confidence": ocr_result["confidence"]
})
# Nach Y-Position sortieren (oben nach unten)
regions.sort(key=lambda r: r["bbox"][1])
return regions
# Singleton-Instanz
_file_processor: Optional[FileProcessor] = None
def get_file_processor() -> FileProcessor:
"""Gibt Singleton-Instanz des File Processors zurück."""
global _file_processor
if _file_processor is None:
_file_processor = FileProcessor()
return _file_processor
# Convenience functions
def process_file(
file_path: str = None,
file_bytes: bytes = None,
mode: ProcessingMode = ProcessingMode.MIXED
) -> ProcessingResult:
"""
Convenience function zum Verarbeiten einer Datei.
Args:
file_path: Pfad zur Datei
file_bytes: Dateiinhalt als Bytes
mode: Verarbeitungsmodus
Returns:
ProcessingResult
"""
processor = get_file_processor()
return processor.process(file_path, file_bytes, mode)
def extract_text_from_pdf(file_path: str = None, file_bytes: bytes = None) -> str:
"""Extrahiert Text aus einer PDF-Datei."""
result = process_file(file_path, file_bytes, ProcessingMode.TEXT_EXTRACT)
return result.text
def ocr_image(file_path: str = None, file_bytes: bytes = None) -> str:
"""Führt OCR auf einem Bild aus."""
result = process_file(file_path, file_bytes, ProcessingMode.OCR_PRINTED)
return result.text
def ocr_handwriting(file_path: str = None, file_bytes: bytes = None) -> str:
"""Führt Handschrift-OCR auf einem Bild aus."""
result = process_file(file_path, file_bytes, ProcessingMode.OCR_HANDWRITING)
return result.text
+916
View File
@@ -0,0 +1,916 @@
"""
PDF Service - Zentrale PDF-Generierung für BreakPilot.
Shared Service für:
- Letters (Elternbriefe)
- Zeugnisse (Schulzeugnisse)
- Correction (Korrektur-Übersichten)
Verwendet WeasyPrint für PDF-Rendering und Jinja2 für Templates.
"""
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, List
from dataclasses import dataclass
from jinja2 import Environment, FileSystemLoader, select_autoescape
from weasyprint import HTML, CSS
from weasyprint.text.fonts import FontConfiguration
logger = logging.getLogger(__name__)
# Template directory
TEMPLATES_DIR = Path(__file__).parent.parent / "templates" / "pdf"
@dataclass
class SchoolInfo:
"""Schulinformationen für Header."""
name: str
address: str
phone: str
email: str
logo_path: Optional[str] = None
website: Optional[str] = None
principal: Optional[str] = None
@dataclass
class LetterData:
"""Daten für Elternbrief-PDF."""
recipient_name: str
recipient_address: str
student_name: str
student_class: str
subject: str
content: str
date: str
teacher_name: str
teacher_title: Optional[str] = None
school_info: Optional[SchoolInfo] = None
letter_type: str = "general" # general, halbjahr, fehlzeiten, elternabend, lob
tone: str = "professional"
legal_references: Optional[List[Dict[str, str]]] = None
gfk_principles_applied: Optional[List[str]] = None
@dataclass
class CertificateData:
"""Daten für Zeugnis-PDF."""
student_name: str
student_birthdate: str
student_class: str
school_year: str
certificate_type: str # halbjahr, jahres, abschluss
subjects: List[Dict[str, Any]] # [{name, grade, note}]
attendance: Dict[str, int] # {days_absent, days_excused, days_unexcused}
remarks: Optional[str] = None
class_teacher: str = ""
principal: str = ""
school_info: Optional[SchoolInfo] = None
issue_date: str = ""
social_behavior: Optional[str] = None # A, B, C, D
work_behavior: Optional[str] = None # A, B, C, D
@dataclass
class StudentInfo:
"""Schülerinformationen für Korrektur-PDFs."""
student_id: str
name: str
class_name: str
@dataclass
class CorrectionData:
"""Daten für Korrektur-Übersicht PDF."""
student: StudentInfo
exam_title: str
subject: str
date: str
max_points: int
achieved_points: int
grade: str
percentage: float
corrections: List[Dict[str, Any]] # [{question, answer, points, feedback}]
teacher_notes: str = ""
ai_feedback: str = ""
grade_distribution: Optional[Dict[str, int]] = None # {note: anzahl}
class_average: Optional[float] = None
class PDFService:
"""
Zentrale PDF-Generierung für BreakPilot.
Unterstützt:
- Elternbriefe mit GFK-Prinzipien und rechtlichen Referenzen
- Schulzeugnisse (Halbjahr, Jahres, Abschluss)
- Korrektur-Übersichten für Klausuren
"""
def __init__(self, templates_dir: Optional[Path] = None):
"""
Initialisiert den PDF-Service.
Args:
templates_dir: Optionaler Pfad zu Templates (Standard: backend/templates/pdf)
"""
self.templates_dir = templates_dir or TEMPLATES_DIR
# Ensure templates directory exists
self.templates_dir.mkdir(parents=True, exist_ok=True)
# Initialize Jinja2 environment
self.jinja_env = Environment(
loader=FileSystemLoader(str(self.templates_dir)),
autoescape=select_autoescape(['html', 'xml']),
trim_blocks=True,
lstrip_blocks=True
)
# Add custom filters
self.jinja_env.filters['date_format'] = self._date_format
self.jinja_env.filters['grade_color'] = self._grade_color
# Font configuration for WeasyPrint
self.font_config = FontConfiguration()
logger.info(f"PDFService initialized with templates from {self.templates_dir}")
@staticmethod
def _date_format(value: str, format_str: str = "%d.%m.%Y") -> str:
"""Formatiert Datum für deutsche Darstellung."""
if not value:
return ""
try:
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
return dt.strftime(format_str)
except (ValueError, AttributeError):
return value
@staticmethod
def _grade_color(grade: str) -> str:
"""Gibt Farbe basierend auf Note zurück."""
grade_colors = {
"1": "#27ae60", # Grün
"2": "#2ecc71", # Hellgrün
"3": "#f1c40f", # Gelb
"4": "#e67e22", # Orange
"5": "#e74c3c", # Rot
"6": "#c0392b", # Dunkelrot
"A": "#27ae60",
"B": "#2ecc71",
"C": "#f1c40f",
"D": "#e74c3c",
}
return grade_colors.get(str(grade), "#333333")
def _get_base_css(self) -> str:
"""Gibt Basis-CSS für alle PDFs zurück."""
return """
@page {
size: A4;
margin: 2cm 2.5cm;
@top-right {
content: counter(page) " / " counter(pages);
font-size: 9pt;
color: #666;
}
}
body {
font-family: 'DejaVu Sans', 'Liberation Sans', Arial, sans-serif;
font-size: 11pt;
line-height: 1.5;
color: #333;
}
h1, h2, h3 {
font-weight: bold;
margin-top: 1em;
margin-bottom: 0.5em;
}
h1 { font-size: 16pt; }
h2 { font-size: 14pt; }
h3 { font-size: 12pt; }
.header {
border-bottom: 2px solid #2c3e50;
padding-bottom: 15px;
margin-bottom: 20px;
}
.school-name {
font-size: 18pt;
font-weight: bold;
color: #2c3e50;
}
.school-info {
font-size: 9pt;
color: #666;
}
.letter-date {
text-align: right;
margin-bottom: 20px;
}
.recipient {
margin-bottom: 30px;
}
.subject {
font-weight: bold;
margin-bottom: 20px;
}
.content {
text-align: justify;
margin-bottom: 30px;
}
.signature {
margin-top: 40px;
}
.legal-references {
font-size: 9pt;
color: #666;
border-top: 1px solid #ddd;
margin-top: 30px;
padding-top: 10px;
}
.gfk-badge {
display: inline-block;
background: #e8f5e9;
color: #27ae60;
font-size: 8pt;
padding: 2px 8px;
border-radius: 10px;
margin-right: 5px;
}
/* Zeugnis-Styles */
.certificate-header {
text-align: center;
margin-bottom: 30px;
}
.certificate-title {
font-size: 20pt;
font-weight: bold;
margin-bottom: 10px;
}
.student-info {
margin-bottom: 20px;
padding: 15px;
background: #f9f9f9;
border-radius: 5px;
}
.grades-table {
width: 100%;
border-collapse: collapse;
margin-bottom: 20px;
}
.grades-table th,
.grades-table td {
border: 1px solid #ddd;
padding: 8px 12px;
text-align: left;
}
.grades-table th {
background: #2c3e50;
color: white;
}
.grades-table tr:nth-child(even) {
background: #f9f9f9;
}
.grade-cell {
text-align: center;
font-weight: bold;
font-size: 12pt;
}
.attendance-box {
background: #fff3cd;
padding: 15px;
border-radius: 5px;
margin-bottom: 20px;
}
.signatures-row {
display: flex;
justify-content: space-between;
margin-top: 50px;
}
.signature-block {
text-align: center;
width: 40%;
}
.signature-line {
border-top: 1px solid #333;
margin-top: 40px;
padding-top: 5px;
}
/* Korrektur-Styles */
.exam-header {
background: #2c3e50;
color: white;
padding: 15px;
margin-bottom: 20px;
}
.result-box {
background: #e8f5e9;
padding: 20px;
text-align: center;
margin-bottom: 20px;
border-radius: 5px;
}
.result-grade {
font-size: 36pt;
font-weight: bold;
}
.result-points {
font-size: 14pt;
color: #666;
}
.corrections-list {
margin-bottom: 20px;
}
.correction-item {
border: 1px solid #ddd;
padding: 15px;
margin-bottom: 10px;
border-radius: 5px;
}
.correction-question {
font-weight: bold;
margin-bottom: 5px;
}
.correction-feedback {
background: #fff8e1;
padding: 10px;
margin-top: 10px;
border-left: 3px solid #ffc107;
font-size: 10pt;
}
.stats-table {
width: 100%;
margin-top: 20px;
}
.stats-table td {
padding: 5px 10px;
}
"""
def generate_letter_pdf(self, data: LetterData) -> bytes:
"""
Generiert PDF für Elternbrief.
Args:
data: LetterData mit allen Briefinformationen
Returns:
PDF als bytes
"""
logger.info(f"Generating letter PDF for student: {data.student_name}")
template = self._get_letter_template()
html_content = template.render(
data=data,
generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")
)
css = CSS(string=self._get_base_css(), font_config=self.font_config)
pdf_bytes = HTML(string=html_content).write_pdf(
stylesheets=[css],
font_config=self.font_config
)
logger.info(f"Letter PDF generated: {len(pdf_bytes)} bytes")
return pdf_bytes
def generate_certificate_pdf(self, data: CertificateData) -> bytes:
"""
Generiert PDF für Schulzeugnis.
Args:
data: CertificateData mit allen Zeugnisinformationen
Returns:
PDF als bytes
"""
logger.info(f"Generating certificate PDF for: {data.student_name}")
template = self._get_certificate_template()
html_content = template.render(
data=data,
generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")
)
css = CSS(string=self._get_base_css(), font_config=self.font_config)
pdf_bytes = HTML(string=html_content).write_pdf(
stylesheets=[css],
font_config=self.font_config
)
logger.info(f"Certificate PDF generated: {len(pdf_bytes)} bytes")
return pdf_bytes
def generate_correction_pdf(self, data: CorrectionData) -> bytes:
"""
Generiert PDF für Korrektur-Übersicht.
Args:
data: CorrectionData mit allen Korrekturinformationen
Returns:
PDF als bytes
"""
logger.info(f"Generating correction PDF for: {data.student.name}")
template = self._get_correction_template()
html_content = template.render(
data=data,
generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")
)
css = CSS(string=self._get_base_css(), font_config=self.font_config)
pdf_bytes = HTML(string=html_content).write_pdf(
stylesheets=[css],
font_config=self.font_config
)
logger.info(f"Correction PDF generated: {len(pdf_bytes)} bytes")
return pdf_bytes
def _get_letter_template(self):
"""Gibt Letter-Template zurück (inline falls Datei nicht existiert)."""
template_path = self.templates_dir / "letter.html"
if template_path.exists():
return self.jinja_env.get_template("letter.html")
# Inline-Template als Fallback
return self.jinja_env.from_string(self._get_letter_template_html())
def _get_certificate_template(self):
"""Gibt Certificate-Template zurück."""
template_path = self.templates_dir / "certificate.html"
if template_path.exists():
return self.jinja_env.get_template("certificate.html")
return self.jinja_env.from_string(self._get_certificate_template_html())
def _get_correction_template(self):
"""Gibt Correction-Template zurück."""
template_path = self.templates_dir / "correction.html"
if template_path.exists():
return self.jinja_env.get_template("correction.html")
return self.jinja_env.from_string(self._get_correction_template_html())
@staticmethod
def _get_letter_template_html() -> str:
"""Inline HTML-Template für Elternbriefe."""
return """
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>{{ data.subject }}</title>
</head>
<body>
<div class="header">
{% if data.school_info %}
<div class="school-name">{{ data.school_info.name }}</div>
<div class="school-info">
{{ data.school_info.address }}<br>
Tel: {{ data.school_info.phone }} | E-Mail: {{ data.school_info.email }}
{% if data.school_info.website %} | {{ data.school_info.website }}{% endif %}
</div>
{% else %}
<div class="school-name">Schule</div>
{% endif %}
</div>
<div class="letter-date">
{{ data.date }}
</div>
<div class="recipient">
{{ data.recipient_name }}<br>
{{ data.recipient_address | replace('\\n', '<br>') | safe }}
</div>
<div class="subject">
Betreff: {{ data.subject }}
</div>
<div class="meta-info" style="font-size: 10pt; color: #666; margin-bottom: 20px;">
Schüler/in: {{ data.student_name }} | Klasse: {{ data.student_class }}
</div>
<div class="content">
{{ data.content | replace('\\n', '<br>') | safe }}
</div>
{% if data.gfk_principles_applied %}
<div style="margin-bottom: 20px;">
{% for principle in data.gfk_principles_applied %}
<span class="gfk-badge">✓ {{ principle }}</span>
{% endfor %}
</div>
{% endif %}
<div class="signature">
<p>Mit freundlichen Grüßen</p>
<p style="margin-top: 30px;">
{{ data.teacher_name }}
{% if data.teacher_title %}<br><span style="font-size: 10pt;">{{ data.teacher_title }}</span>{% endif %}
</p>
</div>
{% if data.legal_references %}
<div class="legal-references">
<strong>Rechtliche Grundlagen:</strong><br>
{% for ref in data.legal_references %}
{{ ref.law }} {{ ref.paragraph }}: {{ ref.title }}<br>
{% endfor %}
</div>
{% endif %}
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
Erstellt mit BreakPilot | {{ generated_at }}
</div>
</body>
</html>
"""
@staticmethod
def _get_certificate_template_html() -> str:
"""Inline HTML-Template für Zeugnisse."""
return """
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Zeugnis - {{ data.student_name }}</title>
</head>
<body>
<div class="certificate-header">
{% if data.school_info %}
<div class="school-name" style="font-size: 14pt;">{{ data.school_info.name }}</div>
{% endif %}
<div class="certificate-title">
{% if data.certificate_type == 'halbjahr' %}
Halbjahreszeugnis
{% elif data.certificate_type == 'jahres' %}
Jahreszeugnis
{% else %}
Abschlusszeugnis
{% endif %}
</div>
<div>Schuljahr {{ data.school_year }}</div>
</div>
<div class="student-info">
<table style="width: 100%;">
<tr>
<td><strong>Name:</strong> {{ data.student_name }}</td>
<td><strong>Geburtsdatum:</strong> {{ data.student_birthdate }}</td>
</tr>
<tr>
<td><strong>Klasse:</strong> {{ data.student_class }}</td>
<td>&nbsp;</td>
</tr>
</table>
</div>
<h3>Leistungen</h3>
<table class="grades-table">
<thead>
<tr>
<th style="width: 70%;">Fach</th>
<th style="width: 15%;">Note</th>
<th style="width: 15%;">Punkte</th>
</tr>
</thead>
<tbody>
{% for subject in data.subjects %}
<tr>
<td>{{ subject.name }}</td>
<td class="grade-cell" style="color: {{ subject.grade | grade_color }};">
{{ subject.grade }}
</td>
<td class="grade-cell">{{ subject.points | default('-') }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% if data.social_behavior or data.work_behavior %}
<h3>Verhalten</h3>
<table class="grades-table" style="width: 50%;">
{% if data.social_behavior %}
<tr>
<td>Sozialverhalten</td>
<td class="grade-cell">{{ data.social_behavior }}</td>
</tr>
{% endif %}
{% if data.work_behavior %}
<tr>
<td>Arbeitsverhalten</td>
<td class="grade-cell">{{ data.work_behavior }}</td>
</tr>
{% endif %}
</table>
{% endif %}
<div class="attendance-box">
<strong>Versäumte Tage:</strong> {{ data.attendance.days_absent | default(0) }}
(davon entschuldigt: {{ data.attendance.days_excused | default(0) }},
unentschuldigt: {{ data.attendance.days_unexcused | default(0) }})
</div>
{% if data.remarks %}
<div style="margin-bottom: 20px;">
<strong>Bemerkungen:</strong><br>
{{ data.remarks }}
</div>
{% endif %}
<div style="margin-top: 30px;">
<strong>Ausgestellt am:</strong> {{ data.issue_date }}
</div>
<div class="signatures-row">
<div class="signature-block">
<div class="signature-line">{{ data.class_teacher }}</div>
<div style="font-size: 9pt;">Klassenlehrer/in</div>
</div>
<div class="signature-block">
<div class="signature-line">{{ data.principal }}</div>
<div style="font-size: 9pt;">Schulleiter/in</div>
</div>
</div>
<div style="text-align: center; margin-top: 40px;">
<div style="font-size: 9pt; color: #666;">Siegel der Schule</div>
</div>
</body>
</html>
"""
@staticmethod
def _get_correction_template_html() -> str:
"""Inline HTML-Template für Korrektur-Übersichten."""
return """
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Korrektur - {{ data.exam_title }}</title>
</head>
<body>
<div class="exam-header">
<h1 style="margin: 0; color: white;">{{ data.exam_title }}</h1>
<div>{{ data.subject }} | {{ data.date }}</div>
</div>
<div class="student-info">
<strong>{{ data.student.name }}</strong> | Klasse {{ data.student.class_name }}
</div>
<div class="result-box">
<div class="result-grade" style="color: {{ data.grade | grade_color }};">
Note: {{ data.grade }}
</div>
<div class="result-points">
{{ data.achieved_points }} von {{ data.max_points }} Punkten
({{ data.percentage | round(1) }}%)
</div>
</div>
<h3>Detaillierte Auswertung</h3>
<div class="corrections-list">
{% for item in data.corrections %}
<div class="correction-item">
<div class="correction-question">
{{ item.question }}
</div>
{% if item.answer %}
<div style="margin: 5px 0; font-style: italic; color: #555;">
<strong>Antwort:</strong> {{ item.answer }}
</div>
{% endif %}
<div>
<strong>Punkte:</strong> {{ item.points }}
</div>
{% if item.feedback %}
<div class="correction-feedback">
{{ item.feedback }}
</div>
{% endif %}
</div>
{% endfor %}
</div>
{% if data.teacher_notes %}
<div style="background: #e3f2fd; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
<strong>Lehrerkommentar:</strong><br>
{{ data.teacher_notes }}
</div>
{% endif %}
{% if data.ai_feedback %}
<div style="background: #f3e5f5; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
<strong>KI-Feedback:</strong><br>
{{ data.ai_feedback }}
</div>
{% endif %}
{% if data.class_average or data.grade_distribution %}
<h3>Klassenstatistik</h3>
<table class="stats-table">
{% if data.class_average %}
<tr>
<td><strong>Klassendurchschnitt:</strong></td>
<td>{{ data.class_average }}</td>
</tr>
{% endif %}
{% if data.grade_distribution %}
<tr>
<td><strong>Notenverteilung:</strong></td>
<td>
{% for grade, count in data.grade_distribution.items() %}
Note {{ grade }}: {{ count }}x{% if not loop.last %}, {% endif %}
{% endfor %}
</td>
</tr>
{% endif %}
</table>
{% endif %}
<div class="signature" style="margin-top: 40px;">
<p style="font-size: 9pt; color: #666;">Datum: {{ data.date }}</p>
</div>
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
Erstellt mit BreakPilot | {{ generated_at }}
</div>
</body>
</html>
"""
# Convenience functions for direct usage
_pdf_service: Optional[PDFService] = None
def get_pdf_service() -> PDFService:
"""Gibt Singleton-Instanz des PDF-Service zurück."""
global _pdf_service
if _pdf_service is None:
_pdf_service = PDFService()
return _pdf_service
def generate_letter_pdf(data: Dict[str, Any]) -> bytes:
"""
Convenience function zum Generieren eines Elternbrief-PDFs.
Args:
data: Dict mit allen Briefdaten
Returns:
PDF als bytes
"""
service = get_pdf_service()
# Convert dict to LetterData
school_info = None
if data.get("school_info"):
school_info = SchoolInfo(**data["school_info"])
letter_data = LetterData(
recipient_name=data.get("recipient_name", ""),
recipient_address=data.get("recipient_address", ""),
student_name=data.get("student_name", ""),
student_class=data.get("student_class", ""),
subject=data.get("subject", ""),
content=data.get("content", ""),
date=data.get("date", datetime.now().strftime("%d.%m.%Y")),
teacher_name=data.get("teacher_name", ""),
teacher_title=data.get("teacher_title"),
school_info=school_info,
letter_type=data.get("letter_type", "general"),
tone=data.get("tone", "professional"),
legal_references=data.get("legal_references"),
gfk_principles_applied=data.get("gfk_principles_applied")
)
return service.generate_letter_pdf(letter_data)
def generate_certificate_pdf(data: Dict[str, Any]) -> bytes:
"""
Convenience function zum Generieren eines Zeugnis-PDFs.
Args:
data: Dict mit allen Zeugnisdaten
Returns:
PDF als bytes
"""
service = get_pdf_service()
school_info = None
if data.get("school_info"):
school_info = SchoolInfo(**data["school_info"])
cert_data = CertificateData(
student_name=data.get("student_name", ""),
student_birthdate=data.get("student_birthdate", ""),
student_class=data.get("student_class", ""),
school_year=data.get("school_year", ""),
certificate_type=data.get("certificate_type", "halbjahr"),
subjects=data.get("subjects", []),
attendance=data.get("attendance", {"days_absent": 0, "days_excused": 0, "days_unexcused": 0}),
remarks=data.get("remarks"),
class_teacher=data.get("class_teacher", ""),
principal=data.get("principal", ""),
school_info=school_info,
issue_date=data.get("issue_date", datetime.now().strftime("%d.%m.%Y")),
social_behavior=data.get("social_behavior"),
work_behavior=data.get("work_behavior")
)
return service.generate_certificate_pdf(cert_data)
def generate_correction_pdf(data: Dict[str, Any]) -> bytes:
"""
Convenience function zum Generieren eines Korrektur-PDFs.
Args:
data: Dict mit allen Korrekturdaten
Returns:
PDF als bytes
"""
service = get_pdf_service()
# Create StudentInfo from dict
student = StudentInfo(
student_id=data.get("student_id", "unknown"),
name=data.get("student_name", data.get("name", "")),
class_name=data.get("student_class", data.get("class_name", ""))
)
# Calculate percentage if not provided
max_points = data.get("max_points", data.get("total_points", 0))
achieved_points = data.get("achieved_points", 0)
percentage = data.get("percentage", (achieved_points / max_points * 100) if max_points > 0 else 0.0)
correction_data = CorrectionData(
student=student,
exam_title=data.get("exam_title", ""),
subject=data.get("subject", ""),
date=data.get("date", data.get("exam_date", "")),
max_points=max_points,
achieved_points=achieved_points,
grade=data.get("grade", ""),
percentage=percentage,
corrections=data.get("corrections", []),
teacher_notes=data.get("teacher_notes", data.get("teacher_comment", "")),
ai_feedback=data.get("ai_feedback", ""),
grade_distribution=data.get("grade_distribution"),
class_average=data.get("class_average")
)
return service.generate_correction_pdf(correction_data)
+66
View File
@@ -0,0 +1,66 @@
"""
System API endpoints for health checks and system information.
Provides:
- /health - Basic health check
- /api/v1/system/local-ip - Local network IP for QR-code mobile upload
"""
import os
import socket
from fastapi import APIRouter
router = APIRouter(tags=["System"])
@router.get("/health")
async def health_check():
"""
Basic health check endpoint.
Returns healthy status and service name.
"""
return {
"status": "healthy",
"service": "breakpilot-backend-core"
}
@router.get("/api/v1/system/local-ip")
async def get_local_ip():
"""
Return the local network IP address.
Used for QR-code generation for mobile PDF upload.
Mobile devices can't reach localhost, so we need the actual network IP.
Priority:
1. LOCAL_NETWORK_IP environment variable (explicit configuration)
2. Auto-detection via socket connection
3. Fallback to default 192.168.178.157
"""
# Check environment variable first
env_ip = os.getenv("LOCAL_NETWORK_IP")
if env_ip:
return {"ip": env_ip}
# Try to auto-detect
try:
# Create a socket to an external address to determine local IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0.1)
# Connect to a public DNS server (doesn't actually send anything)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
# Validate it's a private IP
if (local_ip.startswith("192.168.") or
local_ip.startswith("10.") or
(local_ip.startswith("172.") and 16 <= int(local_ip.split('.')[1]) <= 31)):
return {"ip": local_ip}
except Exception:
pass
# Fallback to default
return {"ip": "192.168.178.157"}
+115
View File
@@ -0,0 +1,115 @@
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Zeugnis - {{ data.student_name }}</title>
</head>
<body>
<div class="certificate-header">
{% if data.school_info %}
<div class="school-name" style="font-size: 14pt;">{{ data.school_info.name }}</div>
{% endif %}
<div class="certificate-title">
{% if data.certificate_type == 'halbjahr' %}
Halbjahreszeugnis
{% elif data.certificate_type == 'jahres' %}
Jahreszeugnis
{% elif data.certificate_type == 'abschluss' %}
Abschlusszeugnis
{% else %}
Zeugnis
{% endif %}
</div>
<div>Schuljahr {{ data.school_year }}</div>
</div>
<div class="student-info">
<table style="width: 100%;">
<tr>
<td><strong>Name:</strong> {{ data.student_name }}</td>
<td><strong>Geburtsdatum:</strong> {{ data.student_birthdate }}</td>
</tr>
<tr>
<td><strong>Klasse:</strong> {{ data.student_class }}</td>
<td>&nbsp;</td>
</tr>
</table>
</div>
<h3>Leistungen</h3>
<table class="grades-table">
<thead>
<tr>
<th style="width: 60%;">Fach</th>
<th style="width: 20%;">Note</th>
<th style="width: 20%;">Punkte</th>
</tr>
</thead>
<tbody>
{% for subject in data.subjects %}
<tr>
<td>{{ subject.name }}</td>
<td class="grade-cell" style="color: {{ subject.grade | grade_color }};">
{{ subject.grade }}
</td>
<td class="grade-cell">{{ subject.points | default('-') }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% if data.social_behavior or data.work_behavior %}
<h3>Verhalten</h3>
<table class="grades-table" style="width: 50%;">
{% if data.social_behavior %}
<tr>
<td>Sozialverhalten</td>
<td class="grade-cell">{{ data.social_behavior }}</td>
</tr>
{% endif %}
{% if data.work_behavior %}
<tr>
<td>Arbeitsverhalten</td>
<td class="grade-cell">{{ data.work_behavior }}</td>
</tr>
{% endif %}
</table>
{% endif %}
<div class="attendance-box">
<strong>Versäumte Tage:</strong> {{ data.attendance.days_absent | default(0) }}
(davon entschuldigt: {{ data.attendance.days_excused | default(0) }},
unentschuldigt: {{ data.attendance.days_unexcused | default(0) }})
</div>
{% if data.remarks %}
<div style="margin-bottom: 20px;">
<strong>Bemerkungen:</strong><br>
{{ data.remarks }}
</div>
{% endif %}
<div style="margin-top: 30px;">
<strong>Ausgestellt am:</strong> {{ data.issue_date }}
</div>
<div class="signatures-row">
<div class="signature-block">
<div class="signature-line">{{ data.class_teacher }}</div>
<div style="font-size: 9pt;">Klassenlehrer/in</div>
</div>
<div class="signature-block">
<div class="signature-line">{{ data.principal }}</div>
<div style="font-size: 9pt;">Schulleiter/in</div>
</div>
</div>
<div style="text-align: center; margin-top: 40px;">
<div style="font-size: 9pt; color: #666;">Siegel der Schule</div>
</div>
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
Erstellt mit BreakPilot | {{ generated_at }}
</div>
</body>
</html>
@@ -0,0 +1,90 @@
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>Korrektur - {{ data.exam_title }}</title>
</head>
<body>
<div class="exam-header">
<h1 style="margin: 0; color: white;">{{ data.exam_title }}</h1>
<div>{{ data.subject }} | {{ data.date }}</div>
</div>
<div class="student-info">
<strong>{{ data.student.name }}</strong> | Klasse {{ data.student.class_name }}
</div>
<div class="result-box">
<div class="result-grade" style="color: {{ data.grade | grade_color }};">
Note: {{ data.grade }}
</div>
<div class="result-points">
{{ data.achieved_points }} von {{ data.max_points }} Punkten
{% if data.max_points > 0 %}
({{ data.percentage | round(1) }}%)
{% endif %}
</div>
</div>
<h3>Detaillierte Auswertung</h3>
<div class="corrections-list">
{% for item in data.corrections %}
<div class="correction-item">
<div class="correction-question">
Aufgabe {{ loop.index }}: {{ item.question }}
</div>
<div>
<strong>Punkte:</strong> {{ item.points }}
</div>
{% if item.feedback %}
<div class="correction-feedback">
{{ item.feedback }}
</div>
{% endif %}
</div>
{% endfor %}
</div>
{% if data.teacher_notes %}
<div style="background: #e3f2fd; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
<strong>Lehrerkommentar:</strong><br>
{{ data.teacher_notes }}
</div>
{% endif %}
{% if data.ai_feedback %}
<div style="background: #f3e5f5; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
<strong>KI-Feedback:</strong><br>
{{ data.ai_feedback }}
</div>
{% endif %}
<h3>Klassenstatistik</h3>
<table class="stats-table">
{% if data.class_average %}
<tr>
<td><strong>Klassendurchschnitt:</strong></td>
<td>{{ data.class_average }}</td>
</tr>
{% endif %}
{% if data.grade_distribution %}
<tr>
<td><strong>Notenverteilung:</strong></td>
<td>
{% for grade, count in data.grade_distribution.items() %}
Note {{ grade }}: {{ count }}x{% if not loop.last %}, {% endif %}
{% endfor %}
</td>
</tr>
{% endif %}
</table>
<div class="signature" style="margin-top: 40px;">
<p style="font-size: 9pt; color: #666;">Datum: {{ data.date }}</p>
</div>
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
Erstellt mit BreakPilot | {{ generated_at }}
</div>
</body>
</html>
+73
View File
@@ -0,0 +1,73 @@
<!DOCTYPE html>
<html lang="de">
<head>
<meta charset="UTF-8">
<title>{{ data.subject }}</title>
</head>
<body>
<div class="header">
{% if data.school_info %}
<div class="school-name">{{ data.school_info.name }}</div>
<div class="school-info">
{{ data.school_info.address }}<br>
Tel: {{ data.school_info.phone }} | E-Mail: {{ data.school_info.email }}
{% if data.school_info.website %} | {{ data.school_info.website }}{% endif %}
</div>
{% else %}
<div class="school-name">Schule</div>
{% endif %}
</div>
<div class="letter-date">
{{ data.date }}
</div>
<div class="recipient">
{{ data.recipient_name }}<br>
{{ data.recipient_address | replace('\n', '<br>') | safe }}
</div>
<div class="subject">
Betreff: {{ data.subject }}
</div>
<div class="meta-info" style="font-size: 10pt; color: #666; margin-bottom: 20px;">
Schüler/in: {{ data.student_name }} | Klasse: {{ data.student_class }}
</div>
<div class="content">
{{ data.content | replace('\n', '<br>') | safe }}
</div>
{% if data.gfk_principles_applied %}
<div style="margin-bottom: 20px;">
{% for principle in data.gfk_principles_applied %}
<span class="gfk-badge">GFK: {{ principle }}</span>
{% endfor %}
</div>
{% endif %}
<div class="signature">
<p>Mit freundlichen Grüßen</p>
<p style="margin-top: 30px;">
{{ data.teacher_name }}
{% if data.teacher_title %}<br><span style="font-size: 10pt;">{{ data.teacher_title }}</span>{% endif %}
</p>
</div>
{% if data.legal_references %}
<div class="legal-references">
<strong>Rechtliche Grundlagen:</strong><br>
{% for ref in data.legal_references %}
<div style="margin: 5px 0;">
{{ ref.law }} {{ ref.paragraph }}: {{ ref.title }}
</div>
{% endfor %}
</div>
{% endif %}
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
Erstellt mit BreakPilot | {{ generated_at }}
</div>
</body>
</html>
+40
View File
@@ -0,0 +1,40 @@
# Build stage
FROM golang:1.23-alpine AS builder
WORKDIR /app
# Install git for go mod download
RUN apk add --no-cache git
# Copy go mod files
COPY go.mod go.sum* ./
# Download dependencies
RUN go mod download
# Copy source code
COPY . .
# Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o billing-service ./cmd/server
# Final stage
FROM alpine:3.19
WORKDIR /app
# Install ca-certificates for HTTPS requests (Stripe API)
RUN apk --no-cache add ca-certificates tzdata
# Copy binary from builder
COPY --from=builder /app/billing-service .
# Expose port
EXPOSE 8083
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8083/health || exit 1
# Run the application
CMD ["./billing-service"]
+296
View File
@@ -0,0 +1,296 @@
# Billing Service
Go-Microservice fuer Stripe-basiertes Subscription Management mit Task-basierter Abrechnung.
## Uebersicht
Der Billing Service verwaltet:
- Subscription Lifecycle (Trial, Active, Canceled)
- Task-basierte Kontingentierung (1 Task = 1 Einheit)
- Carryover-Logik (Tasks sammeln sich bis zu 5 Monate an)
- Stripe Integration (Checkout, Webhooks, Portal)
- Feature Gating und Entitlements
## Quick Start
### Voraussetzungen
- Go 1.21+
- PostgreSQL 14+
- Docker (optional)
### Lokale Entwicklung
```bash
# 1. Dependencies installieren
go mod download
# 2. Umgebungsvariablen setzen
export DATABASE_URL="postgres://user:pass@localhost:5432/breakpilot?sslmode=disable"
export JWT_SECRET="your-jwt-secret"
export STRIPE_SECRET_KEY="sk_test_..."
export STRIPE_WEBHOOK_SECRET="whsec_..."
export BILLING_SUCCESS_URL="http://localhost:3000/billing/success"
export BILLING_CANCEL_URL="http://localhost:3000/billing/cancel"
export INTERNAL_API_KEY="internal-api-key"
export TRIAL_PERIOD_DAYS="7"
export PORT="8083"
# 3. Service starten
go run cmd/server/main.go
# 4. Tests ausfuehren
go test -v ./...
```
### Mit Docker
```bash
# Service bauen und starten
docker compose up billing-service
# Nur bauen
docker build -t billing-service .
```
## Architektur
```
billing-service/
├── cmd/server/main.go # Entry Point
├── internal/
│ ├── config/config.go # Konfiguration
│ ├── database/database.go # DB Connection + Migrations
│ ├── models/models.go # Datenmodelle
│ ├── middleware/middleware.go # JWT Auth, CORS, Rate Limiting
│ ├── services/
│ │ ├── subscription_service.go # Subscription Management
│ │ ├── task_service.go # Task Consumption
│ │ ├── entitlement_service.go # Feature Gating
│ │ ├── usage_service.go # Usage Tracking (Legacy)
│ │ └── stripe_service.go # Stripe API
│ └── handlers/
│ ├── billing_handlers.go # API Endpoints
│ └── webhook_handlers.go # Stripe Webhooks
├── Dockerfile
└── go.mod
```
## Task-basiertes Billing
### Konzept
- **1 Task = 1 Kontingentverbrauch** (unabhaengig von Seitenanzahl, Tokens, etc.)
- **Monatliches Kontingent**: Plan-abhaengig (Basic: 30, Standard: 100, Premium: Fair Use)
- **Carryover**: Ungenutzte Tasks sammeln sich bis zu 5 Monate an
- **Max Balance**: `monthly_allowance * 5` (z.B. Basic: max 150 Tasks)
### Task Types
```go
TaskTypeCorrection = "correction" // Korrekturaufgabe
TaskTypeLetter = "letter" // Brief erstellen
TaskTypeMeeting = "meeting" // Meeting-Protokoll
TaskTypeBatch = "batch" // Batch-Verarbeitung
TaskTypeOther = "other" // Sonstige
```
### Monatswechsel-Logik
Bei jedem API-Aufruf wird geprueft, ob ein Monat vergangen ist:
1. `last_renewal_at` pruefen
2. Falls >= 1 Monat: `task_balance += monthly_allowance`
3. Cap bei `max_task_balance`
4. `last_renewal_at` aktualisieren
## API Endpoints
### User Endpoints (JWT Auth)
| Methode | Endpoint | Beschreibung |
|---------|----------|--------------|
| GET | `/api/v1/billing/status` | Aktueller Billing Status |
| GET | `/api/v1/billing/plans` | Verfuegbare Plaene |
| POST | `/api/v1/billing/trial/start` | Trial starten |
| POST | `/api/v1/billing/change-plan` | Plan wechseln |
| POST | `/api/v1/billing/cancel` | Abo kuendigen |
| GET | `/api/v1/billing/portal` | Stripe Portal URL |
### Internal Endpoints (API Key)
| Methode | Endpoint | Beschreibung |
|---------|----------|--------------|
| GET | `/api/v1/billing/entitlements/:userId` | Entitlements abrufen |
| GET | `/api/v1/billing/entitlements/check/:userId/:feature` | Feature pruefen |
| GET | `/api/v1/billing/tasks/check/:userId` | Task erlaubt? |
| POST | `/api/v1/billing/tasks/consume` | Task konsumieren |
| GET | `/api/v1/billing/tasks/usage/:userId` | Task Usage Info |
### Webhook
| Methode | Endpoint | Beschreibung |
|---------|----------|--------------|
| POST | `/api/v1/billing/webhook` | Stripe Webhooks |
## Plaene und Preise
| Plan | Preis | Tasks/Monat | Max Balance | Features |
|------|-------|-------------|-------------|----------|
| Basic | 9.90 EUR | 30 | 150 | Basis-Features |
| Standard | 19.90 EUR | 100 | 500 | + Templates, Batch |
| Premium | 39.90 EUR | Fair Use | 5000 | + Team, Admin, API |
### Fair Use Mode (Premium)
Im Premium-Plan:
- Keine praktische Begrenzung
- Tasks werden trotzdem getrackt (fuer Monitoring)
- Balance wird nicht dekrementiert
- `CheckTaskAllowed` gibt immer `true` zurueck
## Datenbank
### Wichtige Tabellen
```sql
-- Task-basierte Nutzung pro Account
CREATE TABLE account_usage (
account_id UUID UNIQUE,
plan VARCHAR(50),
monthly_task_allowance INT,
max_task_balance INT,
task_balance INT,
last_renewal_at TIMESTAMPTZ
);
-- Einzelne Task-Records
CREATE TABLE tasks (
id UUID PRIMARY KEY,
account_id UUID,
task_type VARCHAR(50),
consumed BOOLEAN,
created_at TIMESTAMPTZ
);
```
## Tests
```bash
# Alle Tests
go test -v ./...
# Mit Coverage
go test -cover ./...
# Nur Models
go test -v ./internal/models/...
# Nur Services
go test -v ./internal/services/...
# Nur Handlers
go test -v ./internal/handlers/...
```
## Stripe Integration
### Webhooks
Konfiguriere im Stripe Dashboard:
```
URL: https://your-domain.com/api/v1/billing/webhook
Events:
- checkout.session.completed
- customer.subscription.created
- customer.subscription.updated
- customer.subscription.deleted
- invoice.paid
- invoice.payment_failed
```
### Lokales Testing
```bash
# Stripe CLI installieren
brew install stripe/stripe-cli/stripe
# Webhook forwarding
stripe listen --forward-to localhost:8083/api/v1/billing/webhook
# Test Events triggern
stripe trigger checkout.session.completed
stripe trigger invoice.paid
```
## Umgebungsvariablen
| Variable | Beschreibung | Beispiel |
|----------|--------------|----------|
| `DATABASE_URL` | PostgreSQL Connection String | `postgres://...` |
| `JWT_SECRET` | JWT Signing Secret | `your-secret` |
| `STRIPE_SECRET_KEY` | Stripe Secret Key | `sk_test_...` |
| `STRIPE_WEBHOOK_SECRET` | Webhook Signing Secret | `whsec_...` |
| `BILLING_SUCCESS_URL` | Checkout Success Redirect | `http://...` |
| `BILLING_CANCEL_URL` | Checkout Cancel Redirect | `http://...` |
| `INTERNAL_API_KEY` | Service-to-Service Auth | `internal-key` |
| `TRIAL_PERIOD_DAYS` | Trial Dauer in Tagen | `7` |
| `PORT` | Server Port | `8083` |
## Error Handling
### Task Limit Reached
```json
{
"error": "TASK_LIMIT_REACHED",
"message": "Dein Aufgaben-Kontingent ist aufgebraucht.",
"current_balance": 0,
"plan": "basic"
}
```
HTTP Status: `402 Payment Required`
### No Subscription
```json
{
"error": "NO_SUBSCRIPTION",
"message": "Kein aktives Abonnement gefunden."
}
```
HTTP Status: `403 Forbidden`
## Frontend Integration
### Task Usage anzeigen
```typescript
// Response von GET /api/v1/billing/status
interface TaskUsageInfo {
tasks_available: number; // z.B. 45
max_tasks: number; // z.B. 150
info_text: string; // "Aufgaben verfuegbar: 45 von max. 150"
tooltip_text: string; // "Aufgaben koennen sich bis zu 5 Monate ansammeln."
}
```
### Task konsumieren
```typescript
// Vor jeder KI-Aktion
const response = await fetch('/api/v1/billing/tasks/check/' + userId);
const { allowed, message } = await response.json();
if (!allowed) {
showUpgradeDialog(message);
return;
}
// Nach erfolgreicher KI-Aktion
await fetch('/api/v1/billing/tasks/consume', {
method: 'POST',
body: JSON.stringify({ user_id: userId, task_type: 'correction' })
});
```
+143
View File
@@ -0,0 +1,143 @@
package main
import (
"log"
"github.com/breakpilot/billing-service/internal/config"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/handlers"
"github.com/breakpilot/billing-service/internal/middleware"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
)
func main() {
// Load configuration
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Initialize database
db, err := database.Connect(cfg.DatabaseURL)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
// Run migrations
if err := database.Migrate(db); err != nil {
log.Fatalf("Failed to run migrations: %v", err)
}
// Setup Gin router
if cfg.Environment == "production" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.Default()
// Global middleware
router.Use(middleware.CORS())
router.Use(middleware.RequestLogger())
router.Use(middleware.RateLimiter())
// Health check (no auth required)
router.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{
"status": "healthy",
"service": "billing-service",
"version": "1.0.0",
})
})
// Initialize services
subscriptionService := services.NewSubscriptionService(db)
// Create Stripe service (mock or real depending on config)
var stripeService *services.StripeService
if cfg.IsMockMode() {
log.Println("Starting in MOCK MODE - Stripe API calls will be simulated")
stripeService = services.NewMockStripeService(
cfg.BillingSuccessURL,
cfg.BillingCancelURL,
cfg.TrialPeriodDays,
subscriptionService,
)
} else {
stripeService = services.NewStripeService(
cfg.StripeSecretKey,
cfg.StripeWebhookSecret,
cfg.BillingSuccessURL,
cfg.BillingCancelURL,
cfg.TrialPeriodDays,
subscriptionService,
)
}
entitlementService := services.NewEntitlementService(db, subscriptionService)
usageService := services.NewUsageService(db, entitlementService)
// Initialize handlers
billingHandler := handlers.NewBillingHandler(
db,
subscriptionService,
stripeService,
entitlementService,
usageService,
)
webhookHandler := handlers.NewWebhookHandler(
db,
cfg.StripeWebhookSecret,
subscriptionService,
entitlementService,
)
// API v1 routes
v1 := router.Group("/api/v1/billing")
{
// Stripe webhook (no auth - uses Stripe signature)
v1.POST("/webhook", webhookHandler.HandleStripeWebhook)
// =============================================
// User Endpoints (require JWT auth)
// =============================================
user := v1.Group("")
user.Use(middleware.AuthMiddleware(cfg.JWTSecret))
{
// Subscription status and management
user.GET("/status", billingHandler.GetBillingStatus)
user.GET("/plans", billingHandler.GetPlans)
user.POST("/trial/start", billingHandler.StartTrial)
user.POST("/change-plan", billingHandler.ChangePlan)
user.POST("/cancel", billingHandler.CancelSubscription)
user.GET("/portal", billingHandler.GetCustomerPortal)
}
// =============================================
// Internal Endpoints (service-to-service)
// =============================================
internal := v1.Group("")
internal.Use(middleware.InternalAPIKeyMiddleware(cfg.InternalAPIKey))
{
// Entitlements
internal.GET("/entitlements/:userId", billingHandler.GetEntitlements)
internal.GET("/entitlements/check/:userId/:feature", billingHandler.CheckEntitlement)
// Usage tracking
internal.POST("/usage/track", billingHandler.TrackUsage)
internal.GET("/usage/check/:userId/:type", billingHandler.CheckUsage)
}
}
// Start server
port := cfg.Port
if port == "" {
port = "8083"
}
log.Printf("Starting Billing Service on port %s", port)
if err := router.Run(":" + port); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}
+49
View File
@@ -0,0 +1,49 @@
module github.com/breakpilot/billing-service
go 1.23.0
require (
github.com/gin-gonic/gin v1.11.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.6
github.com/joho/godotenv v1.5.1
github.com/stripe/stripe-go/v76 v76.25.0
)
require (
github.com/bytedance/sonic v1.14.0 // indirect
github.com/bytedance/sonic/loader v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.27.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.54.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.0 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/arch v0.20.0 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.42.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/tools v0.34.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect
)
+111
View File
@@ -0,0 +1,111 @@
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg=
github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stripe/stripe-go/v76 v76.25.0 h1:kmDoOTvdQSTQssQzWZQQkgbAR2Q8eXdMWbN/ylNalWA=
github.com/stripe/stripe-go/v76 v76.25.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+157
View File
@@ -0,0 +1,157 @@
package config
import (
"fmt"
"os"
"github.com/joho/godotenv"
)
// Config holds all configuration for the billing service
type Config struct {
// Server
Port string
Environment string
// Database
DatabaseURL string
// JWT (shared with consent-service)
JWTSecret string
// Stripe
StripeSecretKey string
StripeWebhookSecret string
StripePublishableKey string
StripeMockMode bool // If true, Stripe calls are mocked (for dev without Stripe keys)
// URLs
BillingSuccessURL string
BillingCancelURL string
FrontendURL string
// Trial
TrialPeriodDays int
// CORS
AllowedOrigins []string
// Rate Limiting
RateLimitRequests int
RateLimitWindow int // in seconds
// Internal API Key (for service-to-service communication)
InternalAPIKey string
}
// Load loads configuration from environment variables
func Load() (*Config, error) {
// Load .env file if exists (for development)
_ = godotenv.Load()
cfg := &Config{
Port: getEnv("PORT", "8083"),
Environment: getEnv("ENVIRONMENT", "development"),
DatabaseURL: getEnv("DATABASE_URL", ""),
JWTSecret: getEnv("JWT_SECRET", ""),
// Stripe
StripeSecretKey: getEnv("STRIPE_SECRET_KEY", ""),
StripeWebhookSecret: getEnv("STRIPE_WEBHOOK_SECRET", ""),
StripePublishableKey: getEnv("STRIPE_PUBLISHABLE_KEY", ""),
StripeMockMode: getEnvBool("STRIPE_MOCK_MODE", false),
// URLs
BillingSuccessURL: getEnv("BILLING_SUCCESS_URL", "http://localhost:8000/app/billing/success"),
BillingCancelURL: getEnv("BILLING_CANCEL_URL", "http://localhost:8000/app/billing/cancel"),
FrontendURL: getEnv("FRONTEND_URL", "http://localhost:8000"),
// Trial
TrialPeriodDays: getEnvInt("TRIAL_PERIOD_DAYS", 7),
// Rate Limiting
RateLimitRequests: getEnvInt("RATE_LIMIT_REQUESTS", 100),
RateLimitWindow: getEnvInt("RATE_LIMIT_WINDOW", 60),
// Internal API
InternalAPIKey: getEnv("INTERNAL_API_KEY", ""),
}
// Parse allowed origins
originsStr := getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000")
cfg.AllowedOrigins = parseCommaSeparated(originsStr)
// Validate required fields
if cfg.DatabaseURL == "" {
return nil, fmt.Errorf("DATABASE_URL is required")
}
if cfg.JWTSecret == "" {
return nil, fmt.Errorf("JWT_SECRET is required")
}
// Stripe key is required unless mock mode is enabled
if cfg.StripeSecretKey == "" && !cfg.StripeMockMode {
// In development mode, auto-enable mock mode if no Stripe key
if cfg.Environment == "development" {
cfg.StripeMockMode = true
} else {
return nil, fmt.Errorf("STRIPE_SECRET_KEY is required (set STRIPE_MOCK_MODE=true to bypass in dev)")
}
}
return cfg, nil
}
// IsMockMode returns true if Stripe should be mocked
func (c *Config) IsMockMode() bool {
return c.StripeMockMode
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var result int
fmt.Sscanf(value, "%d", &result)
return result
}
return defaultValue
}
func getEnvBool(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
return value == "true" || value == "1" || value == "yes"
}
return defaultValue
}
func parseCommaSeparated(s string) []string {
if s == "" {
return []string{}
}
var result []string
start := 0
for i := 0; i <= len(s); i++ {
if i == len(s) || s[i] == ',' {
item := s[start:i]
// Trim whitespace
for len(item) > 0 && item[0] == ' ' {
item = item[1:]
}
for len(item) > 0 && item[len(item)-1] == ' ' {
item = item[:len(item)-1]
}
if item != "" {
result = append(result, item)
}
start = i + 1
}
}
return result
}
@@ -0,0 +1,260 @@
package database
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// DB wraps the pgx pool
type DB struct {
Pool *pgxpool.Pool
}
// Connect establishes a connection to the PostgreSQL database
func Connect(databaseURL string) (*DB, error) {
config, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse database URL: %w", err)
}
// Configure connection pool
config.MaxConns = 15
config.MinConns = 3
config.MaxConnLifetime = time.Hour
config.MaxConnIdleTime = 30 * time.Minute
config.HealthCheckPeriod = time.Minute
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Test the connection
if err := pool.Ping(ctx); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &DB{Pool: pool}, nil
}
// Close closes the database connection pool
func (db *DB) Close() {
db.Pool.Close()
}
// Migrate runs database migrations for the billing service
func Migrate(db *DB) error {
ctx := context.Background()
migrations := []string{
// =============================================
// Billing Service Tables
// =============================================
// Subscriptions - core subscription data
`CREATE TABLE IF NOT EXISTS subscriptions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
stripe_customer_id VARCHAR(255),
stripe_subscription_id VARCHAR(255) UNIQUE,
plan_id VARCHAR(50) NOT NULL,
status VARCHAR(30) NOT NULL DEFAULT 'trialing',
trial_end TIMESTAMPTZ,
current_period_start TIMESTAMPTZ,
current_period_end TIMESTAMPTZ,
cancel_at_period_end BOOLEAN DEFAULT FALSE,
canceled_at TIMESTAMPTZ,
ended_at TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(user_id)
)`,
// Billing Plans - cached from Stripe
`CREATE TABLE IF NOT EXISTS billing_plans (
id VARCHAR(50) PRIMARY KEY,
stripe_price_id VARCHAR(255) UNIQUE,
stripe_product_id VARCHAR(255),
name VARCHAR(100) NOT NULL,
description TEXT,
price_cents INT NOT NULL,
currency VARCHAR(3) DEFAULT 'eur',
interval VARCHAR(10) DEFAULT 'month',
features JSONB DEFAULT '{}',
is_active BOOLEAN DEFAULT TRUE,
sort_order INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Usage Summary - aggregated usage per period
`CREATE TABLE IF NOT EXISTS usage_summary (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
usage_type VARCHAR(50) NOT NULL,
period_start TIMESTAMPTZ NOT NULL,
total_count INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(user_id, usage_type, period_start)
)`,
// User Entitlements - cached entitlements for fast lookups
`CREATE TABLE IF NOT EXISTS user_entitlements (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL UNIQUE,
plan_id VARCHAR(50) NOT NULL,
ai_requests_limit INT DEFAULT 0,
ai_requests_used INT DEFAULT 0,
documents_limit INT DEFAULT 0,
documents_used INT DEFAULT 0,
features JSONB DEFAULT '{}',
period_start TIMESTAMPTZ,
period_end TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Stripe Webhook Events - for idempotency
`CREATE TABLE IF NOT EXISTS stripe_webhook_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
stripe_event_id VARCHAR(255) UNIQUE NOT NULL,
event_type VARCHAR(100) NOT NULL,
processed BOOLEAN DEFAULT FALSE,
processed_at TIMESTAMPTZ,
payload JSONB,
error_message TEXT,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Billing Audit Log
`CREATE TABLE IF NOT EXISTS billing_audit_log (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID,
action VARCHAR(50) NOT NULL,
entity_type VARCHAR(50),
entity_id VARCHAR(255),
old_value JSONB,
new_value JSONB,
metadata JSONB,
ip_address INET,
user_agent TEXT,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Invoices - cached from Stripe
`CREATE TABLE IF NOT EXISTS invoices (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL,
stripe_invoice_id VARCHAR(255) UNIQUE NOT NULL,
stripe_subscription_id VARCHAR(255),
status VARCHAR(30) NOT NULL,
amount_due INT NOT NULL,
amount_paid INT DEFAULT 0,
currency VARCHAR(3) DEFAULT 'eur',
hosted_invoice_url TEXT,
invoice_pdf TEXT,
period_start TIMESTAMPTZ,
period_end TIMESTAMPTZ,
due_date TIMESTAMPTZ,
paid_at TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// =============================================
// Task-based Billing Tables
// =============================================
// Account Usage - tracks task balance per account
`CREATE TABLE IF NOT EXISTS account_usage (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
account_id UUID NOT NULL UNIQUE,
plan VARCHAR(50) NOT NULL,
monthly_task_allowance INT NOT NULL,
carryover_months_cap INT DEFAULT 5,
max_task_balance INT NOT NULL,
task_balance INT NOT NULL,
last_renewal_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
// Tasks - individual task consumption records
`CREATE TABLE IF NOT EXISTS tasks (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
account_id UUID NOT NULL,
task_type VARCHAR(50) NOT NULL,
consumed BOOLEAN DEFAULT TRUE,
page_count INT DEFAULT 0,
token_count INT DEFAULT 0,
process_time INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
// =============================================
// Indexes
// =============================================
`CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe_customer ON subscriptions(stripe_customer_id)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe_sub ON subscriptions(stripe_subscription_id)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_status ON subscriptions(status)`,
`CREATE INDEX IF NOT EXISTS idx_subscriptions_trial_end ON subscriptions(trial_end)`,
`CREATE INDEX IF NOT EXISTS idx_usage_summary_user ON usage_summary(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_usage_summary_period ON usage_summary(period_start)`,
`CREATE INDEX IF NOT EXISTS idx_usage_summary_type ON usage_summary(usage_type)`,
`CREATE INDEX IF NOT EXISTS idx_user_entitlements_user ON user_entitlements(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_user_entitlements_plan ON user_entitlements(plan_id)`,
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_event_id ON stripe_webhook_events(stripe_event_id)`,
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_type ON stripe_webhook_events(event_type)`,
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_processed ON stripe_webhook_events(processed)`,
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_user ON billing_audit_log(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_action ON billing_audit_log(action)`,
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_created ON billing_audit_log(created_at)`,
`CREATE INDEX IF NOT EXISTS idx_invoices_user ON invoices(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_invoices_stripe_invoice ON invoices(stripe_invoice_id)`,
`CREATE INDEX IF NOT EXISTS idx_invoices_status ON invoices(status)`,
`CREATE INDEX IF NOT EXISTS idx_account_usage_account ON account_usage(account_id)`,
`CREATE INDEX IF NOT EXISTS idx_account_usage_plan ON account_usage(plan)`,
`CREATE INDEX IF NOT EXISTS idx_account_usage_renewal ON account_usage(last_renewal_at)`,
`CREATE INDEX IF NOT EXISTS idx_tasks_account ON tasks(account_id)`,
`CREATE INDEX IF NOT EXISTS idx_tasks_type ON tasks(task_type)`,
`CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)`,
// =============================================
// Insert default plans
// =============================================
`INSERT INTO billing_plans (id, name, description, price_cents, currency, interval, features, sort_order)
VALUES
('basic', 'Basic', 'Perfekt für den Einstieg', 990, 'eur', 'month',
'{"ai_requests_limit": 300, "documents_limit": 50, "feature_flags": ["basic_ai", "basic_documents"], "max_team_members": 1, "priority_support": false, "custom_branding": false}',
1),
('standard', 'Standard', 'Für regelmäßige Nutzer', 1990, 'eur', 'month',
'{"ai_requests_limit": 1500, "documents_limit": 200, "feature_flags": ["basic_ai", "basic_documents", "templates", "batch_processing"], "max_team_members": 3, "priority_support": false, "custom_branding": false}',
2),
('premium', 'Premium', 'Für Teams und Power-User', 3990, 'eur', 'month',
'{"ai_requests_limit": 5000, "documents_limit": 1000, "feature_flags": ["basic_ai", "basic_documents", "templates", "batch_processing", "team_features", "admin_panel", "audit_log", "api_access"], "max_team_members": 10, "priority_support": true, "custom_branding": true}',
3)
ON CONFLICT (id) DO NOTHING`,
}
for _, migration := range migrations {
if _, err := db.Pool.Exec(ctx, migration); err != nil {
return fmt.Errorf("failed to run migration: %w", err)
}
}
return nil
}
@@ -0,0 +1,427 @@
package handlers
import (
"net/http"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/middleware"
"github.com/breakpilot/billing-service/internal/models"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
)
// BillingHandler handles billing-related HTTP requests
type BillingHandler struct {
db *database.DB
subscriptionService *services.SubscriptionService
stripeService *services.StripeService
entitlementService *services.EntitlementService
usageService *services.UsageService
}
// NewBillingHandler creates a new BillingHandler
func NewBillingHandler(
db *database.DB,
subscriptionService *services.SubscriptionService,
stripeService *services.StripeService,
entitlementService *services.EntitlementService,
usageService *services.UsageService,
) *BillingHandler {
return &BillingHandler{
db: db,
subscriptionService: subscriptionService,
stripeService: stripeService,
entitlementService: entitlementService,
usageService: usageService,
}
}
// GetBillingStatus returns the current billing status for a user
// GET /api/v1/billing/status
func (h *BillingHandler) GetBillingStatus(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get subscription",
})
return
}
// Get available plans
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get plans",
})
return
}
response := models.BillingStatusResponse{
HasSubscription: subscription != nil,
AvailablePlans: plans,
}
if subscription != nil {
// Get plan details
plan, _ := h.subscriptionService.GetPlanByID(ctx, string(subscription.PlanID))
subInfo := &models.SubscriptionInfo{
PlanID: subscription.PlanID,
Status: subscription.Status,
IsTrialing: subscription.Status == models.StatusTrialing,
CancelAtPeriodEnd: subscription.CancelAtPeriodEnd,
CurrentPeriodEnd: subscription.CurrentPeriodEnd,
}
if plan != nil {
subInfo.PlanName = plan.Name
subInfo.PriceCents = plan.PriceCents
subInfo.Currency = plan.Currency
}
// Calculate trial days left
if subscription.TrialEnd != nil && subscription.Status == models.StatusTrialing {
// TODO: Calculate days left
}
response.Subscription = subInfo
// Get task usage info (legacy usage tracking - see TaskService for new task-based usage)
// TODO: Replace with TaskService.GetTaskUsageInfo for task-based billing
_, _ = h.usageService.GetUsageSummary(ctx, userID)
// Get entitlements
entitlements, _ := h.entitlementService.GetEntitlements(ctx, userID)
if entitlements != nil {
response.Entitlements = entitlements
}
}
c.JSON(http.StatusOK, response)
}
// GetPlans returns all available billing plans
// GET /api/v1/billing/plans
func (h *BillingHandler) GetPlans(c *gin.Context) {
ctx := c.Request.Context()
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get plans",
})
return
}
c.JSON(http.StatusOK, gin.H{
"plans": plans,
})
}
// StartTrial starts a trial for the user with a specific plan
// POST /api/v1/billing/trial/start
func (h *BillingHandler) StartTrial(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
var req models.StartTrialRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
// Check if user already has a subscription
existing, _ := h.subscriptionService.GetByUserID(ctx, userID)
if existing != nil {
c.JSON(http.StatusConflict, gin.H{
"error": "subscription_exists",
"message": "User already has a subscription",
})
return
}
// Get user email from context
email, _ := c.Get("email")
emailStr, _ := email.(string)
// Create Stripe checkout session
checkoutURL, sessionID, err := h.stripeService.CreateCheckoutSession(ctx, userID, emailStr, req.PlanID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to create checkout session",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.StartTrialResponse{
CheckoutURL: checkoutURL,
SessionID: sessionID,
})
}
// ChangePlan changes the user's subscription plan
// POST /api/v1/billing/change-plan
func (h *BillingHandler) ChangePlan(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
var req models.ChangePlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Change plan via Stripe
err = h.stripeService.ChangePlan(ctx, subscription.StripeSubscriptionID, req.NewPlanID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to change plan",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.ChangePlanResponse{
Success: true,
Message: "Plan changed successfully",
})
}
// CancelSubscription cancels the user's subscription at period end
// POST /api/v1/billing/cancel
func (h *BillingHandler) CancelSubscription(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Cancel at period end via Stripe
err = h.stripeService.CancelSubscription(ctx, subscription.StripeSubscriptionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to cancel subscription",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.CancelSubscriptionResponse{
Success: true,
Message: "Subscription will be canceled at the end of the billing period",
})
}
// GetCustomerPortal returns a URL to the Stripe customer portal
// GET /api/v1/billing/portal
func (h *BillingHandler) GetCustomerPortal(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID.String() == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
return
}
ctx := c.Request.Context()
// Get current subscription
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
if err != nil || subscription == nil || subscription.StripeCustomerID == "" {
c.JSON(http.StatusNotFound, gin.H{
"error": "no_subscription",
"message": "No active subscription found",
})
return
}
// Create portal session
portalURL, err := h.stripeService.CreateCustomerPortalSession(ctx, subscription.StripeCustomerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "stripe_error",
"message": "Failed to create portal session",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, models.CustomerPortalResponse{
PortalURL: portalURL,
})
}
// =============================================
// Internal Endpoints (Service-to-Service)
// =============================================
// GetEntitlements returns entitlements for a user (internal)
// GET /api/v1/billing/entitlements/:userId
func (h *BillingHandler) GetEntitlements(c *gin.Context) {
userIDStr := c.Param("userId")
ctx := c.Request.Context()
entitlements, err := h.entitlementService.GetEntitlementsByUserIDString(ctx, userIDStr)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to get entitlements",
})
return
}
if entitlements == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "not_found",
"message": "No entitlements found for user",
})
return
}
c.JSON(http.StatusOK, entitlements)
}
// TrackUsage tracks usage for a user (internal)
// POST /api/v1/billing/usage/track
func (h *BillingHandler) TrackUsage(c *gin.Context) {
var req models.TrackUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body",
})
return
}
ctx := c.Request.Context()
quantity := req.Quantity
if quantity <= 0 {
quantity = 1
}
err := h.usageService.TrackUsage(ctx, req.UserID, req.UsageType, quantity)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to track usage",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Usage tracked",
})
}
// CheckUsage checks if usage is allowed (internal)
// GET /api/v1/billing/usage/check/:userId/:type
func (h *BillingHandler) CheckUsage(c *gin.Context) {
userIDStr := c.Param("userId")
usageType := c.Param("type")
ctx := c.Request.Context()
response, err := h.usageService.CheckUsageAllowed(ctx, userIDStr, usageType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to check usage",
})
return
}
c.JSON(http.StatusOK, response)
}
// CheckEntitlement checks if a user has a specific entitlement (internal)
// GET /api/v1/billing/entitlements/check/:userId/:feature
func (h *BillingHandler) CheckEntitlement(c *gin.Context) {
userIDStr := c.Param("userId")
feature := c.Param("feature")
ctx := c.Request.Context()
hasEntitlement, planID, err := h.entitlementService.CheckEntitlement(ctx, userIDStr, feature)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "internal_error",
"message": "Failed to check entitlement",
})
return
}
c.JSON(http.StatusOK, models.EntitlementCheckResponse{
HasEntitlement: hasEntitlement,
PlanID: planID,
})
}
@@ -0,0 +1,612 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/breakpilot/billing-service/internal/models"
"github.com/gin-gonic/gin"
)
func init() {
// Set Gin to test mode
gin.SetMode(gin.TestMode)
}
func TestGetPlans_ResponseFormat(t *testing.T) {
// Test that GetPlans returns the expected response structure
// Since we don't have a real database connection in unit tests,
// we test the expected structure and format
// Test that default plans are well-formed
plans := models.GetDefaultPlans()
if len(plans) == 0 {
t.Error("Default plans should not be empty")
}
for _, plan := range plans {
// Verify JSON serialization works
data, err := json.Marshal(plan)
if err != nil {
t.Errorf("Failed to marshal plan %s: %v", plan.ID, err)
}
// Verify we can unmarshal back
var decoded models.BillingPlan
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Errorf("Failed to unmarshal plan %s: %v", plan.ID, err)
}
// Verify key fields
if decoded.ID != plan.ID {
t.Errorf("Plan ID mismatch: got %s, expected %s", decoded.ID, plan.ID)
}
}
}
func TestBillingStatusResponse_Structure(t *testing.T) {
// Test the response structure
response := models.BillingStatusResponse{
HasSubscription: true,
Subscription: &models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
CancelAtPeriodEnd: false,
PriceCents: 1990,
Currency: "eur",
},
TaskUsage: &models.TaskUsageInfo{
TasksAvailable: 85,
MaxTasks: 500,
InfoText: "Aufgaben verfuegbar: 85 von max. 500",
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
},
Entitlements: &models.EntitlementInfo{
Features: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
},
AvailablePlans: models.GetDefaultPlans(),
}
// Test JSON serialization
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal BillingStatusResponse: %v", err)
}
// Verify it's valid JSON
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
// Check required fields exist
if _, ok := decoded["has_subscription"]; !ok {
t.Error("Response should have 'has_subscription' field")
}
}
func TestStartTrialRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.StartTrialRequest
wantError bool
}{
{
name: "Valid basic plan",
request: models.StartTrialRequest{PlanID: models.PlanBasic},
wantError: false,
},
{
name: "Valid standard plan",
request: models.StartTrialRequest{PlanID: models.PlanStandard},
wantError: false,
},
{
name: "Valid premium plan",
request: models.StartTrialRequest{PlanID: models.PlanPremium},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test JSON serialization
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
var decoded models.StartTrialRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal request: %v", err)
}
if decoded.PlanID != tt.request.PlanID {
t.Errorf("PlanID mismatch: got %s, expected %s", decoded.PlanID, tt.request.PlanID)
}
})
}
}
func TestChangePlanRequest_Structure(t *testing.T) {
request := models.ChangePlanRequest{
NewPlanID: models.PlanPremium,
}
data, err := json.Marshal(request)
if err != nil {
t.Fatalf("Failed to marshal ChangePlanRequest: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["new_plan_id"]; !ok {
t.Error("Request should have 'new_plan_id' field")
}
}
func TestStartTrialResponse_Structure(t *testing.T) {
response := models.StartTrialResponse{
CheckoutURL: "https://checkout.stripe.com/c/pay/cs_test_123",
SessionID: "cs_test_123",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal StartTrialResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["checkout_url"]; !ok {
t.Error("Response should have 'checkout_url' field")
}
if _, ok := decoded["session_id"]; !ok {
t.Error("Response should have 'session_id' field")
}
}
func TestCancelSubscriptionResponse_Structure(t *testing.T) {
response := models.CancelSubscriptionResponse{
Success: true,
Message: "Subscription will be canceled at the end of the billing period",
CancelDate: "2025-01-16",
ActiveUntil: "2025-01-16",
}
_, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal CancelSubscriptionResponse: %v", err)
}
if !response.Success {
t.Error("Success should be true")
}
}
func TestCustomerPortalResponse_Structure(t *testing.T) {
response := models.CustomerPortalResponse{
PortalURL: "https://billing.stripe.com/p/session/test_123",
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal CustomerPortalResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["portal_url"]; !ok {
t.Error("Response should have 'portal_url' field")
}
}
func TestEntitlementCheckResponse_Structure(t *testing.T) {
tests := []struct {
name string
response models.EntitlementCheckResponse
}{
{
name: "Has entitlement",
response: models.EntitlementCheckResponse{
HasEntitlement: true,
PlanID: models.PlanStandard,
},
},
{
name: "No entitlement",
response: models.EntitlementCheckResponse{
HasEntitlement: false,
PlanID: models.PlanBasic,
Message: "Feature not available in this plan",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal EntitlementCheckResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["has_entitlement"]; !ok {
t.Error("Response should have 'has_entitlement' field")
}
})
}
}
func TestTrackUsageRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.TrackUsageRequest
valid bool
}{
{
name: "Valid AI request",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "ai_request",
Quantity: 1,
},
valid: true,
},
{
name: "Valid document created",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "document_created",
Quantity: 1,
},
valid: true,
},
{
name: "Multiple quantity",
request: models.TrackUsageRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
UsageType: "ai_request",
Quantity: 5,
},
valid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal TrackUsageRequest: %v", err)
}
var decoded models.TrackUsageRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal TrackUsageRequest: %v", err)
}
if decoded.UserID != tt.request.UserID {
t.Errorf("UserID mismatch: got %s, expected %s", decoded.UserID, tt.request.UserID)
}
})
}
}
func TestCheckUsageResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.CheckUsageResponse
}{
{
name: "Allowed response",
response: models.CheckUsageResponse{
Allowed: true,
CurrentUsage: 450,
Limit: 1500,
Remaining: 1050,
},
},
{
name: "Limit reached",
response: models.CheckUsageResponse{
Allowed: false,
CurrentUsage: 1500,
Limit: 1500,
Remaining: 0,
Message: "Usage limit reached for ai_request (1500/1500)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal CheckUsageResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["allowed"]; !ok {
t.Error("Response should have 'allowed' field")
}
})
}
}
func TestConsumeTaskRequest_Format(t *testing.T) {
tests := []struct {
name string
request models.ConsumeTaskRequest
}{
{
name: "Correction task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeCorrection,
},
},
{
name: "Letter task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeLetter,
},
},
{
name: "Batch task",
request: models.ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: models.TaskTypeBatch,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("Failed to marshal ConsumeTaskRequest: %v", err)
}
var decoded models.ConsumeTaskRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal ConsumeTaskRequest: %v", err)
}
if decoded.TaskType != tt.request.TaskType {
t.Errorf("TaskType mismatch: got %s, expected %s", decoded.TaskType, tt.request.TaskType)
}
})
}
}
func TestConsumeTaskResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.ConsumeTaskResponse
}{
{
name: "Successful consumption",
response: models.ConsumeTaskResponse{
Success: true,
TaskID: "task-uuid-123",
TasksRemaining: 49,
},
},
{
name: "Limit reached",
response: models.ConsumeTaskResponse{
Success: false,
TasksRemaining: 0,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal ConsumeTaskResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["success"]; !ok {
t.Error("Response should have 'success' field")
}
if _, ok := decoded["tasks_remaining"]; !ok {
t.Error("Response should have 'tasks_remaining' field")
}
})
}
}
func TestCheckTaskAllowedResponse_Format(t *testing.T) {
tests := []struct {
name string
response models.CheckTaskAllowedResponse
}{
{
name: "Task allowed",
response: models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 50,
MaxTasks: 150,
PlanID: models.PlanBasic,
},
},
{
name: "Task not allowed",
response: models.CheckTaskAllowedResponse{
Allowed: false,
TasksAvailable: 0,
MaxTasks: 150,
PlanID: models.PlanBasic,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
},
},
{
name: "Premium Fair Use",
response: models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 1000,
MaxTasks: 5000,
PlanID: models.PlanPremium,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal CheckTaskAllowedResponse: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Response is not valid JSON: %v", err)
}
if _, ok := decoded["allowed"]; !ok {
t.Error("Response should have 'allowed' field")
}
if _, ok := decoded["tasks_available"]; !ok {
t.Error("Response should have 'tasks_available' field")
}
if _, ok := decoded["plan_id"]; !ok {
t.Error("Response should have 'plan_id' field")
}
})
}
}
// HTTP Handler Tests (without DB)
func TestHTTPErrorResponse_Format(t *testing.T) {
// Test standard error response format
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Simulate an error response
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User not authenticated",
})
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if _, ok := response["error"]; !ok {
t.Error("Error response should have 'error' field")
}
if _, ok := response["message"]; !ok {
t.Error("Error response should have 'message' field")
}
}
func TestHTTPSuccessResponse_Format(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Simulate a success response
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Operation completed",
})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["success"] != true {
t.Error("Success response should have success=true")
}
}
func TestRequestParsing_InvalidJSON(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with invalid JSON
invalidJSON := []byte(`{"plan_id": }`) // Invalid JSON
c.Request = httptest.NewRequest("POST", "/test", bytes.NewReader(invalidJSON))
c.Request.Header.Set("Content-Type", "application/json")
var req models.StartTrialRequest
err := c.ShouldBindJSON(&req)
if err == nil {
t.Error("Should return error for invalid JSON")
}
}
func TestHTTPHeaders_ContentType(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"test": "value"})
contentType := w.Header().Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("Expected JSON content type, got %s", contentType)
}
}
@@ -0,0 +1,205 @@
package handlers
import (
"io"
"log"
"net/http"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76/webhook"
)
// WebhookHandler handles Stripe webhook events
type WebhookHandler struct {
db *database.DB
webhookSecret string
subscriptionService *services.SubscriptionService
entitlementService *services.EntitlementService
}
// NewWebhookHandler creates a new WebhookHandler
func NewWebhookHandler(
db *database.DB,
webhookSecret string,
subscriptionService *services.SubscriptionService,
entitlementService *services.EntitlementService,
) *WebhookHandler {
return &WebhookHandler{
db: db,
webhookSecret: webhookSecret,
subscriptionService: subscriptionService,
entitlementService: entitlementService,
}
}
// HandleStripeWebhook handles incoming Stripe webhook events
// POST /api/v1/billing/webhook
func (h *WebhookHandler) HandleStripeWebhook(c *gin.Context) {
// Read the request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("Webhook: Error reading body: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "cannot read body"})
return
}
// Get the Stripe signature header
sigHeader := c.GetHeader("Stripe-Signature")
if sigHeader == "" {
log.Printf("Webhook: Missing Stripe-Signature header")
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
return
}
// Verify the webhook signature
event, err := webhook.ConstructEvent(body, sigHeader, h.webhookSecret)
if err != nil {
log.Printf("Webhook: Signature verification failed: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
return
}
ctx := c.Request.Context()
// Check if we've already processed this event (idempotency)
processed, err := h.subscriptionService.IsEventProcessed(ctx, event.ID)
if err != nil {
log.Printf("Webhook: Error checking event: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
if processed {
log.Printf("Webhook: Event %s already processed", event.ID)
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
return
}
// Mark event as being processed
if err := h.subscriptionService.MarkEventProcessing(ctx, event.ID, string(event.Type)); err != nil {
log.Printf("Webhook: Error marking event: %v", err)
}
// Handle the event based on type
var handleErr error
switch event.Type {
case "checkout.session.completed":
handleErr = h.handleCheckoutSessionCompleted(ctx, event.Data.Raw)
case "customer.subscription.created":
handleErr = h.handleSubscriptionCreated(ctx, event.Data.Raw)
case "customer.subscription.updated":
handleErr = h.handleSubscriptionUpdated(ctx, event.Data.Raw)
case "customer.subscription.deleted":
handleErr = h.handleSubscriptionDeleted(ctx, event.Data.Raw)
case "invoice.paid":
handleErr = h.handleInvoicePaid(ctx, event.Data.Raw)
case "invoice.payment_failed":
handleErr = h.handleInvoicePaymentFailed(ctx, event.Data.Raw)
case "customer.created":
log.Printf("Webhook: Customer created - %s", event.ID)
default:
log.Printf("Webhook: Unhandled event type: %s", event.Type)
}
if handleErr != nil {
log.Printf("Webhook: Error handling %s: %v", event.Type, handleErr)
// Mark event as failed
h.subscriptionService.MarkEventFailed(ctx, event.ID, handleErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
return
}
// Mark event as processed
if err := h.subscriptionService.MarkEventProcessed(ctx, event.ID); err != nil {
log.Printf("Webhook: Error marking event processed: %v", err)
}
c.JSON(http.StatusOK, gin.H{"status": "processed"})
}
// handleCheckoutSessionCompleted handles successful checkout
func (h *WebhookHandler) handleCheckoutSessionCompleted(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing checkout.session.completed")
// Parse checkout session from data
// The actual implementation will parse the JSON and create/update subscription
// TODO: Implementation
// 1. Parse checkout session data
// 2. Extract customer_id, subscription_id, user_id (from metadata)
// 3. Create or update subscription record
// 4. Update entitlements
return nil
}
// handleSubscriptionCreated handles new subscription creation
func (h *WebhookHandler) handleSubscriptionCreated(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.created")
// TODO: Implementation
// 1. Parse subscription data
// 2. Extract status, plan, trial_end, etc.
// 3. Create subscription record
// 4. Set up initial entitlements
return nil
}
// handleSubscriptionUpdated handles subscription updates
func (h *WebhookHandler) handleSubscriptionUpdated(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.updated")
// TODO: Implementation
// 1. Parse subscription data
// 2. Update subscription record (status, plan, cancel_at_period_end, etc.)
// 3. Update entitlements if plan changed
return nil
}
// handleSubscriptionDeleted handles subscription cancellation
func (h *WebhookHandler) handleSubscriptionDeleted(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing customer.subscription.deleted")
// TODO: Implementation
// 1. Parse subscription data
// 2. Update subscription status to canceled/expired
// 3. Remove or downgrade entitlements
return nil
}
// handleInvoicePaid handles successful invoice payment
func (h *WebhookHandler) handleInvoicePaid(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing invoice.paid")
// TODO: Implementation
// 1. Parse invoice data
// 2. Update subscription period
// 3. Reset usage counters for new period
// 4. Store invoice record
return nil
}
// handleInvoicePaymentFailed handles failed invoice payment
func (h *WebhookHandler) handleInvoicePaymentFailed(ctx interface{}, data []byte) error {
log.Printf("Webhook: Processing invoice.payment_failed")
// TODO: Implementation
// 1. Parse invoice data
// 2. Update subscription status to past_due
// 3. Send notification to user
// 4. Possibly restrict access
return nil
}
@@ -0,0 +1,433 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestWebhookEventTypes tests the event types we handle
func TestWebhookEventTypes(t *testing.T) {
eventTypes := []struct {
eventType string
shouldHandle bool
}{
{"checkout.session.completed", true},
{"customer.subscription.created", true},
{"customer.subscription.updated", true},
{"customer.subscription.deleted", true},
{"invoice.paid", true},
{"invoice.payment_failed", true},
{"customer.created", true}, // Handled but just logged
{"unknown.event.type", false},
}
for _, tt := range eventTypes {
t.Run(tt.eventType, func(t *testing.T) {
if tt.eventType == "" {
t.Error("Event type should not be empty")
}
})
}
}
// TestWebhookRequest_MissingSignature tests handling of missing signature
func TestWebhookRequest_MissingSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request without Stripe-Signature header
body := []byte(`{"id": "evt_test_123", "type": "test.event"}`)
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
// Note: No Stripe-Signature header
// Simulate the check we do in the handler
sigHeader := c.GetHeader("Stripe-Signature")
if sigHeader == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
}
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing signature, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "missing signature" {
t.Errorf("Expected 'missing signature' error, got '%v'", response["error"])
}
}
// TestWebhookRequest_EmptyBody tests handling of empty request body
func TestWebhookRequest_EmptyBody(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create request with empty body
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader([]byte{}))
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("Stripe-Signature", "t=123,v1=signature")
// Read the body
body := make([]byte, 0)
// Simulate empty body handling
if len(body) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "empty body"})
}
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for empty body, got %d", w.Code)
}
}
// TestWebhookIdempotency tests idempotency behavior
func TestWebhookIdempotency(t *testing.T) {
// Test that the same event ID should not be processed twice
eventID := "evt_test_123456789"
// Simulate event tracking
processedEvents := make(map[string]bool)
// First time - should process
if !processedEvents[eventID] {
processedEvents[eventID] = true
}
// Second time - should skip
alreadyProcessed := processedEvents[eventID]
if !alreadyProcessed {
t.Error("Event should be marked as processed")
}
}
// TestWebhookResponse_Processed tests successful webhook response
func TestWebhookResponse_Processed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"status": "processed"})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "processed" {
t.Errorf("Expected status 'processed', got '%v'", response["status"])
}
}
// TestWebhookResponse_AlreadyProcessed tests idempotent response
func TestWebhookResponse_AlreadyProcessed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["status"] != "already_processed" {
t.Errorf("Expected status 'already_processed', got '%v'", response["status"])
}
}
// TestWebhookResponse_InternalError tests error response
func TestWebhookResponse_InternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
if w.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "handler error" {
t.Errorf("Expected 'handler error', got '%v'", response["error"])
}
}
// TestWebhookResponse_InvalidSignature tests signature verification failure
func TestWebhookResponse_InvalidSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "invalid signature" {
t.Errorf("Expected 'invalid signature', got '%v'", response["error"])
}
}
// TestCheckoutSessionCompleted_EventStructure tests the event data structure
func TestCheckoutSessionCompleted_EventStructure(t *testing.T) {
// Test the expected structure of a checkout.session.completed event
eventData := map[string]interface{}{
"id": "cs_test_123",
"customer": "cus_test_456",
"subscription": "sub_test_789",
"mode": "subscription",
"payment_status": "paid",
"status": "complete",
"metadata": map[string]interface{}{
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"plan_id": "standard",
},
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["customer"] == nil {
t.Error("Event should have 'customer' field")
}
if decoded["subscription"] == nil {
t.Error("Event should have 'subscription' field")
}
metadata, ok := decoded["metadata"].(map[string]interface{})
if !ok || metadata["user_id"] == nil {
t.Error("Event should have 'metadata.user_id' field")
}
}
// TestSubscriptionCreated_EventStructure tests subscription.created event structure
func TestSubscriptionCreated_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "sub_test_123",
"customer": "cus_test_456",
"status": "trialing",
"items": map[string]interface{}{
"data": []map[string]interface{}{
{
"price": map[string]interface{}{
"id": "price_test_789",
"metadata": map[string]interface{}{"plan_id": "standard"},
},
},
},
},
"trial_end": 1735689600,
"current_period_end": 1735689600,
"metadata": map[string]interface{}{
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"plan_id": "standard",
},
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "trialing" {
t.Errorf("Expected status 'trialing', got '%v'", decoded["status"])
}
}
// TestSubscriptionUpdated_StatusTransitions tests subscription status transitions
func TestSubscriptionUpdated_StatusTransitions(t *testing.T) {
validTransitions := []struct {
from string
to string
}{
{"trialing", "active"},
{"active", "past_due"},
{"past_due", "active"},
{"active", "canceled"},
{"trialing", "canceled"},
}
for _, tt := range validTransitions {
t.Run(tt.from+"->"+tt.to, func(t *testing.T) {
if tt.from == "" || tt.to == "" {
t.Error("Status should not be empty")
}
})
}
}
// TestInvoicePaid_EventStructure tests invoice.paid event structure
func TestInvoicePaid_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "in_test_123",
"subscription": "sub_test_456",
"customer": "cus_test_789",
"status": "paid",
"amount_paid": 1990,
"currency": "eur",
"period_start": 1735689600,
"period_end": 1738368000,
"hosted_invoice_url": "https://invoice.stripe.com/test",
"invoice_pdf": "https://invoice.stripe.com/test.pdf",
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "paid" {
t.Errorf("Expected status 'paid', got '%v'", decoded["status"])
}
if decoded["subscription"] == nil {
t.Error("Event should have 'subscription' field")
}
}
// TestInvoicePaymentFailed_EventStructure tests invoice.payment_failed event structure
func TestInvoicePaymentFailed_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "in_test_123",
"subscription": "sub_test_456",
"customer": "cus_test_789",
"status": "open",
"attempt_count": 1,
"next_payment_attempt": 1735776000,
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify fields
if decoded["attempt_count"] == nil {
t.Error("Event should have 'attempt_count' field")
}
}
// TestSubscriptionDeleted_EventStructure tests subscription.deleted event structure
func TestSubscriptionDeleted_EventStructure(t *testing.T) {
eventData := map[string]interface{}{
"id": "sub_test_123",
"customer": "cus_test_456",
"status": "canceled",
"ended_at": 1735689600,
"canceled_at": 1735689600,
}
data, err := json.Marshal(eventData)
if err != nil {
t.Fatalf("Failed to marshal event data: %v", err)
}
var decoded map[string]interface{}
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal event data: %v", err)
}
// Verify required fields
if decoded["status"] != "canceled" {
t.Errorf("Expected status 'canceled', got '%v'", decoded["status"])
}
}
// TestStripeSignatureFormat tests the Stripe signature header format
func TestStripeSignatureFormat(t *testing.T) {
// Stripe signature format: t=timestamp,v1=signature
validSignatures := []string{
"t=1609459200,v1=abc123def456",
"t=1609459200,v1=signature_here,v0=old_signature",
}
for _, sig := range validSignatures {
if len(sig) < 10 {
t.Errorf("Signature seems too short: %s", sig)
}
// Should start with timestamp
if sig[:2] != "t=" {
t.Errorf("Signature should start with 't=': %s", sig)
}
}
}
// TestWebhookEventID_Format tests Stripe event ID format
func TestWebhookEventID_Format(t *testing.T) {
validEventIDs := []string{
"evt_1234567890abcdef",
"evt_test_123456789",
"evt_live_987654321",
}
for _, eventID := range validEventIDs {
// Event IDs should start with "evt_"
if len(eventID) < 10 || eventID[:4] != "evt_" {
t.Errorf("Invalid event ID format: %s", eventID)
}
}
}
@@ -0,0 +1,288 @@
package middleware
import (
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// UserClaims represents the JWT claims for a user
type UserClaims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// CORS returns a CORS middleware
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// Allow localhost for development
allowedOrigins := []string{
"http://localhost:3000",
"http://localhost:8000",
"http://localhost:8080",
"http://localhost:8083",
"https://breakpilot.app",
}
allowed := false
for _, o := range allowedOrigins {
if origin == o {
allowed = true
break
}
}
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
}
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With, X-Internal-API-Key")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// RequestLogger logs each request
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
// Log only in development or for errors
if status >= 400 {
gin.DefaultWriter.Write([]byte(
method + " " + path + " " +
string(rune(status)) + " " +
latency.String() + "\n",
))
}
}
}
// RateLimiter implements a simple in-memory rate limiter
func RateLimiter() gin.HandlerFunc {
type client struct {
count int
lastSeen time.Time
}
var (
mu sync.Mutex
clients = make(map[string]*client)
)
// Clean up old entries periodically
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, c := range clients {
if time.Since(c.lastSeen) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(c *gin.Context) {
ip := c.ClientIP()
mu.Lock()
defer mu.Unlock()
if _, exists := clients[ip]; !exists {
clients[ip] = &client{}
}
cli := clients[ip]
// Reset count if more than a minute has passed
if time.Since(cli.lastSeen) > time.Minute {
cli.count = 0
}
cli.count++
cli.lastSeen = time.Now()
// Allow 100 requests per minute
if cli.count > 100 {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
})
return
}
c.Next()
}
}
// AuthMiddleware validates JWT tokens
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "missing_authorization",
"message": "Authorization header is required",
})
return
}
// Extract token from "Bearer <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_authorization",
"message": "Authorization header must be in format: Bearer <token>",
})
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_token",
"message": "Invalid or expired token",
})
return
}
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// Set user info in context
c.Set("user_id", claims.UserID)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Next()
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_claims",
"message": "Invalid token claims",
})
return
}
}
}
// InternalAPIKeyMiddleware validates internal API key for service-to-service communication
func InternalAPIKeyMiddleware(apiKey string) gin.HandlerFunc {
return func(c *gin.Context) {
if apiKey == "" {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"error": "config_error",
"message": "Internal API key not configured",
})
return
}
providedKey := c.GetHeader("X-Internal-API-Key")
if providedKey == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "missing_api_key",
"message": "X-Internal-API-Key header is required",
})
return
}
if providedKey != apiKey {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_api_key",
"message": "Invalid API key",
})
return
}
c.Next()
}
}
// AdminOnly ensures only admin users can access the route
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User role not found",
})
return
}
roleStr, ok := role.(string)
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "Admin access required",
})
return
}
c.Next()
}
}
// GetUserID extracts the user ID from the context
func GetUserID(c *gin.Context) (uuid.UUID, error) {
userIDStr, exists := c.Get("user_id")
if !exists {
return uuid.Nil, nil
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
return uuid.Nil, err
}
return userID, nil
}
// GetClientIP returns the client's IP address
func GetClientIP(c *gin.Context) string {
// Check X-Forwarded-For header first (for proxied requests)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// GetUserAgent returns the client's User-Agent
func GetUserAgent(c *gin.Context) string {
return c.GetHeader("User-Agent")
}
+372
View File
@@ -0,0 +1,372 @@
package models
import (
"time"
"github.com/google/uuid"
)
// SubscriptionStatus represents the status of a subscription
type SubscriptionStatus string
const (
StatusTrialing SubscriptionStatus = "trialing"
StatusActive SubscriptionStatus = "active"
StatusPastDue SubscriptionStatus = "past_due"
StatusCanceled SubscriptionStatus = "canceled"
StatusExpired SubscriptionStatus = "expired"
)
// PlanID represents the available plan IDs
type PlanID string
const (
PlanBasic PlanID = "basic"
PlanStandard PlanID = "standard"
PlanPremium PlanID = "premium"
)
// TaskType represents the type of task
type TaskType string
const (
TaskTypeCorrection TaskType = "correction"
TaskTypeLetter TaskType = "letter"
TaskTypeMeeting TaskType = "meeting"
TaskTypeBatch TaskType = "batch"
TaskTypeOther TaskType = "other"
)
// CarryoverMonthsCap is the maximum number of months tasks can accumulate
const CarryoverMonthsCap = 5
// Subscription represents a user's subscription
type Subscription struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
StripeCustomerID string `json:"stripe_customer_id"`
StripeSubscriptionID string `json:"stripe_subscription_id"`
PlanID PlanID `json:"plan_id"`
Status SubscriptionStatus `json:"status"`
TrialEnd *time.Time `json:"trial_end,omitempty"`
CurrentPeriodEnd *time.Time `json:"current_period_end,omitempty"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// BillingPlan represents a billing plan with its features and limits
type BillingPlan struct {
ID PlanID `json:"id"`
StripePriceID string `json:"stripe_price_id"`
Name string `json:"name"`
Description string `json:"description"`
PriceCents int `json:"price_cents"` // Price in cents (990 = 9.90 EUR)
Currency string `json:"currency"`
Interval string `json:"interval"` // "month" or "year"
Features PlanFeatures `json:"features"`
IsActive bool `json:"is_active"`
SortOrder int `json:"sort_order"`
}
// PlanFeatures represents the features and limits of a plan
type PlanFeatures struct {
// Task-based limits (primary billing unit)
MonthlyTaskAllowance int `json:"monthly_task_allowance"` // Tasks per month
MaxTaskBalance int `json:"max_task_balance"` // Max accumulated tasks (allowance * CarryoverMonthsCap)
// Legacy fields for backward compatibility (deprecated, use task-based limits)
AIRequestsLimit int `json:"ai_requests_limit,omitempty"`
DocumentsLimit int `json:"documents_limit,omitempty"`
// Feature flags
FeatureFlags []string `json:"feature_flags"`
MaxTeamMembers int `json:"max_team_members,omitempty"`
PrioritySupport bool `json:"priority_support"`
CustomBranding bool `json:"custom_branding"`
BatchProcessing bool `json:"batch_processing"`
CustomTemplates bool `json:"custom_templates"`
// Premium: Fair Use (no visible limit)
FairUseMode bool `json:"fair_use_mode"`
}
// Task represents a single task that consumes 1 unit from the balance
type Task struct {
ID uuid.UUID `json:"id"`
AccountID uuid.UUID `json:"account_id"`
TaskType TaskType `json:"task_type"`
CreatedAt time.Time `json:"created_at"`
Consumed bool `json:"consumed"` // Always true when created
// Internal metrics (not shown to user)
PageCount int `json:"-"`
TokenCount int `json:"-"`
ProcessTime int `json:"-"` // in seconds
}
// AccountUsage represents the task-based usage for an account
type AccountUsage struct {
ID uuid.UUID `json:"id"`
AccountID uuid.UUID `json:"account_id"`
PlanID PlanID `json:"plan"`
MonthlyTaskAllowance int `json:"monthly_task_allowance"`
CarryoverMonthsCap int `json:"carryover_months_cap"` // Always 5
MaxTaskBalance int `json:"max_task_balance"` // allowance * cap
TaskBalance int `json:"task_balance"` // Current available tasks
LastRenewalAt time.Time `json:"last_renewal_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UsageSummary tracks usage for a specific period (internal metrics)
type UsageSummary struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
UsageType string `json:"usage_type"` // "task", "page", "token"
PeriodStart time.Time `json:"period_start"`
TotalCount int `json:"total_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UserEntitlements represents cached entitlements for a user
type UserEntitlements struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
PlanID PlanID `json:"plan_id"`
TaskBalance int `json:"task_balance"`
MaxBalance int `json:"max_balance"`
Features PlanFeatures `json:"features"`
UpdatedAt time.Time `json:"updated_at"`
// Legacy fields for backward compatibility with old entitlement service
AIRequestsLimit int `json:"ai_requests_limit"`
AIRequestsUsed int `json:"ai_requests_used"`
DocumentsLimit int `json:"documents_limit"`
DocumentsUsed int `json:"documents_used"`
}
// StripeWebhookEvent tracks processed webhook events for idempotency
type StripeWebhookEvent struct {
StripeEventID string `json:"stripe_event_id"`
EventType string `json:"event_type"`
Processed bool `json:"processed"`
ProcessedAt time.Time `json:"processed_at"`
CreatedAt time.Time `json:"created_at"`
}
// BillingStatusResponse is the response for the billing status endpoint
type BillingStatusResponse struct {
HasSubscription bool `json:"has_subscription"`
Subscription *SubscriptionInfo `json:"subscription,omitempty"`
TaskUsage *TaskUsageInfo `json:"task_usage,omitempty"`
Entitlements *EntitlementInfo `json:"entitlements,omitempty"`
AvailablePlans []BillingPlan `json:"available_plans,omitempty"`
}
// SubscriptionInfo contains subscription details for the response
type SubscriptionInfo struct {
PlanID PlanID `json:"plan_id"`
PlanName string `json:"plan_name"`
Status SubscriptionStatus `json:"status"`
IsTrialing bool `json:"is_trialing"`
TrialDaysLeft int `json:"trial_days_left,omitempty"`
CurrentPeriodEnd *time.Time `json:"current_period_end,omitempty"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
PriceCents int `json:"price_cents"`
Currency string `json:"currency"`
}
// TaskUsageInfo contains current task usage information
// This is the ONLY usage info shown to users
type TaskUsageInfo struct {
TasksAvailable int `json:"tasks_available"` // Current balance
MaxTasks int `json:"max_tasks"` // Max possible balance
InfoText string `json:"info_text"` // "Aufgaben verfuegbar: X von max. Y"
TooltipText string `json:"tooltip_text"` // "Aufgaben koennen sich bis zu 5 Monate ansammeln."
}
// EntitlementInfo contains feature entitlements
type EntitlementInfo struct {
Features []string `json:"features"`
MaxTeamMembers int `json:"max_team_members,omitempty"`
PrioritySupport bool `json:"priority_support"`
CustomBranding bool `json:"custom_branding"`
BatchProcessing bool `json:"batch_processing"`
CustomTemplates bool `json:"custom_templates"`
FairUseMode bool `json:"fair_use_mode"` // Premium only
}
// StartTrialRequest is the request to start a trial
type StartTrialRequest struct {
PlanID PlanID `json:"plan_id" binding:"required"`
}
// StartTrialResponse is the response after starting a trial
type StartTrialResponse struct {
CheckoutURL string `json:"checkout_url"`
SessionID string `json:"session_id"`
}
// ChangePlanRequest is the request to change plans
type ChangePlanRequest struct {
NewPlanID PlanID `json:"new_plan_id" binding:"required"`
}
// ChangePlanResponse is the response after changing plans
type ChangePlanResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
EffectiveDate string `json:"effective_date,omitempty"`
}
// CancelSubscriptionResponse is the response after canceling
type CancelSubscriptionResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
CancelDate string `json:"cancel_date"`
ActiveUntil string `json:"active_until"`
}
// CustomerPortalResponse contains the portal URL
type CustomerPortalResponse struct {
PortalURL string `json:"portal_url"`
}
// ConsumeTaskRequest is the request to consume a task (internal)
type ConsumeTaskRequest struct {
UserID string `json:"user_id" binding:"required"`
TaskType TaskType `json:"task_type" binding:"required"`
}
// ConsumeTaskResponse is the response after consuming a task
type ConsumeTaskResponse struct {
Success bool `json:"success"`
TaskID string `json:"task_id,omitempty"`
TasksRemaining int `json:"tasks_remaining"`
Message string `json:"message,omitempty"`
}
// CheckTaskAllowedResponse is the response for task limit checks
type CheckTaskAllowedResponse struct {
Allowed bool `json:"allowed"`
TasksAvailable int `json:"tasks_available"`
MaxTasks int `json:"max_tasks"`
PlanID PlanID `json:"plan_id"`
Message string `json:"message,omitempty"`
}
// EntitlementCheckResponse is the response for entitlement checks (internal)
type EntitlementCheckResponse struct {
HasEntitlement bool `json:"has_entitlement"`
PlanID PlanID `json:"plan_id,omitempty"`
Message string `json:"message,omitempty"`
}
// TaskLimitError represents the error when task limit is reached
type TaskLimitError struct {
Error string `json:"error"`
CurrentBalance int `json:"current_balance"`
Plan PlanID `json:"plan"`
}
// UsageInfo represents current usage information (legacy, prefer TaskUsageInfo)
type UsageInfo struct {
AIRequestsUsed int `json:"ai_requests_used"`
AIRequestsLimit int `json:"ai_requests_limit"`
AIRequestsPercent float64 `json:"ai_requests_percent"`
DocumentsUsed int `json:"documents_used"`
DocumentsLimit int `json:"documents_limit"`
DocumentsPercent float64 `json:"documents_percent"`
PeriodStart string `json:"period_start"`
PeriodEnd string `json:"period_end"`
}
// CheckUsageResponse is the response for legacy usage checks
type CheckUsageResponse struct {
Allowed bool `json:"allowed"`
CurrentUsage int `json:"current_usage"`
Limit int `json:"limit"`
Remaining int `json:"remaining"`
Message string `json:"message,omitempty"`
}
// TrackUsageRequest is the request to track usage (internal)
type TrackUsageRequest struct {
UserID string `json:"user_id" binding:"required"`
UsageType string `json:"usage_type" binding:"required"`
Quantity int `json:"quantity"`
}
// GetDefaultPlans returns the default billing plans with task-based limits
func GetDefaultPlans() []BillingPlan {
return []BillingPlan{
{
ID: PlanBasic,
Name: "Basic",
Description: "Perfekt fuer den Einstieg - Gelegentliche Nutzung",
PriceCents: 990, // 9.90 EUR
Currency: "eur",
Interval: "month",
Features: PlanFeatures{
MonthlyTaskAllowance: 30, // 30 tasks/month
MaxTaskBalance: 30 * CarryoverMonthsCap, // 150 max
FeatureFlags: []string{"basic_ai", "basic_documents"},
MaxTeamMembers: 1,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: false,
CustomTemplates: false,
FairUseMode: false,
},
IsActive: true,
SortOrder: 1,
},
{
ID: PlanStandard,
Name: "Standard",
Description: "Fuer regelmaessige Nutzer - Mehrere Klassen und regelmaessige Korrekturen",
PriceCents: 1990, // 19.90 EUR
Currency: "eur",
Interval: "month",
Features: PlanFeatures{
MonthlyTaskAllowance: 100, // 100 tasks/month
MaxTaskBalance: 100 * CarryoverMonthsCap, // 500 max
FeatureFlags: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
},
IsActive: true,
SortOrder: 2,
},
{
ID: PlanPremium,
Name: "Premium",
Description: "Sorglos-Tarif - Vielnutzer, Teams, schulischer Kontext",
PriceCents: 3990, // 39.90 EUR
Currency: "eur",
Interval: "month",
Features: PlanFeatures{
MonthlyTaskAllowance: 1000, // Very high (Fair Use)
MaxTaskBalance: 1000 * CarryoverMonthsCap, // 5000 max (not shown to user)
FeatureFlags: []string{"basic_ai", "basic_documents", "templates", "batch_processing", "team_features", "admin_panel", "audit_log", "api_access"},
MaxTeamMembers: 10,
PrioritySupport: true,
CustomBranding: true,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: true, // No visible limit
},
IsActive: true,
SortOrder: 3,
},
}
}
// CalculateMaxTaskBalance calculates max task balance from monthly allowance
func CalculateMaxTaskBalance(monthlyAllowance int) int {
return monthlyAllowance * CarryoverMonthsCap
}
@@ -0,0 +1,319 @@
package models
import (
"testing"
)
func TestCarryoverMonthsCap(t *testing.T) {
// Verify the constant is set correctly
if CarryoverMonthsCap != 5 {
t.Errorf("CarryoverMonthsCap should be 5, got %d", CarryoverMonthsCap)
}
}
func TestCalculateMaxTaskBalance(t *testing.T) {
tests := []struct {
name string
monthlyAllowance int
expected int
}{
{"Basic plan", 30, 150},
{"Standard plan", 100, 500},
{"Premium plan", 1000, 5000},
{"Zero allowance", 0, 0},
{"Single task", 1, 5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := CalculateMaxTaskBalance(tt.monthlyAllowance)
if result != tt.expected {
t.Errorf("CalculateMaxTaskBalance(%d) = %d, expected %d",
tt.monthlyAllowance, result, tt.expected)
}
})
}
}
func TestGetDefaultPlans(t *testing.T) {
plans := GetDefaultPlans()
if len(plans) != 3 {
t.Fatalf("Expected 3 plans, got %d", len(plans))
}
// Test Basic plan
basic := plans[0]
if basic.ID != PlanBasic {
t.Errorf("First plan should be Basic, got %s", basic.ID)
}
if basic.PriceCents != 990 {
t.Errorf("Basic price should be 990 cents, got %d", basic.PriceCents)
}
if basic.Features.MonthlyTaskAllowance != 30 {
t.Errorf("Basic monthly allowance should be 30, got %d", basic.Features.MonthlyTaskAllowance)
}
if basic.Features.MaxTaskBalance != 150 {
t.Errorf("Basic max balance should be 150, got %d", basic.Features.MaxTaskBalance)
}
if basic.Features.FairUseMode {
t.Error("Basic should not have FairUseMode")
}
// Test Standard plan
standard := plans[1]
if standard.ID != PlanStandard {
t.Errorf("Second plan should be Standard, got %s", standard.ID)
}
if standard.PriceCents != 1990 {
t.Errorf("Standard price should be 1990 cents, got %d", standard.PriceCents)
}
if standard.Features.MonthlyTaskAllowance != 100 {
t.Errorf("Standard monthly allowance should be 100, got %d", standard.Features.MonthlyTaskAllowance)
}
if !standard.Features.BatchProcessing {
t.Error("Standard should have BatchProcessing")
}
if !standard.Features.CustomTemplates {
t.Error("Standard should have CustomTemplates")
}
// Test Premium plan
premium := plans[2]
if premium.ID != PlanPremium {
t.Errorf("Third plan should be Premium, got %s", premium.ID)
}
if premium.PriceCents != 3990 {
t.Errorf("Premium price should be 3990 cents, got %d", premium.PriceCents)
}
if !premium.Features.FairUseMode {
t.Error("Premium should have FairUseMode")
}
if !premium.Features.PrioritySupport {
t.Error("Premium should have PrioritySupport")
}
if !premium.Features.CustomBranding {
t.Error("Premium should have CustomBranding")
}
}
func TestPlanIDConstants(t *testing.T) {
if PlanBasic != "basic" {
t.Errorf("PlanBasic should be 'basic', got '%s'", PlanBasic)
}
if PlanStandard != "standard" {
t.Errorf("PlanStandard should be 'standard', got '%s'", PlanStandard)
}
if PlanPremium != "premium" {
t.Errorf("PlanPremium should be 'premium', got '%s'", PlanPremium)
}
}
func TestSubscriptionStatusConstants(t *testing.T) {
statuses := []struct {
status SubscriptionStatus
expected string
}{
{StatusTrialing, "trialing"},
{StatusActive, "active"},
{StatusPastDue, "past_due"},
{StatusCanceled, "canceled"},
{StatusExpired, "expired"},
}
for _, tt := range statuses {
if string(tt.status) != tt.expected {
t.Errorf("Status %s should be '%s'", tt.status, tt.expected)
}
}
}
func TestTaskTypeConstants(t *testing.T) {
types := []struct {
taskType TaskType
expected string
}{
{TaskTypeCorrection, "correction"},
{TaskTypeLetter, "letter"},
{TaskTypeMeeting, "meeting"},
{TaskTypeBatch, "batch"},
{TaskTypeOther, "other"},
}
for _, tt := range types {
if string(tt.taskType) != tt.expected {
t.Errorf("TaskType %s should be '%s'", tt.taskType, tt.expected)
}
}
}
func TestPlanFeatures_CarryoverCalculation(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
expectedMax := plan.Features.MonthlyTaskAllowance * CarryoverMonthsCap
if plan.Features.MaxTaskBalance != expectedMax {
t.Errorf("Plan %s: MaxTaskBalance should be %d (allowance * 5), got %d",
plan.ID, expectedMax, plan.Features.MaxTaskBalance)
}
}
}
func TestBillingPlan_AllPlansActive(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
if !plan.IsActive {
t.Errorf("Plan %s should be active", plan.ID)
}
}
}
func TestBillingPlan_CurrencyIsEuro(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
if plan.Currency != "eur" {
t.Errorf("Plan %s currency should be 'eur', got '%s'", plan.ID, plan.Currency)
}
}
}
func TestBillingPlan_IntervalIsMonth(t *testing.T) {
plans := GetDefaultPlans()
for _, plan := range plans {
if plan.Interval != "month" {
t.Errorf("Plan %s interval should be 'month', got '%s'", plan.ID, plan.Interval)
}
}
}
func TestBillingPlan_SortOrder(t *testing.T) {
plans := GetDefaultPlans()
for i, plan := range plans {
expectedOrder := i + 1
if plan.SortOrder != expectedOrder {
t.Errorf("Plan %s sort order should be %d, got %d",
plan.ID, expectedOrder, plan.SortOrder)
}
}
}
func TestTaskUsageInfo_FormatStrings(t *testing.T) {
usage := TaskUsageInfo{
TasksAvailable: 45,
MaxTasks: 150,
InfoText: "Aufgaben verfuegbar: 45 von max. 150",
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
}
if usage.TasksAvailable != 45 {
t.Errorf("TasksAvailable should be 45, got %d", usage.TasksAvailable)
}
if usage.MaxTasks != 150 {
t.Errorf("MaxTasks should be 150, got %d", usage.MaxTasks)
}
}
func TestCheckTaskAllowedResponse_Allowed(t *testing.T) {
response := CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: 50,
MaxTasks: 150,
PlanID: PlanBasic,
}
if !response.Allowed {
t.Error("Response should be allowed")
}
if response.Message != "" {
t.Errorf("Message should be empty for allowed response, got '%s'", response.Message)
}
}
func TestCheckTaskAllowedResponse_NotAllowed(t *testing.T) {
response := CheckTaskAllowedResponse{
Allowed: false,
TasksAvailable: 0,
MaxTasks: 150,
PlanID: PlanBasic,
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
}
if response.Allowed {
t.Error("Response should not be allowed")
}
if response.TasksAvailable != 0 {
t.Errorf("TasksAvailable should be 0, got %d", response.TasksAvailable)
}
}
func TestTaskLimitError(t *testing.T) {
err := TaskLimitError{
Error: "TASK_LIMIT_REACHED",
CurrentBalance: 0,
Plan: PlanBasic,
}
if err.Error != "TASK_LIMIT_REACHED" {
t.Errorf("Error should be 'TASK_LIMIT_REACHED', got '%s'", err.Error)
}
if err.CurrentBalance != 0 {
t.Errorf("CurrentBalance should be 0, got %d", err.CurrentBalance)
}
if err.Plan != PlanBasic {
t.Errorf("Plan should be basic, got '%s'", err.Plan)
}
}
func TestConsumeTaskRequest(t *testing.T) {
req := ConsumeTaskRequest{
UserID: "550e8400-e29b-41d4-a716-446655440000",
TaskType: TaskTypeCorrection,
}
if req.UserID == "" {
t.Error("UserID should not be empty")
}
if req.TaskType != TaskTypeCorrection {
t.Errorf("TaskType should be correction, got '%s'", req.TaskType)
}
}
func TestConsumeTaskResponse_Success(t *testing.T) {
resp := ConsumeTaskResponse{
Success: true,
TaskID: "task-123",
TasksRemaining: 49,
}
if !resp.Success {
t.Error("Response should be successful")
}
if resp.TasksRemaining != 49 {
t.Errorf("TasksRemaining should be 49, got %d", resp.TasksRemaining)
}
}
func TestEntitlementInfo_Premium(t *testing.T) {
premium := GetDefaultPlans()[2]
info := EntitlementInfo{
Features: premium.Features.FeatureFlags,
MaxTeamMembers: premium.Features.MaxTeamMembers,
PrioritySupport: premium.Features.PrioritySupport,
CustomBranding: premium.Features.CustomBranding,
BatchProcessing: premium.Features.BatchProcessing,
CustomTemplates: premium.Features.CustomTemplates,
FairUseMode: premium.Features.FairUseMode,
}
if !info.FairUseMode {
t.Error("Premium should have FairUseMode")
}
if info.MaxTeamMembers != 10 {
t.Errorf("Premium MaxTeamMembers should be 10, got %d", info.MaxTeamMembers)
}
}
@@ -0,0 +1,232 @@
package services
import (
"context"
"encoding/json"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
// EntitlementService handles entitlement-related operations
type EntitlementService struct {
db *database.DB
subService *SubscriptionService
}
// NewEntitlementService creates a new EntitlementService
func NewEntitlementService(db *database.DB, subService *SubscriptionService) *EntitlementService {
return &EntitlementService{
db: db,
subService: subService,
}
}
// GetEntitlements returns the entitlement info for a user
func (s *EntitlementService) GetEntitlements(ctx context.Context, userID uuid.UUID) (*models.EntitlementInfo, error) {
entitlements, err := s.getUserEntitlements(ctx, userID)
if err != nil || entitlements == nil {
return nil, err
}
return &models.EntitlementInfo{
Features: entitlements.Features.FeatureFlags,
MaxTeamMembers: entitlements.Features.MaxTeamMembers,
PrioritySupport: entitlements.Features.PrioritySupport,
CustomBranding: entitlements.Features.CustomBranding,
}, nil
}
// GetEntitlementsByUserIDString returns entitlements by user ID string (for internal API)
func (s *EntitlementService) GetEntitlementsByUserIDString(ctx context.Context, userIDStr string) (*models.UserEntitlements, error) {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return nil, err
}
return s.getUserEntitlements(ctx, userID)
}
// getUserEntitlements retrieves or creates entitlements for a user
func (s *EntitlementService) getUserEntitlements(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
query := `
SELECT id, user_id, plan_id, ai_requests_limit, ai_requests_used,
documents_limit, documents_used, features, period_start, period_end,
created_at, updated_at
FROM user_entitlements
WHERE user_id = $1
`
var ent models.UserEntitlements
var featuresJSON []byte
var periodStart, periodEnd *time.Time
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
&ent.DocumentsLimit, &ent.DocumentsUsed, &featuresJSON, &periodStart, &periodEnd,
nil, &ent.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
// Try to create entitlements based on subscription
return s.createEntitlementsFromSubscription(ctx, userID)
}
return nil, err
}
if len(featuresJSON) > 0 {
json.Unmarshal(featuresJSON, &ent.Features)
}
return &ent, nil
}
// createEntitlementsFromSubscription creates entitlements based on user's subscription
func (s *EntitlementService) createEntitlementsFromSubscription(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
// Get user's subscription
sub, err := s.subService.GetByUserID(ctx, userID)
if err != nil || sub == nil {
return nil, err
}
// Get plan details
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
if err != nil || plan == nil {
return nil, err
}
// Create entitlements
return s.CreateEntitlements(ctx, userID, sub.PlanID, plan.Features, sub.CurrentPeriodEnd)
}
// CreateEntitlements creates entitlements for a user
func (s *EntitlementService) CreateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures, periodEnd *time.Time) (*models.UserEntitlements, error) {
featuresJSON, _ := json.Marshal(features)
now := time.Now()
periodStart := now
query := `
INSERT INTO user_entitlements (
user_id, plan_id, ai_requests_limit, ai_requests_used,
documents_limit, documents_used, features, period_start, period_end
) VALUES ($1, $2, $3, 0, $4, 0, $5, $6, $7)
ON CONFLICT (user_id) DO UPDATE SET
plan_id = EXCLUDED.plan_id,
ai_requests_limit = EXCLUDED.ai_requests_limit,
documents_limit = EXCLUDED.documents_limit,
features = EXCLUDED.features,
period_start = EXCLUDED.period_start,
period_end = EXCLUDED.period_end,
updated_at = NOW()
RETURNING id, user_id, plan_id, ai_requests_limit, ai_requests_used,
documents_limit, documents_used, updated_at
`
var ent models.UserEntitlements
err := s.db.Pool.QueryRow(ctx, query,
userID, planID, features.AIRequestsLimit, features.DocumentsLimit,
featuresJSON, periodStart, periodEnd,
).Scan(
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
&ent.DocumentsLimit, &ent.DocumentsUsed, &ent.UpdatedAt,
)
if err != nil {
return nil, err
}
ent.Features = features
return &ent, nil
}
// UpdateEntitlements updates entitlements for a user (e.g., on plan change)
func (s *EntitlementService) UpdateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures) error {
featuresJSON, _ := json.Marshal(features)
query := `
UPDATE user_entitlements SET
plan_id = $2,
ai_requests_limit = $3,
documents_limit = $4,
features = $5,
updated_at = NOW()
WHERE user_id = $1
`
_, err := s.db.Pool.Exec(ctx, query,
userID, planID, features.AIRequestsLimit, features.DocumentsLimit, featuresJSON,
)
return err
}
// ResetUsageCounters resets usage counters for a new period
func (s *EntitlementService) ResetUsageCounters(ctx context.Context, userID uuid.UUID, newPeriodStart, newPeriodEnd *time.Time) error {
query := `
UPDATE user_entitlements SET
ai_requests_used = 0,
documents_used = 0,
period_start = $2,
period_end = $3,
updated_at = NOW()
WHERE user_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, userID, newPeriodStart, newPeriodEnd)
return err
}
// CheckEntitlement checks if a user has a specific feature entitlement
func (s *EntitlementService) CheckEntitlement(ctx context.Context, userIDStr, feature string) (bool, models.PlanID, error) {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return false, "", err
}
ent, err := s.getUserEntitlements(ctx, userID)
if err != nil || ent == nil {
return false, "", err
}
// Check if feature is in the feature flags
for _, f := range ent.Features.FeatureFlags {
if f == feature {
return true, ent.PlanID, nil
}
}
return false, ent.PlanID, nil
}
// IncrementUsage increments a usage counter
func (s *EntitlementService) IncrementUsage(ctx context.Context, userID uuid.UUID, usageType string, amount int) error {
var column string
switch usageType {
case "ai_request":
column = "ai_requests_used"
case "document_created":
column = "documents_used"
default:
return nil
}
query := `
UPDATE user_entitlements SET
` + column + ` = ` + column + ` + $2,
updated_at = NOW()
WHERE user_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, userID, amount)
return err
}
// DeleteEntitlements removes entitlements for a user (on subscription cancellation)
func (s *EntitlementService) DeleteEntitlements(ctx context.Context, userID uuid.UUID) error {
query := `DELETE FROM user_entitlements WHERE user_id = $1`
_, err := s.db.Pool.Exec(ctx, query, userID)
return err
}
@@ -0,0 +1,317 @@
package services
import (
"context"
"fmt"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/billingportal/session"
checkoutsession "github.com/stripe/stripe-go/v76/checkout/session"
"github.com/stripe/stripe-go/v76/customer"
"github.com/stripe/stripe-go/v76/price"
"github.com/stripe/stripe-go/v76/product"
"github.com/stripe/stripe-go/v76/subscription"
)
// StripeService handles Stripe API interactions
type StripeService struct {
secretKey string
webhookSecret string
successURL string
cancelURL string
trialPeriodDays int64
subService *SubscriptionService
mockMode bool // If true, don't make real Stripe API calls
}
// NewStripeService creates a new StripeService
func NewStripeService(secretKey, webhookSecret, successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
// Initialize Stripe with the secret key (only if not empty)
if secretKey != "" {
stripe.Key = secretKey
}
return &StripeService{
secretKey: secretKey,
webhookSecret: webhookSecret,
successURL: successURL,
cancelURL: cancelURL,
trialPeriodDays: int64(trialPeriodDays),
subService: subService,
mockMode: false,
}
}
// NewMockStripeService creates a mock StripeService for development
func NewMockStripeService(successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
return &StripeService{
secretKey: "",
webhookSecret: "",
successURL: successURL,
cancelURL: cancelURL,
trialPeriodDays: int64(trialPeriodDays),
subService: subService,
mockMode: true,
}
}
// IsMockMode returns true if running in mock mode
func (s *StripeService) IsMockMode() bool {
return s.mockMode
}
// CreateCheckoutSession creates a Stripe Checkout session for trial start
func (s *StripeService) CreateCheckoutSession(ctx context.Context, userID uuid.UUID, email string, planID models.PlanID) (string, string, error) {
// Mock mode: return a fake URL for development
if s.mockMode {
mockSessionID := fmt.Sprintf("mock_cs_%s", uuid.New().String()[:8])
mockURL := fmt.Sprintf("%s?session_id=%s&mock=true&plan=%s", s.successURL, mockSessionID, planID)
return mockURL, mockSessionID, nil
}
// Get plan details
plan, err := s.subService.GetPlanByID(ctx, string(planID))
if err != nil || plan == nil {
return "", "", fmt.Errorf("plan not found: %s", planID)
}
// Ensure we have a Stripe price ID
if plan.StripePriceID == "" {
// Create product and price in Stripe if not exists
priceID, err := s.ensurePriceExists(ctx, plan)
if err != nil {
return "", "", fmt.Errorf("failed to create stripe price: %w", err)
}
plan.StripePriceID = priceID
}
// Create checkout session parameters
params := &stripe.CheckoutSessionParams{
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(plan.StripePriceID),
Quantity: stripe.Int64(1),
},
},
SuccessURL: stripe.String(s.successURL + "?session_id={CHECKOUT_SESSION_ID}"),
CancelURL: stripe.String(s.cancelURL),
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
TrialPeriodDays: stripe.Int64(s.trialPeriodDays),
Metadata: map[string]string{
"user_id": userID.String(),
"plan_id": string(planID),
},
},
PaymentMethodCollection: stripe.String(string(stripe.CheckoutSessionPaymentMethodCollectionAlways)),
Metadata: map[string]string{
"user_id": userID.String(),
"plan_id": string(planID),
},
}
// Set customer email if provided
if email != "" {
params.CustomerEmail = stripe.String(email)
}
// Create the session
sess, err := checkoutsession.New(params)
if err != nil {
return "", "", fmt.Errorf("failed to create checkout session: %w", err)
}
return sess.URL, sess.ID, nil
}
// ensurePriceExists creates a Stripe product and price if they don't exist
func (s *StripeService) ensurePriceExists(ctx context.Context, plan *models.BillingPlan) (string, error) {
// Create product
productParams := &stripe.ProductParams{
Name: stripe.String(plan.Name),
Description: stripe.String(plan.Description),
Metadata: map[string]string{
"plan_id": string(plan.ID),
},
}
prod, err := product.New(productParams)
if err != nil {
return "", fmt.Errorf("failed to create product: %w", err)
}
// Create price
priceParams := &stripe.PriceParams{
Product: stripe.String(prod.ID),
UnitAmount: stripe.Int64(int64(plan.PriceCents)),
Currency: stripe.String(plan.Currency),
Recurring: &stripe.PriceRecurringParams{
Interval: stripe.String(plan.Interval),
},
Metadata: map[string]string{
"plan_id": string(plan.ID),
},
}
pr, err := price.New(priceParams)
if err != nil {
return "", fmt.Errorf("failed to create price: %w", err)
}
// Update plan with Stripe IDs
if err := s.subService.UpdatePlanStripePriceID(ctx, string(plan.ID), pr.ID, prod.ID); err != nil {
// Log but don't fail
fmt.Printf("Warning: Failed to update plan with Stripe IDs: %v\n", err)
}
return pr.ID, nil
}
// GetOrCreateCustomer gets or creates a Stripe customer for a user
func (s *StripeService) GetOrCreateCustomer(ctx context.Context, email, name string, userID uuid.UUID) (string, error) {
// Search for existing customer
params := &stripe.CustomerSearchParams{
SearchParams: stripe.SearchParams{
Query: fmt.Sprintf("email:'%s'", email),
},
}
iter := customer.Search(params)
for iter.Next() {
cust := iter.Customer()
// Check if this customer belongs to our user
if cust.Metadata["user_id"] == userID.String() {
return cust.ID, nil
}
}
// Create new customer
customerParams := &stripe.CustomerParams{
Email: stripe.String(email),
Name: stripe.String(name),
Metadata: map[string]string{
"user_id": userID.String(),
},
}
cust, err := customer.New(customerParams)
if err != nil {
return "", fmt.Errorf("failed to create customer: %w", err)
}
return cust.ID, nil
}
// ChangePlan changes a subscription to a new plan
func (s *StripeService) ChangePlan(ctx context.Context, stripeSubID string, newPlanID models.PlanID) error {
// Mock mode: just return success
if s.mockMode {
return nil
}
// Get new plan details
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
if err != nil || plan == nil {
return fmt.Errorf("plan not found: %s", newPlanID)
}
if plan.StripePriceID == "" {
return fmt.Errorf("plan %s has no Stripe price ID", newPlanID)
}
// Get current subscription
sub, err := subscription.Get(stripeSubID, nil)
if err != nil {
return fmt.Errorf("failed to get subscription: %w", err)
}
// Update subscription with new price
params := &stripe.SubscriptionParams{
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(plan.StripePriceID),
},
},
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Metadata: map[string]string{
"plan_id": string(newPlanID),
},
}
_, err = subscription.Update(stripeSubID, params)
if err != nil {
return fmt.Errorf("failed to update subscription: %w", err)
}
return nil
}
// CancelSubscription cancels a subscription at period end
func (s *StripeService) CancelSubscription(ctx context.Context, stripeSubID string) error {
// Mock mode: just return success
if s.mockMode {
return nil
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(true),
}
_, err := subscription.Update(stripeSubID, params)
if err != nil {
return fmt.Errorf("failed to cancel subscription: %w", err)
}
return nil
}
// ReactivateSubscription removes the cancel_at_period_end flag
func (s *StripeService) ReactivateSubscription(ctx context.Context, stripeSubID string) error {
// Mock mode: just return success
if s.mockMode {
return nil
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
}
_, err := subscription.Update(stripeSubID, params)
if err != nil {
return fmt.Errorf("failed to reactivate subscription: %w", err)
}
return nil
}
// CreateCustomerPortalSession creates a Stripe Customer Portal session
func (s *StripeService) CreateCustomerPortalSession(ctx context.Context, customerID string) (string, error) {
// Mock mode: return a mock URL
if s.mockMode {
return fmt.Sprintf("%s?mock_portal=true", s.successURL), nil
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(customerID),
ReturnURL: stripe.String(s.successURL),
}
sess, err := session.New(params)
if err != nil {
return "", fmt.Errorf("failed to create portal session: %w", err)
}
return sess.URL, nil
}
// GetSubscription retrieves a subscription from Stripe
func (s *StripeService) GetSubscription(ctx context.Context, stripeSubID string) (*stripe.Subscription, error) {
sub, err := subscription.Get(stripeSubID, nil)
if err != nil {
return nil, fmt.Errorf("failed to get subscription: %w", err)
}
return sub, nil
}
@@ -0,0 +1,315 @@
package services
import (
"context"
"encoding/json"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
// SubscriptionService handles subscription-related operations
type SubscriptionService struct {
db *database.DB
}
// NewSubscriptionService creates a new SubscriptionService
func NewSubscriptionService(db *database.DB) *SubscriptionService {
return &SubscriptionService{db: db}
}
// GetByUserID retrieves a subscription by user ID
func (s *SubscriptionService) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.Subscription, error) {
query := `
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
status, trial_end, current_period_end, cancel_at_period_end,
created_at, updated_at
FROM subscriptions
WHERE user_id = $1
`
var sub models.Subscription
var stripeCustomerID, stripeSubID *string
var trialEnd, periodEnd *time.Time
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&sub.ID, &sub.UserID, &stripeCustomerID, &stripeSubID, &sub.PlanID,
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
&sub.CreatedAt, &sub.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
if stripeCustomerID != nil {
sub.StripeCustomerID = *stripeCustomerID
}
if stripeSubID != nil {
sub.StripeSubscriptionID = *stripeSubID
}
sub.TrialEnd = trialEnd
sub.CurrentPeriodEnd = periodEnd
return &sub, nil
}
// GetByStripeSubscriptionID retrieves a subscription by Stripe subscription ID
func (s *SubscriptionService) GetByStripeSubscriptionID(ctx context.Context, stripeSubID string) (*models.Subscription, error) {
query := `
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
status, trial_end, current_period_end, cancel_at_period_end,
created_at, updated_at
FROM subscriptions
WHERE stripe_subscription_id = $1
`
var sub models.Subscription
var stripeCustomerID, subID *string
var trialEnd, periodEnd *time.Time
err := s.db.Pool.QueryRow(ctx, query, stripeSubID).Scan(
&sub.ID, &sub.UserID, &stripeCustomerID, &subID, &sub.PlanID,
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
&sub.CreatedAt, &sub.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
if stripeCustomerID != nil {
sub.StripeCustomerID = *stripeCustomerID
}
if subID != nil {
sub.StripeSubscriptionID = *subID
}
sub.TrialEnd = trialEnd
sub.CurrentPeriodEnd = periodEnd
return &sub, nil
}
// Create creates a new subscription
func (s *SubscriptionService) Create(ctx context.Context, sub *models.Subscription) error {
query := `
INSERT INTO subscriptions (
user_id, stripe_customer_id, stripe_subscription_id, plan_id,
status, trial_end, current_period_end, cancel_at_period_end
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at
`
return s.db.Pool.QueryRow(ctx, query,
sub.UserID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt)
}
// Update updates an existing subscription
func (s *SubscriptionService) Update(ctx context.Context, sub *models.Subscription) error {
query := `
UPDATE subscriptions SET
stripe_customer_id = $2,
stripe_subscription_id = $3,
plan_id = $4,
status = $5,
trial_end = $6,
current_period_end = $7,
cancel_at_period_end = $8,
updated_at = NOW()
WHERE id = $1
`
_, err := s.db.Pool.Exec(ctx, query,
sub.ID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
)
return err
}
// UpdateStatus updates the subscription status
func (s *SubscriptionService) UpdateStatus(ctx context.Context, id uuid.UUID, status models.SubscriptionStatus) error {
query := `UPDATE subscriptions SET status = $2, updated_at = NOW() WHERE id = $1`
_, err := s.db.Pool.Exec(ctx, query, id, status)
return err
}
// GetAvailablePlans retrieves all active billing plans
func (s *SubscriptionService) GetAvailablePlans(ctx context.Context) ([]models.BillingPlan, error) {
query := `
SELECT id, stripe_price_id, name, description, price_cents,
currency, interval, features, is_active, sort_order
FROM billing_plans
WHERE is_active = true
ORDER BY sort_order ASC
`
rows, err := s.db.Pool.Query(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var plans []models.BillingPlan
for rows.Next() {
var plan models.BillingPlan
var stripePriceID *string
var featuresJSON []byte
err := rows.Scan(
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
&plan.PriceCents, &plan.Currency, &plan.Interval,
&featuresJSON, &plan.IsActive, &plan.SortOrder,
)
if err != nil {
return nil, err
}
if stripePriceID != nil {
plan.StripePriceID = *stripePriceID
}
// Parse features JSON
if len(featuresJSON) > 0 {
json.Unmarshal(featuresJSON, &plan.Features)
}
plans = append(plans, plan)
}
return plans, nil
}
// GetPlanByID retrieves a billing plan by ID
func (s *SubscriptionService) GetPlanByID(ctx context.Context, planID string) (*models.BillingPlan, error) {
query := `
SELECT id, stripe_price_id, name, description, price_cents,
currency, interval, features, is_active, sort_order
FROM billing_plans
WHERE id = $1
`
var plan models.BillingPlan
var stripePriceID *string
var featuresJSON []byte
err := s.db.Pool.QueryRow(ctx, query, planID).Scan(
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
&plan.PriceCents, &plan.Currency, &plan.Interval,
&featuresJSON, &plan.IsActive, &plan.SortOrder,
)
if err != nil {
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
if stripePriceID != nil {
plan.StripePriceID = *stripePriceID
}
if len(featuresJSON) > 0 {
json.Unmarshal(featuresJSON, &plan.Features)
}
return &plan, nil
}
// UpdatePlanStripePriceID updates the Stripe price ID for a plan
func (s *SubscriptionService) UpdatePlanStripePriceID(ctx context.Context, planID, stripePriceID, stripeProductID string) error {
query := `
UPDATE billing_plans
SET stripe_price_id = $2, stripe_product_id = $3, updated_at = NOW()
WHERE id = $1
`
_, err := s.db.Pool.Exec(ctx, query, planID, stripePriceID, stripeProductID)
return err
}
// =============================================
// Webhook Event Tracking (Idempotency)
// =============================================
// IsEventProcessed checks if a webhook event has already been processed
func (s *SubscriptionService) IsEventProcessed(ctx context.Context, eventID string) (bool, error) {
query := `SELECT processed FROM stripe_webhook_events WHERE stripe_event_id = $1`
var processed bool
err := s.db.Pool.QueryRow(ctx, query, eventID).Scan(&processed)
if err != nil {
if err.Error() == "no rows in result set" {
return false, nil
}
return false, err
}
return processed, nil
}
// MarkEventProcessing marks an event as being processed
func (s *SubscriptionService) MarkEventProcessing(ctx context.Context, eventID, eventType string) error {
query := `
INSERT INTO stripe_webhook_events (stripe_event_id, event_type, processed)
VALUES ($1, $2, false)
ON CONFLICT (stripe_event_id) DO NOTHING
`
_, err := s.db.Pool.Exec(ctx, query, eventID, eventType)
return err
}
// MarkEventProcessed marks an event as successfully processed
func (s *SubscriptionService) MarkEventProcessed(ctx context.Context, eventID string) error {
query := `
UPDATE stripe_webhook_events
SET processed = true, processed_at = NOW()
WHERE stripe_event_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, eventID)
return err
}
// MarkEventFailed marks an event as failed with an error message
func (s *SubscriptionService) MarkEventFailed(ctx context.Context, eventID, errorMsg string) error {
query := `
UPDATE stripe_webhook_events
SET processed = false, error_message = $2, processed_at = NOW()
WHERE stripe_event_id = $1
`
_, err := s.db.Pool.Exec(ctx, query, eventID, errorMsg)
return err
}
// =============================================
// Audit Logging
// =============================================
// LogAuditEvent logs a billing audit event
func (s *SubscriptionService) LogAuditEvent(ctx context.Context, userID *uuid.UUID, action, entityType, entityID string, oldValue, newValue, metadata interface{}, ipAddress, userAgent string) error {
oldJSON, _ := json.Marshal(oldValue)
newJSON, _ := json.Marshal(newValue)
metaJSON, _ := json.Marshal(metadata)
query := `
INSERT INTO billing_audit_log (
user_id, action, entity_type, entity_id,
old_value, new_value, metadata, ip_address, user_agent
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err := s.db.Pool.Exec(ctx, query,
userID, action, entityType, entityID,
oldJSON, newJSON, metaJSON, ipAddress, userAgent,
)
return err
}
@@ -0,0 +1,326 @@
package services
import (
"encoding/json"
"testing"
"github.com/breakpilot/billing-service/internal/models"
)
func TestSubscriptionStatus_Transitions(t *testing.T) {
// Test valid subscription status values
validStatuses := []models.SubscriptionStatus{
models.StatusTrialing,
models.StatusActive,
models.StatusPastDue,
models.StatusCanceled,
models.StatusExpired,
}
for _, status := range validStatuses {
if status == "" {
t.Errorf("Status should not be empty")
}
}
}
func TestPlanID_ValidValues(t *testing.T) {
validPlanIDs := []models.PlanID{
models.PlanBasic,
models.PlanStandard,
models.PlanPremium,
}
expected := []string{"basic", "standard", "premium"}
for i, planID := range validPlanIDs {
if string(planID) != expected[i] {
t.Errorf("PlanID should be '%s', got '%s'", expected[i], planID)
}
}
}
func TestPlanFeatures_JSONSerialization(t *testing.T) {
features := models.PlanFeatures{
MonthlyTaskAllowance: 100,
MaxTaskBalance: 500,
FeatureFlags: []string{"basic_ai", "templates"},
MaxTeamMembers: 3,
PrioritySupport: false,
CustomBranding: false,
BatchProcessing: true,
CustomTemplates: true,
FairUseMode: false,
}
// Test JSON serialization
data, err := json.Marshal(features)
if err != nil {
t.Fatalf("Failed to marshal PlanFeatures: %v", err)
}
// Test JSON deserialization
var decoded models.PlanFeatures
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("Failed to unmarshal PlanFeatures: %v", err)
}
// Verify fields
if decoded.MonthlyTaskAllowance != features.MonthlyTaskAllowance {
t.Errorf("MonthlyTaskAllowance mismatch: got %d, expected %d",
decoded.MonthlyTaskAllowance, features.MonthlyTaskAllowance)
}
if decoded.MaxTaskBalance != features.MaxTaskBalance {
t.Errorf("MaxTaskBalance mismatch: got %d, expected %d",
decoded.MaxTaskBalance, features.MaxTaskBalance)
}
if decoded.BatchProcessing != features.BatchProcessing {
t.Errorf("BatchProcessing mismatch: got %v, expected %v",
decoded.BatchProcessing, features.BatchProcessing)
}
}
func TestBillingPlan_DefaultPlansAreValid(t *testing.T) {
plans := models.GetDefaultPlans()
if len(plans) != 3 {
t.Fatalf("Expected 3 default plans, got %d", len(plans))
}
// Verify all plans have required fields
for _, plan := range plans {
if plan.ID == "" {
t.Errorf("Plan ID should not be empty")
}
if plan.Name == "" {
t.Errorf("Plan '%s' should have a name", plan.ID)
}
if plan.Description == "" {
t.Errorf("Plan '%s' should have a description", plan.ID)
}
if plan.PriceCents <= 0 {
t.Errorf("Plan '%s' should have a positive price, got %d", plan.ID, plan.PriceCents)
}
if plan.Currency != "eur" {
t.Errorf("Plan '%s' currency should be 'eur', got '%s'", plan.ID, plan.Currency)
}
if plan.Interval != "month" {
t.Errorf("Plan '%s' interval should be 'month', got '%s'", plan.ID, plan.Interval)
}
if !plan.IsActive {
t.Errorf("Plan '%s' should be active", plan.ID)
}
if plan.SortOrder <= 0 {
t.Errorf("Plan '%s' should have a positive sort order, got %d", plan.ID, plan.SortOrder)
}
}
}
func TestBillingPlan_TaskAllowanceProgression(t *testing.T) {
plans := models.GetDefaultPlans()
// Basic should have lowest allowance
basic := plans[0]
standard := plans[1]
premium := plans[2]
if basic.Features.MonthlyTaskAllowance >= standard.Features.MonthlyTaskAllowance {
t.Error("Standard plan should have more tasks than Basic")
}
if standard.Features.MonthlyTaskAllowance >= premium.Features.MonthlyTaskAllowance {
t.Error("Premium plan should have more tasks than Standard")
}
}
func TestBillingPlan_PriceProgression(t *testing.T) {
plans := models.GetDefaultPlans()
// Prices should increase with each tier
if plans[0].PriceCents >= plans[1].PriceCents {
t.Error("Standard should cost more than Basic")
}
if plans[1].PriceCents >= plans[2].PriceCents {
t.Error("Premium should cost more than Standard")
}
}
func TestBillingPlan_FairUseModeOnlyForPremium(t *testing.T) {
plans := models.GetDefaultPlans()
for _, plan := range plans {
if plan.ID == models.PlanPremium {
if !plan.Features.FairUseMode {
t.Error("Premium plan should have FairUseMode enabled")
}
} else {
if plan.Features.FairUseMode {
t.Errorf("Plan '%s' should not have FairUseMode enabled", plan.ID)
}
}
}
}
func TestBillingPlan_MaxTaskBalanceCalculation(t *testing.T) {
plans := models.GetDefaultPlans()
for _, plan := range plans {
expected := plan.Features.MonthlyTaskAllowance * models.CarryoverMonthsCap
if plan.Features.MaxTaskBalance != expected {
t.Errorf("Plan '%s' MaxTaskBalance should be %d (allowance * 5), got %d",
plan.ID, expected, plan.Features.MaxTaskBalance)
}
}
}
func TestAuditLogJSON_Marshaling(t *testing.T) {
// Test that audit log values can be properly serialized
oldValue := map[string]interface{}{
"plan_id": "basic",
"status": "active",
}
newValue := map[string]interface{}{
"plan_id": "standard",
"status": "active",
}
metadata := map[string]interface{}{
"reason": "upgrade",
}
// Marshal all values
oldJSON, err := json.Marshal(oldValue)
if err != nil {
t.Fatalf("Failed to marshal oldValue: %v", err)
}
newJSON, err := json.Marshal(newValue)
if err != nil {
t.Fatalf("Failed to marshal newValue: %v", err)
}
metaJSON, err := json.Marshal(metadata)
if err != nil {
t.Fatalf("Failed to marshal metadata: %v", err)
}
// Verify non-empty
if len(oldJSON) == 0 || len(newJSON) == 0 || len(metaJSON) == 0 {
t.Error("JSON outputs should not be empty")
}
}
func TestSubscriptionTrialCalculation(t *testing.T) {
// Test trial days calculation logic
trialDays := 7
if trialDays <= 0 {
t.Error("Trial days should be positive")
}
if trialDays > 30 {
t.Error("Trial days should not exceed 30")
}
}
func TestSubscriptionInfo_TrialingStatus(t *testing.T) {
info := models.SubscriptionInfo{
PlanID: models.PlanBasic,
PlanName: "Basic",
Status: models.StatusTrialing,
IsTrialing: true,
TrialDaysLeft: 5,
CancelAtPeriodEnd: false,
PriceCents: 990,
Currency: "eur",
}
if !info.IsTrialing {
t.Error("Should be trialing")
}
if info.Status != models.StatusTrialing {
t.Errorf("Status should be 'trialing', got '%s'", info.Status)
}
if info.TrialDaysLeft <= 0 {
t.Error("TrialDaysLeft should be positive during trial")
}
}
func TestSubscriptionInfo_ActiveStatus(t *testing.T) {
info := models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
TrialDaysLeft: 0,
CancelAtPeriodEnd: false,
PriceCents: 1990,
Currency: "eur",
}
if info.IsTrialing {
t.Error("Should not be trialing")
}
if info.Status != models.StatusActive {
t.Errorf("Status should be 'active', got '%s'", info.Status)
}
}
func TestSubscriptionInfo_CanceledStatus(t *testing.T) {
info := models.SubscriptionInfo{
PlanID: models.PlanStandard,
PlanName: "Standard",
Status: models.StatusActive,
IsTrialing: false,
CancelAtPeriodEnd: true, // Scheduled for cancellation
PriceCents: 1990,
Currency: "eur",
}
if !info.CancelAtPeriodEnd {
t.Error("CancelAtPeriodEnd should be true")
}
// Status remains active until period end
if info.Status != models.StatusActive {
t.Errorf("Status should still be 'active', got '%s'", info.Status)
}
}
func TestWebhookEventTypes(t *testing.T) {
// Test common Stripe webhook event types we handle
eventTypes := []string{
"checkout.session.completed",
"customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
"invoice.paid",
"invoice.payment_failed",
}
for _, eventType := range eventTypes {
if eventType == "" {
t.Error("Event type should not be empty")
}
}
}
func TestIdempotencyKey_Format(t *testing.T) {
// Test that we can handle Stripe event IDs
sampleEventIDs := []string{
"evt_1234567890abcdef",
"evt_test_abc123xyz789",
"evt_live_real_event_id",
}
for _, eventID := range sampleEventIDs {
if len(eventID) < 10 {
t.Errorf("Event ID '%s' seems too short", eventID)
}
// Stripe event IDs typically start with "evt_"
if eventID[:4] != "evt_" {
t.Errorf("Event ID '%s' should start with 'evt_'", eventID)
}
}
}
@@ -0,0 +1,352 @@
package services
import (
"context"
"errors"
"fmt"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
var (
// ErrTaskLimitReached is returned when task balance is 0
ErrTaskLimitReached = errors.New("TASK_LIMIT_REACHED")
// ErrNoSubscription is returned when user has no subscription
ErrNoSubscription = errors.New("NO_SUBSCRIPTION")
)
// TaskService handles task consumption and balance management
type TaskService struct {
db *database.DB
subService *SubscriptionService
}
// NewTaskService creates a new TaskService
func NewTaskService(db *database.DB, subService *SubscriptionService) *TaskService {
return &TaskService{
db: db,
subService: subService,
}
}
// GetAccountUsage retrieves or creates account usage for a user
func (s *TaskService) GetAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
query := `
SELECT id, account_id, plan, monthly_task_allowance, carryover_months_cap,
max_task_balance, task_balance, last_renewal_at, created_at, updated_at
FROM account_usage
WHERE account_id = $1
`
var usage models.AccountUsage
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&usage.ID, &usage.AccountID, &usage.PlanID, &usage.MonthlyTaskAllowance,
&usage.CarryoverMonthsCap, &usage.MaxTaskBalance, &usage.TaskBalance,
&usage.LastRenewalAt, &usage.CreatedAt, &usage.UpdatedAt,
)
if err != nil {
if err.Error() == "no rows in result set" {
// Create new account usage based on subscription
return s.createAccountUsage(ctx, userID)
}
return nil, err
}
// Check if month renewal is needed
if err := s.checkAndApplyMonthRenewal(ctx, &usage); err != nil {
return nil, err
}
return &usage, nil
}
// createAccountUsage creates account usage based on user's subscription
func (s *TaskService) createAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
// Get subscription to determine plan
sub, err := s.subService.GetByUserID(ctx, userID)
if err != nil || sub == nil {
return nil, ErrNoSubscription
}
// Get plan features
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
if err != nil || plan == nil {
return nil, fmt.Errorf("plan not found: %s", sub.PlanID)
}
now := time.Now()
usage := &models.AccountUsage{
AccountID: userID,
PlanID: sub.PlanID,
MonthlyTaskAllowance: plan.Features.MonthlyTaskAllowance,
CarryoverMonthsCap: models.CarryoverMonthsCap,
MaxTaskBalance: plan.Features.MaxTaskBalance,
TaskBalance: plan.Features.MonthlyTaskAllowance, // Start with one month's worth
LastRenewalAt: now,
}
query := `
INSERT INTO account_usage (
account_id, plan, monthly_task_allowance, carryover_months_cap,
max_task_balance, task_balance, last_renewal_at
) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, created_at, updated_at
`
err = s.db.Pool.QueryRow(ctx, query,
usage.AccountID, usage.PlanID, usage.MonthlyTaskAllowance,
usage.CarryoverMonthsCap, usage.MaxTaskBalance, usage.TaskBalance, usage.LastRenewalAt,
).Scan(&usage.ID, &usage.CreatedAt, &usage.UpdatedAt)
if err != nil {
return nil, err
}
return usage, nil
}
// checkAndApplyMonthRenewal checks if a month has passed and adds allowance
// Implements the carryover logic: tasks accumulate up to max_task_balance
func (s *TaskService) checkAndApplyMonthRenewal(ctx context.Context, usage *models.AccountUsage) error {
now := time.Now()
// Check if at least one month has passed since last renewal
monthsSinceRenewal := monthsBetween(usage.LastRenewalAt, now)
if monthsSinceRenewal < 1 {
return nil
}
// Calculate new balance with carryover
// Add monthly allowance for each month that passed
newBalance := usage.TaskBalance
for i := 0; i < monthsSinceRenewal; i++ {
newBalance += usage.MonthlyTaskAllowance
// Cap at max balance
if newBalance > usage.MaxTaskBalance {
newBalance = usage.MaxTaskBalance
break
}
}
// Calculate new renewal date (add the number of months)
newRenewalAt := usage.LastRenewalAt.AddDate(0, monthsSinceRenewal, 0)
// Update in database
query := `
UPDATE account_usage
SET task_balance = $2, last_renewal_at = $3, updated_at = NOW()
WHERE id = $1
`
_, err := s.db.Pool.Exec(ctx, query, usage.ID, newBalance, newRenewalAt)
if err != nil {
return err
}
// Update local struct
usage.TaskBalance = newBalance
usage.LastRenewalAt = newRenewalAt
return nil
}
// monthsBetween calculates full months between two dates
func monthsBetween(start, end time.Time) int {
months := 0
for start.AddDate(0, months+1, 0).Before(end) || start.AddDate(0, months+1, 0).Equal(end) {
months++
}
return months
}
// CheckTaskAllowed checks if a task can be consumed (balance > 0)
func (s *TaskService) CheckTaskAllowed(ctx context.Context, userID uuid.UUID) (*models.CheckTaskAllowedResponse, error) {
usage, err := s.GetAccountUsage(ctx, userID)
if err != nil {
if errors.Is(err, ErrNoSubscription) {
return &models.CheckTaskAllowedResponse{
Allowed: false,
PlanID: "",
Message: "Kein aktives Abonnement gefunden.",
}, nil
}
return nil, err
}
// Premium Fair Use mode - always allow
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
if plan != nil && plan.Features.FairUseMode {
return &models.CheckTaskAllowedResponse{
Allowed: true,
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
PlanID: usage.PlanID,
}, nil
}
allowed := usage.TaskBalance > 0
response := &models.CheckTaskAllowedResponse{
Allowed: allowed,
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
PlanID: usage.PlanID,
}
if !allowed {
response.Message = "Dein Aufgaben-Kontingent ist aufgebraucht."
}
return response, nil
}
// ConsumeTask consumes one task from the balance
// Returns error if balance is 0
func (s *TaskService) ConsumeTask(ctx context.Context, userID uuid.UUID, taskType models.TaskType) (*models.ConsumeTaskResponse, error) {
// First check if allowed
checkResponse, err := s.CheckTaskAllowed(ctx, userID)
if err != nil {
return nil, err
}
if !checkResponse.Allowed {
return &models.ConsumeTaskResponse{
Success: false,
TasksRemaining: 0,
Message: checkResponse.Message,
}, ErrTaskLimitReached
}
// Get current usage
usage, err := s.GetAccountUsage(ctx, userID)
if err != nil {
return nil, err
}
// Start transaction
tx, err := s.db.Pool.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
// Decrement balance (only if not Premium Fair Use)
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
newBalance := usage.TaskBalance
if plan == nil || !plan.Features.FairUseMode {
newBalance = usage.TaskBalance - 1
_, err = tx.Exec(ctx, `
UPDATE account_usage
SET task_balance = $2, updated_at = NOW()
WHERE account_id = $1
`, userID, newBalance)
if err != nil {
return nil, err
}
}
// Create task record
taskID := uuid.New()
_, err = tx.Exec(ctx, `
INSERT INTO tasks (id, account_id, task_type, consumed, created_at)
VALUES ($1, $2, $3, true, NOW())
`, taskID, userID, taskType)
if err != nil {
return nil, err
}
// Commit transaction
if err = tx.Commit(ctx); err != nil {
return nil, err
}
return &models.ConsumeTaskResponse{
Success: true,
TaskID: taskID.String(),
TasksRemaining: newBalance,
}, nil
}
// GetTaskUsageInfo returns formatted task usage info for display
func (s *TaskService) GetTaskUsageInfo(ctx context.Context, userID uuid.UUID) (*models.TaskUsageInfo, error) {
usage, err := s.GetAccountUsage(ctx, userID)
if err != nil {
return nil, err
}
// Check for Fair Use mode (Premium)
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
if plan != nil && plan.Features.FairUseMode {
return &models.TaskUsageInfo{
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
InfoText: "Unbegrenzte Aufgaben (Fair Use)",
TooltipText: "Im Premium-Tarif gibt es keine praktische Begrenzung.",
}, nil
}
return &models.TaskUsageInfo{
TasksAvailable: usage.TaskBalance,
MaxTasks: usage.MaxTaskBalance,
InfoText: fmt.Sprintf("Aufgaben verfuegbar: %d von max. %d", usage.TaskBalance, usage.MaxTaskBalance),
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
}, nil
}
// UpdatePlanForUser updates the plan and adjusts allowances
func (s *TaskService) UpdatePlanForUser(ctx context.Context, userID uuid.UUID, newPlanID models.PlanID) error {
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
if err != nil || plan == nil {
return fmt.Errorf("plan not found: %s", newPlanID)
}
// Update account usage with new plan limits
query := `
UPDATE account_usage
SET plan = $2,
monthly_task_allowance = $3,
max_task_balance = $4,
updated_at = NOW()
WHERE account_id = $1
`
_, err = s.db.Pool.Exec(ctx, query,
userID, newPlanID, plan.Features.MonthlyTaskAllowance, plan.Features.MaxTaskBalance)
return err
}
// GetTaskHistory returns task history for a user
func (s *TaskService) GetTaskHistory(ctx context.Context, userID uuid.UUID, limit int) ([]models.Task, error) {
if limit <= 0 {
limit = 50
}
query := `
SELECT id, account_id, task_type, created_at, consumed
FROM tasks
WHERE account_id = $1
ORDER BY created_at DESC
LIMIT $2
`
rows, err := s.db.Pool.Query(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var tasks []models.Task
for rows.Next() {
var task models.Task
err := rows.Scan(&task.ID, &task.AccountID, &task.TaskType, &task.CreatedAt, &task.Consumed)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
return tasks, nil
}
@@ -0,0 +1,397 @@
package services
import (
"testing"
"time"
)
func TestMonthsBetween(t *testing.T) {
tests := []struct {
name string
start time.Time
end time.Time
expected int
}{
{
name: "Same day",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
expected: 0,
},
{
name: "Less than one month",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 10, 0, 0, 0, 0, time.UTC),
expected: 0,
},
{
name: "Exactly one month",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
expected: 1,
},
{
name: "One month and one day",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 16, 0, 0, 0, 0, time.UTC),
expected: 1,
},
{
name: "Two months",
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 3, 15, 0, 0, 0, 0, time.UTC),
expected: 2,
},
{
name: "Five months exactly",
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC),
expected: 5,
},
{
name: "Year boundary",
start: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
expected: 3,
},
{
name: "Leap year February to March",
start: time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC),
end: time.Date(2024, 3, 29, 0, 0, 0, 0, time.UTC),
expected: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := monthsBetween(tt.start, tt.end)
if result != tt.expected {
t.Errorf("monthsBetween(%v, %v) = %d, expected %d",
tt.start.Format("2006-01-02"), tt.end.Format("2006-01-02"),
result, tt.expected)
}
})
}
}
func TestCarryoverLogic(t *testing.T) {
// Test the carryover calculation logic
tests := []struct {
name string
currentBalance int
monthlyAllowance int
maxBalance int
monthsSinceRenewal int
expectedNewBalance int
}{
{
name: "Normal renewal - add allowance",
currentBalance: 50,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 80,
},
{
name: "Two months missed",
currentBalance: 50,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 2,
expectedNewBalance: 110,
},
{
name: "Cap at max balance",
currentBalance: 140,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 150,
},
{
name: "Already at max - no change",
currentBalance: 150,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 150,
},
{
name: "Multiple months - cap applies",
currentBalance: 100,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 5,
expectedNewBalance: 150,
},
{
name: "Empty balance - add one month",
currentBalance: 0,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 1,
expectedNewBalance: 30,
},
{
name: "Empty balance - add five months",
currentBalance: 0,
monthlyAllowance: 30,
maxBalance: 150,
monthsSinceRenewal: 5,
expectedNewBalance: 150,
},
{
name: "Standard plan - normal case",
currentBalance: 200,
monthlyAllowance: 100,
maxBalance: 500,
monthsSinceRenewal: 1,
expectedNewBalance: 300,
},
{
name: "Premium plan - Fair Use",
currentBalance: 1000,
monthlyAllowance: 1000,
maxBalance: 5000,
monthsSinceRenewal: 1,
expectedNewBalance: 2000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the carryover logic
newBalance := tt.currentBalance
for i := 0; i < tt.monthsSinceRenewal; i++ {
newBalance += tt.monthlyAllowance
if newBalance > tt.maxBalance {
newBalance = tt.maxBalance
break
}
}
if newBalance != tt.expectedNewBalance {
t.Errorf("Carryover for balance=%d, allowance=%d, max=%d, months=%d = %d, expected %d",
tt.currentBalance, tt.monthlyAllowance, tt.maxBalance, tt.monthsSinceRenewal,
newBalance, tt.expectedNewBalance)
}
})
}
}
func TestTaskBalanceAfterConsumption(t *testing.T) {
tests := []struct {
name string
currentBalance int
tasksToConsume int
expectedBalance int
shouldBeAllowed bool
}{
{
name: "Normal consumption",
currentBalance: 50,
tasksToConsume: 1,
expectedBalance: 49,
shouldBeAllowed: true,
},
{
name: "Last task",
currentBalance: 1,
tasksToConsume: 1,
expectedBalance: 0,
shouldBeAllowed: true,
},
{
name: "Empty balance - not allowed",
currentBalance: 0,
tasksToConsume: 1,
expectedBalance: 0,
shouldBeAllowed: false,
},
{
name: "Multiple tasks",
currentBalance: 50,
tasksToConsume: 5,
expectedBalance: 45,
shouldBeAllowed: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test if allowed
allowed := tt.currentBalance > 0
if allowed != tt.shouldBeAllowed {
t.Errorf("Task allowed with balance=%d: got %v, expected %v",
tt.currentBalance, allowed, tt.shouldBeAllowed)
}
// Test balance calculation
if allowed {
newBalance := tt.currentBalance - tt.tasksToConsume
if newBalance != tt.expectedBalance {
t.Errorf("Balance after consuming %d tasks from %d: got %d, expected %d",
tt.tasksToConsume, tt.currentBalance, newBalance, tt.expectedBalance)
}
}
})
}
}
func TestTaskServiceErrors(t *testing.T) {
// Test error constants
if ErrTaskLimitReached == nil {
t.Error("ErrTaskLimitReached should not be nil")
}
if ErrTaskLimitReached.Error() != "TASK_LIMIT_REACHED" {
t.Errorf("ErrTaskLimitReached should be 'TASK_LIMIT_REACHED', got '%s'", ErrTaskLimitReached.Error())
}
if ErrNoSubscription == nil {
t.Error("ErrNoSubscription should not be nil")
}
if ErrNoSubscription.Error() != "NO_SUBSCRIPTION" {
t.Errorf("ErrNoSubscription should be 'NO_SUBSCRIPTION', got '%s'", ErrNoSubscription.Error())
}
}
func TestRenewalDateCalculation(t *testing.T) {
tests := []struct {
name string
lastRenewal time.Time
monthsToAdd int
expectedRenewal time.Time
}{
{
name: "Add one month",
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
monthsToAdd: 1,
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "Add three months",
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
monthsToAdd: 3,
expectedRenewal: time.Date(2025, 4, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "Year boundary",
lastRenewal: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
monthsToAdd: 3,
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
},
{
name: "End of month adjustment",
lastRenewal: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
monthsToAdd: 1,
// Go's AddDate handles this - February doesn't have 31 days
expectedRenewal: time.Date(2025, 3, 3, 0, 0, 0, 0, time.UTC), // Feb 31 -> March 3
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.lastRenewal.AddDate(0, tt.monthsToAdd, 0)
if !result.Equal(tt.expectedRenewal) {
t.Errorf("AddDate(%v, %d months) = %v, expected %v",
tt.lastRenewal.Format("2006-01-02"), tt.monthsToAdd,
result.Format("2006-01-02"), tt.expectedRenewal.Format("2006-01-02"))
}
})
}
}
func TestFairUseModeLogic(t *testing.T) {
// Test that Fair Use mode always allows tasks regardless of balance
tests := []struct {
name string
fairUseMode bool
balance int
shouldAllow bool
}{
{
name: "Fair Use - zero balance still allowed",
fairUseMode: true,
balance: 0,
shouldAllow: true,
},
{
name: "Fair Use - normal balance allowed",
fairUseMode: true,
balance: 1000,
shouldAllow: true,
},
{
name: "Not Fair Use - zero balance not allowed",
fairUseMode: false,
balance: 0,
shouldAllow: false,
},
{
name: "Not Fair Use - positive balance allowed",
fairUseMode: false,
balance: 50,
shouldAllow: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the check logic
var allowed bool
if tt.fairUseMode {
allowed = true // Fair Use always allows
} else {
allowed = tt.balance > 0
}
if allowed != tt.shouldAllow {
t.Errorf("FairUseMode=%v, balance=%d: allowed=%v, expected=%v",
tt.fairUseMode, tt.balance, allowed, tt.shouldAllow)
}
})
}
}
func TestBalanceDecrementLogic(t *testing.T) {
// Test that Fair Use mode doesn't decrement balance
tests := []struct {
name string
fairUseMode bool
initialBalance int
expectedAfter int
}{
{
name: "Normal plan - decrement",
fairUseMode: false,
initialBalance: 50,
expectedAfter: 49,
},
{
name: "Fair Use - no decrement",
fairUseMode: true,
initialBalance: 1000,
expectedAfter: 1000,
},
{
name: "Normal plan - last task",
fairUseMode: false,
initialBalance: 1,
expectedAfter: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newBalance := tt.initialBalance
if !tt.fairUseMode {
newBalance = tt.initialBalance - 1
}
if newBalance != tt.expectedAfter {
t.Errorf("FairUseMode=%v, initial=%d: got %d, expected %d",
tt.fairUseMode, tt.initialBalance, newBalance, tt.expectedAfter)
}
})
}
}
@@ -0,0 +1,194 @@
package services
import (
"context"
"fmt"
"time"
"github.com/breakpilot/billing-service/internal/database"
"github.com/breakpilot/billing-service/internal/models"
"github.com/google/uuid"
)
// UsageService handles usage tracking operations
type UsageService struct {
db *database.DB
entitlementService *EntitlementService
}
// NewUsageService creates a new UsageService
func NewUsageService(db *database.DB, entitlementService *EntitlementService) *UsageService {
return &UsageService{
db: db,
entitlementService: entitlementService,
}
}
// TrackUsage tracks usage for a user
func (s *UsageService) TrackUsage(ctx context.Context, userIDStr, usageType string, quantity int) error {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return fmt.Errorf("invalid user ID: %w", err)
}
// Get current period start (beginning of current month)
now := time.Now()
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
// Upsert usage summary
query := `
INSERT INTO usage_summary (user_id, usage_type, period_start, total_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, usage_type, period_start) DO UPDATE SET
total_count = usage_summary.total_count + EXCLUDED.total_count,
updated_at = NOW()
`
_, err = s.db.Pool.Exec(ctx, query, userID, usageType, periodStart, quantity)
if err != nil {
return fmt.Errorf("failed to track usage: %w", err)
}
// Also update entitlements cache
return s.entitlementService.IncrementUsage(ctx, userID, usageType, quantity)
}
// GetUsageSummary returns usage summary for a user
func (s *UsageService) GetUsageSummary(ctx context.Context, userID uuid.UUID) (*models.UsageInfo, error) {
// Get entitlements (which include current usage)
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
if err != nil || ent == nil {
return nil, err
}
// Calculate percentages
aiPercent := 0.0
if ent.AIRequestsLimit > 0 {
aiPercent = float64(ent.AIRequestsUsed) / float64(ent.AIRequestsLimit) * 100
}
docPercent := 0.0
if ent.DocumentsLimit > 0 {
docPercent = float64(ent.DocumentsUsed) / float64(ent.DocumentsLimit) * 100
}
// Get period dates
now := time.Now()
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
periodEnd := periodStart.AddDate(0, 1, 0).Add(-time.Second)
return &models.UsageInfo{
AIRequestsUsed: ent.AIRequestsUsed,
AIRequestsLimit: ent.AIRequestsLimit,
AIRequestsPercent: aiPercent,
DocumentsUsed: ent.DocumentsUsed,
DocumentsLimit: ent.DocumentsLimit,
DocumentsPercent: docPercent,
PeriodStart: periodStart.Format("2006-01-02"),
PeriodEnd: periodEnd.Format("2006-01-02"),
}, nil
}
// CheckUsageAllowed checks if a user is allowed to perform a usage action
func (s *UsageService) CheckUsageAllowed(ctx context.Context, userIDStr, usageType string) (*models.CheckUsageResponse, error) {
userID, err := uuid.Parse(userIDStr)
if err != nil {
return &models.CheckUsageResponse{
Allowed: false,
Message: "Invalid user ID",
}, nil
}
// Get entitlements
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
if err != nil {
return &models.CheckUsageResponse{
Allowed: false,
Message: "Failed to get entitlements",
}, nil
}
if ent == nil {
return &models.CheckUsageResponse{
Allowed: false,
Message: "No subscription found",
}, nil
}
var currentUsage, limit int
switch usageType {
case "ai_request":
currentUsage = ent.AIRequestsUsed
limit = ent.AIRequestsLimit
case "document_created":
currentUsage = ent.DocumentsUsed
limit = ent.DocumentsLimit
default:
return &models.CheckUsageResponse{
Allowed: true,
Message: "Unknown usage type - allowing",
}, nil
}
remaining := limit - currentUsage
allowed := remaining > 0
response := &models.CheckUsageResponse{
Allowed: allowed,
CurrentUsage: currentUsage,
Limit: limit,
Remaining: remaining,
}
if !allowed {
response.Message = fmt.Sprintf("Usage limit reached for %s (%d/%d)", usageType, currentUsage, limit)
}
return response, nil
}
// GetUsageHistory returns usage history for a user
func (s *UsageService) GetUsageHistory(ctx context.Context, userID uuid.UUID, months int) ([]models.UsageSummary, error) {
query := `
SELECT id, user_id, usage_type, period_start, total_count, created_at, updated_at
FROM usage_summary
WHERE user_id = $1
AND period_start >= $2
ORDER BY period_start DESC, usage_type
`
// Calculate start date
startDate := time.Now().AddDate(0, -months, 0)
startDate = time.Date(startDate.Year(), startDate.Month(), 1, 0, 0, 0, 0, time.UTC)
rows, err := s.db.Pool.Query(ctx, query, userID, startDate)
if err != nil {
return nil, err
}
defer rows.Close()
var summaries []models.UsageSummary
for rows.Next() {
var summary models.UsageSummary
err := rows.Scan(
&summary.ID, &summary.UserID, &summary.UsageType,
&summary.PeriodStart, &summary.TotalCount,
&summary.CreatedAt, &summary.UpdatedAt,
)
if err != nil {
return nil, err
}
summaries = append(summaries, summary)
}
return summaries, nil
}
// ResetPeriodUsage resets usage for a new billing period
func (s *UsageService) ResetPeriodUsage(ctx context.Context, userID uuid.UUID) error {
now := time.Now()
newPeriodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
newPeriodEnd := newPeriodStart.AddDate(0, 1, 0).Add(-time.Second)
return s.entitlementService.ResetUsageCounters(ctx, userID, &newPeriodStart, &newPeriodEnd)
}
+171
View File
@@ -0,0 +1,171 @@
# BreakPilot BPMN Prozesse
Dieses Verzeichnis enthaelt die BPMN 2.0 Prozessdefinitionen fuer BreakPilot.
## Prozess-Uebersicht
| Datei | Prozess | Beschreibung | Status |
|-------|---------|--------------|--------|
| `classroom-lesson.bpmn` | Unterrichtsstunde | Phasenbasierte Unterrichtssteuerung | Entwurf |
| `consent-document.bpmn` | Consent-Dokument | DSB-Approval, Publishing, Monitoring | Entwurf |
| `klausur-korrektur.bpmn` | Klausurkorrektur | OCR, AI-Grading, Export | Entwurf |
| `dsr-request.bpmn` | DSR/GDPR | Betroffenenanfragen (Art. 15-20) | Entwurf |
## Verwendung
### Im BPMN Editor laden
1. Navigiere zu http://localhost:3000/admin/workflow oder http://localhost:8000/app (Workflow)
2. Klicke "Oeffnen" und waehle eine .bpmn Datei
3. Bearbeite den Prozess im Editor
4. Speichere und deploye zu Camunda
### In Camunda deployen
```bash
# Camunda starten (falls noch nicht aktiv)
docker compose --profile bpmn up -d camunda
# Prozess deployen via API
curl -X POST http://localhost:8000/api/bpmn/deployment/create \
-F "deployment-name=breakpilot-processes" \
-F "data=@classroom-lesson.bpmn"
```
### Prozess starten
```bash
# Unterrichtsstunde starten
curl -X POST http://localhost:8000/api/bpmn/process-definition/ClassroomLessonProcess/start \
-H "Content-Type: application/json" \
-d '{
"variables": {
"teacherId": {"value": "teacher-123"},
"classId": {"value": "class-7a"},
"subject": {"value": "Mathematik"}
}
}'
```
## Prozess-Details
### 1. Classroom Lesson (classroom-lesson.bpmn)
**Phasen:**
- Einstieg (Motivation, Problemstellung)
- Erarbeitung I (Einzelarbeit, Partnerarbeit, Gruppenarbeit)
- Erarbeitung II (optional)
- Sicherung (Tafel, Digital, Schueler-Praesentation)
- Transfer (Anwendungsaufgaben)
- Reflexion & Abschluss (Hausaufgaben, Notizen)
**Service Tasks:**
- `contentSuggestionDelegate` - Content-Vorschlaege basierend auf Phase
- `lessonProtocolDelegate` - Automatisches Stundenprotokoll
**Timer Events:**
- Phasen-Timer mit Warnungen
---
### 2. Consent Document (consent-document.bpmn)
**Workflow:**
1. Dokument bearbeiten (Autor)
2. DSB-Pruefung (Vier-Augen-Prinzip)
3. Bei Ablehnung: Zurueck an Autor
4. Bei Genehmigung: Veroeffentlichen
5. Benutzer benachrichtigen
6. Consent sammeln mit Deadline-Timer
7. Monitoring-Subprocess fuer jaehrliche Erneuerung
8. Archivierung bei neuer Version
**Service Tasks:**
- `publishConsentDocumentDelegate`
- `notifyUsersDelegate`
- `sendConsentReminderDelegate`
- `checkConsentStatusDelegate`
- `triggerRenewalDelegate`
- `archiveDocumentDelegate`
---
### 3. Klausur Korrektur (klausur-korrektur.bpmn)
**Workflow:**
1. OCR-Verarbeitung der hochgeladenen Klausuren
2. Qualitaets-Check (Confidence >= 85%)
3. Bei schlechter Qualitaet: Manuelle Nachbearbeitung
4. Erwartungshorizont definieren
5. AI-Bewertung mit Claude
6. Lehrer-Review mit Anpassungsmoeglichkeit
7. Noten berechnen (15-Punkte-Skala)
8. Notenbuch aktualisieren
9. Export (PDF, Excel)
10. Optional: Eltern benachrichtigen
11. Archivierung
**Service Tasks:**
- `ocrProcessingDelegate`
- `ocrQualityCheckDelegate`
- `aiGradingDelegate`
- `calculateGradesDelegate`
- `updateGradebookDelegate`
- `generateExportDelegate`
- `notifyParentsDelegate`
- `archiveExamDelegate`
- `deadlineWarningDelegate`
---
### 4. DSR Request (dsr-request.bpmn)
**GDPR Artikel:**
- Art. 15: Recht auf Auskunft (Access)
- Art. 16: Recht auf Berichtigung (Rectification)
- Art. 17: Recht auf Loeschung (Deletion)
- Art. 20: Recht auf Datenuebertragbarkeit (Portability)
**Workflow:**
1. Anfrage validieren
2. Bei ungueltig: Ablehnen
3. Je nach Typ:
- Access: Daten sammeln → Anonymisieren → Review → Export
- Deletion: Identifizieren → Genehmigen → Loeschen → Verifizieren
- Portability: Sammeln → JSON formatieren
- Rectification: Pruefen → Anwenden
4. Betroffenen benachrichtigen
5. Audit Log erstellen
**30-Tage Frist:**
- Timer-Event nach 25 Tagen fuer Eskalation an DSB
**Service Tasks:**
- `validateDSRDelegate`
- `rejectDSRDelegate`
- `collectUserDataDelegate`
- `anonymizeDataDelegate`
- `prepareExportDelegate`
- `identifyUserDataDelegate`
- `executeDataDeletionDelegate`
- `verifyDeletionDelegate`
- `collectPortableDataDelegate`
- `formatPortableDataDelegate`
- `applyRectificationDelegate`
- `notifyDataSubjectDelegate`
- `createAuditLogDelegate`
- `escalateToDSBDelegate`
## Naechste Schritte
1. **Delegates implementieren**: Java/Python Service Tasks
2. **Camunda Connect**: REST-Aufrufe zu Backend-APIs
3. **User Task Forms**: Camunda Forms oder Custom UI
4. **Timer konfigurieren**: Realistische Dauern setzen
5. **Testing**: Prozesse mit Testdaten durchlaufen
## Referenzen
- [Camunda 7 Docs](https://docs.camunda.org/manual/7.21/)
- [BPMN 2.0 Spec](https://www.omg.org/spec/BPMN/2.0/)
- [bpmn-js](https://bpmn.io/toolkit/bpmn-js/)
+181
View File
@@ -0,0 +1,181 @@
<?xml version="1.0" encoding="UTF-8"?>
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
id="Definitions_Classroom"
targetNamespace="http://breakpilot.de/bpmn/classroom">
<bpmn:process id="ClassroomLessonProcess" name="Unterrichtsstunde" isExecutable="true">
<!-- Start Event -->
<bpmn:startEvent id="start" name="Stunde beginnen">
<bpmn:outgoing>flow_to_einstieg</bpmn:outgoing>
</bpmn:startEvent>
<!-- Phase 1: Einstieg -->
<bpmn:userTask id="phase_einstieg" name="Einstiegsphase" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="10" />
<camunda:formField id="activity" label="Aktivitaet" type="string" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_einstieg</bpmn:incoming>
<bpmn:outgoing>flow_to_erarbeitung1</bpmn:outgoing>
</bpmn:userTask>
<!-- Service Task: Content-Vorschlaege Einstieg -->
<bpmn:serviceTask id="suggest_einstieg" name="Content-Vorschlaege" camunda:delegateExpression="${contentSuggestionDelegate}">
<bpmn:incoming>flow_suggest_einstieg</bpmn:incoming>
<bpmn:outgoing>flow_from_suggest_einstieg</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Phase 2: Erarbeitung I -->
<bpmn:userTask id="phase_erarbeitung1" name="Erarbeitung I" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="15" />
<camunda:formField id="sozialform" label="Sozialform" type="enum">
<camunda:value id="einzelarbeit" name="Einzelarbeit" />
<camunda:value id="partnerarbeit" name="Partnerarbeit" />
<camunda:value id="gruppenarbeit" name="Gruppenarbeit" />
</camunda:formField>
<camunda:formField id="content" label="Lerneinheit" type="string" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_erarbeitung1</bpmn:incoming>
<bpmn:outgoing>flow_to_erarbeitung_gateway</bpmn:outgoing>
</bpmn:userTask>
<!-- Gateway: Weitere Erarbeitung? -->
<bpmn:exclusiveGateway id="erarbeitung_gateway" name="Weitere Erarbeitung?">
<bpmn:incoming>flow_to_erarbeitung_gateway</bpmn:incoming>
<bpmn:outgoing>flow_to_erarbeitung2</bpmn:outgoing>
<bpmn:outgoing>flow_to_sicherung</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Phase 2b: Erarbeitung II (optional) -->
<bpmn:userTask id="phase_erarbeitung2" name="Erarbeitung II" camunda:assignee="${teacherId}">
<bpmn:incoming>flow_to_erarbeitung2</bpmn:incoming>
<bpmn:outgoing>flow_from_erarbeitung2</bpmn:outgoing>
</bpmn:userTask>
<!-- Phase 3: Sicherung -->
<bpmn:userTask id="phase_sicherung" name="Sicherungsphase" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="10" />
<camunda:formField id="method" label="Methode" type="enum">
<camunda:value id="tafel" name="Tafelanschrieb" />
<camunda:value id="digital" name="Digitale Zusammenfassung" />
<camunda:value id="schueler" name="Schueler-Praesentation" />
</camunda:formField>
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_sicherung</bpmn:incoming>
<bpmn:incoming>flow_from_erarbeitung2</bpmn:incoming>
<bpmn:outgoing>flow_to_transfer</bpmn:outgoing>
</bpmn:userTask>
<!-- Phase 4: Transfer -->
<bpmn:userTask id="phase_transfer" name="Transferphase" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="8" />
<camunda:formField id="aufgabe" label="Transfer-Aufgabe" type="string" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_transfer</bpmn:incoming>
<bpmn:outgoing>flow_to_reflexion</bpmn:outgoing>
</bpmn:userTask>
<!-- Phase 5: Reflexion -->
<bpmn:userTask id="phase_reflexion" name="Reflexion &amp; Abschluss" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="5" />
<camunda:formField id="hausaufgabe" label="Hausaufgabe" type="string" />
<camunda:formField id="notizen" label="Stundennotizen" type="string" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_reflexion</bpmn:incoming>
<bpmn:outgoing>flow_to_protokoll</bpmn:outgoing>
</bpmn:userTask>
<!-- Service Task: Stundenprotokoll -->
<bpmn:serviceTask id="create_protokoll" name="Stundenprotokoll erstellen" camunda:delegateExpression="${lessonProtocolDelegate}">
<bpmn:incoming>flow_to_protokoll</bpmn:incoming>
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
</bpmn:serviceTask>
<!-- End Event -->
<bpmn:endEvent id="end" name="Stunde beendet">
<bpmn:incoming>flow_to_end</bpmn:incoming>
</bpmn:endEvent>
<!-- Boundary Timer Events -->
<bpmn:boundaryEvent id="timer_einstieg" attachedToRef="phase_einstieg" cancelActivity="false">
<bpmn:timerEventDefinition>
<bpmn:timeDuration>PT${einstiegDuration}M</bpmn:timeDuration>
</bpmn:timerEventDefinition>
<bpmn:outgoing>flow_timer_warning</bpmn:outgoing>
</bpmn:boundaryEvent>
<!-- Sequence Flows -->
<bpmn:sequenceFlow id="flow_to_einstieg" sourceRef="start" targetRef="phase_einstieg" />
<bpmn:sequenceFlow id="flow_to_erarbeitung1" sourceRef="phase_einstieg" targetRef="phase_erarbeitung1" />
<bpmn:sequenceFlow id="flow_to_erarbeitung_gateway" sourceRef="phase_erarbeitung1" targetRef="erarbeitung_gateway" />
<bpmn:sequenceFlow id="flow_to_erarbeitung2" sourceRef="erarbeitung_gateway" targetRef="phase_erarbeitung2">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${needsMoreWork == true}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_to_sicherung" sourceRef="erarbeitung_gateway" targetRef="phase_sicherung">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${needsMoreWork == false}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_from_erarbeitung2" sourceRef="phase_erarbeitung2" targetRef="phase_sicherung" />
<bpmn:sequenceFlow id="flow_to_transfer" sourceRef="phase_sicherung" targetRef="phase_transfer" />
<bpmn:sequenceFlow id="flow_to_reflexion" sourceRef="phase_transfer" targetRef="phase_reflexion" />
<bpmn:sequenceFlow id="flow_to_protokoll" sourceRef="phase_reflexion" targetRef="create_protokoll" />
<bpmn:sequenceFlow id="flow_to_end" sourceRef="create_protokoll" targetRef="end" />
</bpmn:process>
<!-- BPMN Diagram -->
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="ClassroomLessonProcess">
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
<dc:Bounds x="152" y="102" width="36" height="36" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="phase_einstieg_di" bpmnElement="phase_einstieg">
<dc:Bounds x="240" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="phase_erarbeitung1_di" bpmnElement="phase_erarbeitung1">
<dc:Bounds x="390" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="erarbeitung_gateway_di" bpmnElement="erarbeitung_gateway" isMarkerVisible="true">
<dc:Bounds x="545" y="95" width="50" height="50" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="phase_erarbeitung2_di" bpmnElement="phase_erarbeitung2">
<dc:Bounds x="520" y="200" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="phase_sicherung_di" bpmnElement="phase_sicherung">
<dc:Bounds x="670" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="phase_transfer_di" bpmnElement="phase_transfer">
<dc:Bounds x="820" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="phase_reflexion_di" bpmnElement="phase_reflexion">
<dc:Bounds x="970" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="create_protokoll_di" bpmnElement="create_protokoll">
<dc:Bounds x="1120" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
<dc:Bounds x="1272" y="102" width="36" height="36" />
</bpmndi:BPMNShape>
</bpmndi:BPMNPlane>
</bpmndi:BPMNDiagram>
</bpmn:definitions>
+206
View File
@@ -0,0 +1,206 @@
<?xml version="1.0" encoding="UTF-8"?>
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
id="Definitions_Consent"
targetNamespace="http://breakpilot.de/bpmn/consent">
<bpmn:process id="ConsentDocumentProcess" name="Consent-Dokument Workflow" isExecutable="true">
<!-- Start Event -->
<bpmn:startEvent id="start" name="Dokument erstellt">
<bpmn:outgoing>flow_to_edit</bpmn:outgoing>
</bpmn:startEvent>
<!-- User Task: Dokument bearbeiten -->
<bpmn:userTask id="edit_document" name="Dokument bearbeiten" camunda:assignee="${authorId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="title" label="Titel" type="string" />
<camunda:formField id="type" label="Dokumenttyp" type="enum">
<camunda:value id="terms" name="AGB" />
<camunda:value id="privacy" name="Datenschutzerklaerung" />
<camunda:value id="cookies" name="Cookie-Richtlinie" />
<camunda:value id="consent_form" name="Einwilligungserklaerung" />
</camunda:formField>
<camunda:formField id="content" label="Inhalt" type="string" />
<camunda:formField id="version" label="Version" type="string" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_edit</bpmn:incoming>
<bpmn:incoming>flow_rejected_to_edit</bpmn:incoming>
<bpmn:outgoing>flow_to_review</bpmn:outgoing>
</bpmn:userTask>
<!-- User Task: DSB Review -->
<bpmn:userTask id="dsb_review" name="DSB Pruefung" camunda:candidateGroups="data_protection_officer">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="approved" label="Genehmigt" type="boolean" />
<camunda:formField id="comments" label="Kommentare" type="string" />
<camunda:formField id="legalCheck" label="Rechtliche Pruefung OK" type="boolean" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_review</bpmn:incoming>
<bpmn:outgoing>flow_to_approval_gateway</bpmn:outgoing>
</bpmn:userTask>
<!-- Gateway: Genehmigt? -->
<bpmn:exclusiveGateway id="approval_gateway" name="Genehmigt?">
<bpmn:incoming>flow_to_approval_gateway</bpmn:incoming>
<bpmn:outgoing>flow_approved</bpmn:outgoing>
<bpmn:outgoing>flow_rejected</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Service Task: Veroeffentlichen -->
<bpmn:serviceTask id="publish_document" name="Dokument veroeffentlichen" camunda:delegateExpression="${publishConsentDocumentDelegate}">
<bpmn:incoming>flow_approved</bpmn:incoming>
<bpmn:outgoing>flow_to_notify</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: Benutzer benachrichtigen -->
<bpmn:serviceTask id="notify_users" name="Benutzer benachrichtigen" camunda:delegateExpression="${notifyUsersDelegate}">
<bpmn:incoming>flow_to_notify</bpmn:incoming>
<bpmn:outgoing>flow_to_collect_consent</bpmn:outgoing>
</bpmn:serviceTask>
<!-- User Task: Consent sammeln (wartet auf Benutzer-Zustimmungen) -->
<bpmn:receiveTask id="collect_consent" name="Auf Zustimmungen warten">
<bpmn:incoming>flow_to_collect_consent</bpmn:incoming>
<bpmn:outgoing>flow_to_check_deadline</bpmn:outgoing>
</bpmn:receiveTask>
<!-- Boundary Timer: Deadline -->
<bpmn:boundaryEvent id="consent_deadline" attachedToRef="collect_consent" cancelActivity="false">
<bpmn:timerEventDefinition>
<bpmn:timeDuration>P${consentDeadlineDays}D</bpmn:timeDuration>
</bpmn:timerEventDefinition>
<bpmn:outgoing>flow_to_reminder</bpmn:outgoing>
</bpmn:boundaryEvent>
<!-- Service Task: Reminder senden -->
<bpmn:serviceTask id="send_reminder" name="Reminder senden" camunda:delegateExpression="${sendConsentReminderDelegate}">
<bpmn:incoming>flow_to_reminder</bpmn:incoming>
<bpmn:outgoing>flow_back_to_collect</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: Consent-Status pruefen -->
<bpmn:serviceTask id="check_consent_status" name="Consent-Status pruefen" camunda:delegateExpression="${checkConsentStatusDelegate}">
<bpmn:incoming>flow_to_check_deadline</bpmn:incoming>
<bpmn:outgoing>flow_to_active</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Intermediate Event: Dokument aktiv -->
<bpmn:intermediateThrowEvent id="document_active" name="Dokument aktiv">
<bpmn:incoming>flow_to_active</bpmn:incoming>
<bpmn:outgoing>flow_to_monitor</bpmn:outgoing>
</bpmn:intermediateThrowEvent>
<!-- Sub-Process: Monitoring -->
<bpmn:subProcess id="monitoring_subprocess" name="Consent Monitoring">
<bpmn:incoming>flow_to_monitor</bpmn:incoming>
<bpmn:outgoing>flow_to_archive</bpmn:outgoing>
<bpmn:startEvent id="monitoring_start" />
<!-- Event-based Gateway: Warte auf Events -->
<bpmn:eventBasedGateway id="event_gateway">
<bpmn:incoming>flow_from_monitoring_start</bpmn:incoming>
<bpmn:outgoing>flow_to_renewal_timer</bpmn:outgoing>
<bpmn:outgoing>flow_to_supersede_event</bpmn:outgoing>
</bpmn:eventBasedGateway>
<!-- Timer: Jaehrliche Erneuerung -->
<bpmn:intermediateCatchEvent id="renewal_timer" name="Erneuerungsdatum">
<bpmn:incoming>flow_to_renewal_timer</bpmn:incoming>
<bpmn:outgoing>flow_to_renewal_task</bpmn:outgoing>
<bpmn:timerEventDefinition>
<bpmn:timeDuration>P1Y</bpmn:timeDuration>
</bpmn:timerEventDefinition>
</bpmn:intermediateCatchEvent>
<!-- Message: Dokument ersetzt -->
<bpmn:intermediateCatchEvent id="supersede_event" name="Neue Version">
<bpmn:incoming>flow_to_supersede_event</bpmn:incoming>
<bpmn:outgoing>flow_to_monitoring_end</bpmn:outgoing>
<bpmn:messageEventDefinition messageRef="Message_Supersede" />
</bpmn:intermediateCatchEvent>
<bpmn:serviceTask id="trigger_renewal" name="Erneuerung anfordern" camunda:delegateExpression="${triggerRenewalDelegate}">
<bpmn:incoming>flow_to_renewal_task</bpmn:incoming>
<bpmn:outgoing>flow_back_to_gateway</bpmn:outgoing>
</bpmn:serviceTask>
<bpmn:endEvent id="monitoring_end" />
</bpmn:subProcess>
<!-- Service Task: Archivieren -->
<bpmn:serviceTask id="archive_document" name="Dokument archivieren" camunda:delegateExpression="${archiveDocumentDelegate}">
<bpmn:incoming>flow_to_archive</bpmn:incoming>
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
</bpmn:serviceTask>
<!-- End Event -->
<bpmn:endEvent id="end" name="Workflow beendet">
<bpmn:incoming>flow_to_end</bpmn:incoming>
</bpmn:endEvent>
<!-- Sequence Flows -->
<bpmn:sequenceFlow id="flow_to_edit" sourceRef="start" targetRef="edit_document" />
<bpmn:sequenceFlow id="flow_to_review" sourceRef="edit_document" targetRef="dsb_review" />
<bpmn:sequenceFlow id="flow_to_approval_gateway" sourceRef="dsb_review" targetRef="approval_gateway" />
<bpmn:sequenceFlow id="flow_approved" sourceRef="approval_gateway" targetRef="publish_document">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == true}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_rejected" sourceRef="approval_gateway" targetRef="edit_document">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == false}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_rejected_to_edit" sourceRef="approval_gateway" targetRef="edit_document" />
<bpmn:sequenceFlow id="flow_to_notify" sourceRef="publish_document" targetRef="notify_users" />
<bpmn:sequenceFlow id="flow_to_collect_consent" sourceRef="notify_users" targetRef="collect_consent" />
<bpmn:sequenceFlow id="flow_to_reminder" sourceRef="consent_deadline" targetRef="send_reminder" />
<bpmn:sequenceFlow id="flow_to_check_deadline" sourceRef="collect_consent" targetRef="check_consent_status" />
<bpmn:sequenceFlow id="flow_to_active" sourceRef="check_consent_status" targetRef="document_active" />
<bpmn:sequenceFlow id="flow_to_monitor" sourceRef="document_active" targetRef="monitoring_subprocess" />
<bpmn:sequenceFlow id="flow_to_archive" sourceRef="monitoring_subprocess" targetRef="archive_document" />
<bpmn:sequenceFlow id="flow_to_end" sourceRef="archive_document" targetRef="end" />
</bpmn:process>
<!-- Messages -->
<bpmn:message id="Message_Supersede" name="DocumentSuperseded" />
<!-- BPMN Diagram -->
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="ConsentDocumentProcess">
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
<dc:Bounds x="152" y="102" width="36" height="36" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="edit_document_di" bpmnElement="edit_document">
<dc:Bounds x="240" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="dsb_review_di" bpmnElement="dsb_review">
<dc:Bounds x="390" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="approval_gateway_di" bpmnElement="approval_gateway" isMarkerVisible="true">
<dc:Bounds x="545" y="95" width="50" height="50" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="publish_document_di" bpmnElement="publish_document">
<dc:Bounds x="650" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="notify_users_di" bpmnElement="notify_users">
<dc:Bounds x="800" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="collect_consent_di" bpmnElement="collect_consent">
<dc:Bounds x="950" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
<dc:Bounds x="1502" y="102" width="36" height="36" />
</bpmndi:BPMNShape>
</bpmndi:BPMNPlane>
</bpmndi:BPMNDiagram>
</bpmn:definitions>
+222
View File
@@ -0,0 +1,222 @@
<?xml version="1.0" encoding="UTF-8"?>
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
id="Definitions_DSR"
targetNamespace="http://breakpilot.de/bpmn/dsr">
<bpmn:process id="DSRRequestProcess" name="Data Subject Request (GDPR)" isExecutable="true">
<!-- Start Event -->
<bpmn:startEvent id="start" name="DSR eingereicht">
<bpmn:outgoing>flow_to_validate</bpmn:outgoing>
</bpmn:startEvent>
<!-- Service Task: Anfrage validieren -->
<bpmn:serviceTask id="validate_request" name="Anfrage validieren" camunda:delegateExpression="${validateDSRDelegate}">
<bpmn:incoming>flow_to_validate</bpmn:incoming>
<bpmn:outgoing>flow_to_validation_gateway</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Gateway: Anfrage gueltig? -->
<bpmn:exclusiveGateway id="validation_gateway" name="Anfrage gueltig?">
<bpmn:incoming>flow_to_validation_gateway</bpmn:incoming>
<bpmn:outgoing>flow_valid</bpmn:outgoing>
<bpmn:outgoing>flow_invalid</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Service Task: Anfrage ablehnen -->
<bpmn:serviceTask id="reject_request" name="Anfrage ablehnen" camunda:delegateExpression="${rejectDSRDelegate}">
<bpmn:incoming>flow_invalid</bpmn:incoming>
<bpmn:outgoing>flow_to_reject_end</bpmn:outgoing>
</bpmn:serviceTask>
<!-- End Event: Abgelehnt -->
<bpmn:endEvent id="end_rejected" name="DSR abgelehnt">
<bpmn:incoming>flow_to_reject_end</bpmn:incoming>
</bpmn:endEvent>
<!-- Gateway: Request-Typ -->
<bpmn:exclusiveGateway id="type_gateway" name="Request-Typ?">
<bpmn:incoming>flow_valid</bpmn:incoming>
<bpmn:outgoing>flow_access</bpmn:outgoing>
<bpmn:outgoing>flow_deletion</bpmn:outgoing>
<bpmn:outgoing>flow_portability</bpmn:outgoing>
<bpmn:outgoing>flow_rectification</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Sub-Process: Daten-Zugang (Art. 15) -->
<bpmn:subProcess id="access_subprocess" name="Daten-Zugang (Art. 15)">
<bpmn:incoming>flow_access</bpmn:incoming>
<bpmn:outgoing>flow_access_done</bpmn:outgoing>
<bpmn:startEvent id="access_start" />
<bpmn:serviceTask id="collect_data" name="Daten sammeln" camunda:delegateExpression="${collectUserDataDelegate}" />
<bpmn:serviceTask id="anonymize_data" name="Daten anonymisieren" camunda:delegateExpression="${anonymizeDataDelegate}" />
<bpmn:userTask id="review_data" name="Daten pruefen" camunda:candidateGroups="data_protection_officer">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="dataComplete" label="Daten vollstaendig" type="boolean" />
<camunda:formField id="sensitivePII" label="Sensible PII entfernt" type="boolean" />
</camunda:formData>
</bpmn:extensionElements>
</bpmn:userTask>
<bpmn:serviceTask id="prepare_export" name="Export vorbereiten" camunda:delegateExpression="${prepareExportDelegate}" />
<bpmn:endEvent id="access_end" />
</bpmn:subProcess>
<!-- Sub-Process: Daten-Loeschung (Art. 17) -->
<bpmn:subProcess id="deletion_subprocess" name="Daten-Loeschung (Art. 17)">
<bpmn:incoming>flow_deletion</bpmn:incoming>
<bpmn:outgoing>flow_deletion_done</bpmn:outgoing>
<bpmn:startEvent id="deletion_start" />
<bpmn:serviceTask id="identify_data" name="Daten identifizieren" camunda:delegateExpression="${identifyUserDataDelegate}" />
<bpmn:userTask id="approve_deletion" name="Loeschung genehmigen" camunda:candidateGroups="data_protection_officer">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="legalRetention" label="Aufbewahrungspflicht?" type="boolean" />
<camunda:formField id="deletionApproved" label="Loeschung genehmigt" type="boolean" />
</camunda:formData>
</bpmn:extensionElements>
</bpmn:userTask>
<bpmn:serviceTask id="execute_deletion" name="Daten loeschen" camunda:delegateExpression="${executeDataDeletionDelegate}" />
<bpmn:serviceTask id="verify_deletion" name="Loeschung verifizieren" camunda:delegateExpression="${verifyDeletionDelegate}" />
<bpmn:endEvent id="deletion_end" />
</bpmn:subProcess>
<!-- Sub-Process: Daten-Portabilitaet (Art. 20) -->
<bpmn:subProcess id="portability_subprocess" name="Daten-Portabilitaet (Art. 20)">
<bpmn:incoming>flow_portability</bpmn:incoming>
<bpmn:outgoing>flow_portability_done</bpmn:outgoing>
<bpmn:startEvent id="portability_start" />
<bpmn:serviceTask id="collect_portable_data" name="Portable Daten sammeln" camunda:delegateExpression="${collectPortableDataDelegate}" />
<bpmn:serviceTask id="format_data" name="Daten formatieren (JSON)" camunda:delegateExpression="${formatPortableDataDelegate}" />
<bpmn:endEvent id="portability_end" />
</bpmn:subProcess>
<!-- Sub-Process: Berichtigung (Art. 16) -->
<bpmn:subProcess id="rectification_subprocess" name="Berichtigung (Art. 16)">
<bpmn:incoming>flow_rectification</bpmn:incoming>
<bpmn:outgoing>flow_rectification_done</bpmn:outgoing>
<bpmn:startEvent id="rectification_start" />
<bpmn:userTask id="review_rectification" name="Berichtigung pruefen" camunda:candidateGroups="data_protection_officer" />
<bpmn:serviceTask id="apply_rectification" name="Daten berichtigen" camunda:delegateExpression="${applyRectificationDelegate}" />
<bpmn:endEvent id="rectification_end" />
</bpmn:subProcess>
<!-- Gateway: Zusammenfuehrung -->
<bpmn:exclusiveGateway id="merge_gateway">
<bpmn:incoming>flow_access_done</bpmn:incoming>
<bpmn:incoming>flow_deletion_done</bpmn:incoming>
<bpmn:incoming>flow_portability_done</bpmn:incoming>
<bpmn:incoming>flow_rectification_done</bpmn:incoming>
<bpmn:outgoing>flow_to_notify</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Service Task: Betroffenen benachrichtigen -->
<bpmn:serviceTask id="notify_subject" name="Betroffenen benachrichtigen" camunda:delegateExpression="${notifyDataSubjectDelegate}">
<bpmn:incoming>flow_to_notify</bpmn:incoming>
<bpmn:outgoing>flow_to_audit</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: Audit Log -->
<bpmn:serviceTask id="create_audit" name="Audit Log erstellen" camunda:delegateExpression="${createAuditLogDelegate}">
<bpmn:incoming>flow_to_audit</bpmn:incoming>
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
</bpmn:serviceTask>
<!-- End Event -->
<bpmn:endEvent id="end" name="DSR abgeschlossen">
<bpmn:incoming>flow_to_end</bpmn:incoming>
</bpmn:endEvent>
<!-- Boundary Timer: 30-Tage GDPR Frist -->
<bpmn:boundaryEvent id="gdpr_deadline" attachedToRef="access_subprocess" cancelActivity="false">
<bpmn:timerEventDefinition>
<bpmn:timeDuration>P25D</bpmn:timeDuration>
</bpmn:timerEventDefinition>
<bpmn:outgoing>flow_deadline_escalation</bpmn:outgoing>
</bpmn:boundaryEvent>
<!-- Service Task: Eskalation an DSB -->
<bpmn:serviceTask id="escalate_dsb" name="Eskalation an DSB" camunda:delegateExpression="${escalateToDSBDelegate}">
<bpmn:incoming>flow_deadline_escalation</bpmn:incoming>
</bpmn:serviceTask>
<!-- Sequence Flows -->
<bpmn:sequenceFlow id="flow_to_validate" sourceRef="start" targetRef="validate_request" />
<bpmn:sequenceFlow id="flow_to_validation_gateway" sourceRef="validate_request" targetRef="validation_gateway" />
<bpmn:sequenceFlow id="flow_valid" sourceRef="validation_gateway" targetRef="type_gateway">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${valid == true}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_invalid" sourceRef="validation_gateway" targetRef="reject_request">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${valid == false}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_to_reject_end" sourceRef="reject_request" targetRef="end_rejected" />
<bpmn:sequenceFlow id="flow_access" sourceRef="type_gateway" targetRef="access_subprocess">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'access'}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_deletion" sourceRef="type_gateway" targetRef="deletion_subprocess">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'deletion'}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_portability" sourceRef="type_gateway" targetRef="portability_subprocess">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'portability'}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_rectification" sourceRef="type_gateway" targetRef="rectification_subprocess">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'rectification'}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_access_done" sourceRef="access_subprocess" targetRef="merge_gateway" />
<bpmn:sequenceFlow id="flow_deletion_done" sourceRef="deletion_subprocess" targetRef="merge_gateway" />
<bpmn:sequenceFlow id="flow_portability_done" sourceRef="portability_subprocess" targetRef="merge_gateway" />
<bpmn:sequenceFlow id="flow_rectification_done" sourceRef="rectification_subprocess" targetRef="merge_gateway" />
<bpmn:sequenceFlow id="flow_to_notify" sourceRef="merge_gateway" targetRef="notify_subject" />
<bpmn:sequenceFlow id="flow_to_audit" sourceRef="notify_subject" targetRef="create_audit" />
<bpmn:sequenceFlow id="flow_to_end" sourceRef="create_audit" targetRef="end" />
<bpmn:sequenceFlow id="flow_deadline_escalation" sourceRef="gdpr_deadline" targetRef="escalate_dsb" />
</bpmn:process>
<!-- BPMN Diagram -->
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="DSRRequestProcess">
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
<dc:Bounds x="152" y="252" width="36" height="36" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="validate_request_di" bpmnElement="validate_request">
<dc:Bounds x="240" y="230" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="validation_gateway_di" bpmnElement="validation_gateway" isMarkerVisible="true">
<dc:Bounds x="395" y="245" width="50" height="50" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="type_gateway_di" bpmnElement="type_gateway" isMarkerVisible="true">
<dc:Bounds x="545" y="245" width="50" height="50" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
<dc:Bounds x="1502" y="252" width="36" height="36" />
</bpmndi:BPMNShape>
</bpmndi:BPMNPlane>
</bpmndi:BPMNDiagram>
</bpmn:definitions>
+215
View File
@@ -0,0 +1,215 @@
<?xml version="1.0" encoding="UTF-8"?>
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
id="Definitions_Klausur"
targetNamespace="http://breakpilot.de/bpmn/klausur">
<bpmn:process id="KlausurKorrekturProcess" name="Klausurkorrektur Workflow" isExecutable="true">
<!-- Start Event -->
<bpmn:startEvent id="start" name="Klausuren hochgeladen">
<bpmn:outgoing>flow_to_ocr</bpmn:outgoing>
</bpmn:startEvent>
<!-- Service Task: OCR Verarbeitung -->
<bpmn:serviceTask id="ocr_processing" name="OCR Verarbeitung" camunda:delegateExpression="${ocrProcessingDelegate}">
<bpmn:incoming>flow_to_ocr</bpmn:incoming>
<bpmn:outgoing>flow_to_quality_check</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: Qualitaets-Check -->
<bpmn:serviceTask id="quality_check" name="OCR Qualitaets-Check" camunda:delegateExpression="${ocrQualityCheckDelegate}">
<bpmn:incoming>flow_to_quality_check</bpmn:incoming>
<bpmn:outgoing>flow_to_quality_gateway</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Gateway: OCR Qualitaet ausreichend? -->
<bpmn:exclusiveGateway id="quality_gateway" name="OCR Qualitaet OK?">
<bpmn:incoming>flow_to_quality_gateway</bpmn:incoming>
<bpmn:outgoing>flow_quality_ok</bpmn:outgoing>
<bpmn:outgoing>flow_quality_bad</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- User Task: Manuelle Nachbearbeitung -->
<bpmn:userTask id="manual_ocr_fix" name="Manuelle OCR-Korrektur" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="correctedText" label="Korrigierter Text" type="string" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_quality_bad</bpmn:incoming>
<bpmn:outgoing>flow_from_manual_fix</bpmn:outgoing>
</bpmn:userTask>
<!-- User Task: Erwartungshorizont definieren -->
<bpmn:userTask id="define_expectations" name="Erwartungshorizont" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="criteria" label="Bewertungskriterien" type="string" />
<camunda:formField id="maxPoints" label="Maximale Punktzahl" type="long" />
<camunda:formField id="useTemplate" label="Template verwenden" type="boolean" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_quality_ok</bpmn:incoming>
<bpmn:incoming>flow_from_manual_fix</bpmn:incoming>
<bpmn:outgoing>flow_to_ai_grading</bpmn:outgoing>
</bpmn:userTask>
<!-- Service Task: AI Bewertung -->
<bpmn:serviceTask id="ai_grading" name="AI-Bewertung (Claude)" camunda:delegateExpression="${aiGradingDelegate}">
<bpmn:incoming>flow_to_ai_grading</bpmn:incoming>
<bpmn:outgoing>flow_to_teacher_review</bpmn:outgoing>
</bpmn:serviceTask>
<!-- User Task: Lehrer-Review -->
<bpmn:userTask id="teacher_review" name="Lehrer-Review" camunda:assignee="${teacherId}">
<bpmn:extensionElements>
<camunda:formData>
<camunda:formField id="adjustedGrade" label="Angepasste Bewertung" type="long" />
<camunda:formField id="comments" label="Kommentare" type="string" />
<camunda:formField id="approved" label="Bewertung final" type="boolean" />
</camunda:formData>
</bpmn:extensionElements>
<bpmn:incoming>flow_to_teacher_review</bpmn:incoming>
<bpmn:outgoing>flow_to_review_gateway</bpmn:outgoing>
</bpmn:userTask>
<!-- Gateway: Review abgeschlossen? -->
<bpmn:exclusiveGateway id="review_gateway" name="Review OK?">
<bpmn:incoming>flow_to_review_gateway</bpmn:incoming>
<bpmn:outgoing>flow_review_ok</bpmn:outgoing>
<bpmn:outgoing>flow_review_adjust</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Service Task: Noten berechnen -->
<bpmn:serviceTask id="calculate_grades" name="Noten berechnen" camunda:delegateExpression="${calculateGradesDelegate}">
<bpmn:incoming>flow_review_ok</bpmn:incoming>
<bpmn:outgoing>flow_to_gradebook</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: In Notenbuch uebertragen -->
<bpmn:serviceTask id="update_gradebook" name="Notenbuch aktualisieren" camunda:delegateExpression="${updateGradebookDelegate}">
<bpmn:incoming>flow_to_gradebook</bpmn:incoming>
<bpmn:outgoing>flow_to_export</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: Export generieren -->
<bpmn:serviceTask id="generate_export" name="Export generieren" camunda:delegateExpression="${generateExportDelegate}">
<bpmn:incoming>flow_to_export</bpmn:incoming>
<bpmn:outgoing>flow_to_notify_gateway</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Gateway: Eltern benachrichtigen? -->
<bpmn:exclusiveGateway id="notify_gateway" name="Eltern benachrichtigen?">
<bpmn:incoming>flow_to_notify_gateway</bpmn:incoming>
<bpmn:outgoing>flow_notify_yes</bpmn:outgoing>
<bpmn:outgoing>flow_notify_no</bpmn:outgoing>
</bpmn:exclusiveGateway>
<!-- Service Task: Eltern benachrichtigen -->
<bpmn:serviceTask id="notify_parents" name="Eltern benachrichtigen" camunda:delegateExpression="${notifyParentsDelegate}">
<bpmn:incoming>flow_notify_yes</bpmn:incoming>
<bpmn:outgoing>flow_from_notify</bpmn:outgoing>
</bpmn:serviceTask>
<!-- Service Task: Archivieren -->
<bpmn:serviceTask id="archive_exam" name="Klausur archivieren" camunda:delegateExpression="${archiveExamDelegate}">
<bpmn:incoming>flow_notify_no</bpmn:incoming>
<bpmn:incoming>flow_from_notify</bpmn:incoming>
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
</bpmn:serviceTask>
<!-- End Event -->
<bpmn:endEvent id="end" name="Korrektur abgeschlossen">
<bpmn:incoming>flow_to_end</bpmn:incoming>
</bpmn:endEvent>
<!-- Boundary Timer: Korrektur-Deadline -->
<bpmn:boundaryEvent id="correction_deadline" attachedToRef="teacher_review" cancelActivity="false">
<bpmn:timerEventDefinition>
<bpmn:timeDuration>P${correctionDeadlineDays}D</bpmn:timeDuration>
</bpmn:timerEventDefinition>
<bpmn:outgoing>flow_deadline_warning</bpmn:outgoing>
</bpmn:boundaryEvent>
<!-- Service Task: Deadline-Warnung -->
<bpmn:serviceTask id="deadline_warning" name="Deadline-Warnung" camunda:delegateExpression="${deadlineWarningDelegate}">
<bpmn:incoming>flow_deadline_warning</bpmn:incoming>
</bpmn:serviceTask>
<!-- Sequence Flows -->
<bpmn:sequenceFlow id="flow_to_ocr" sourceRef="start" targetRef="ocr_processing" />
<bpmn:sequenceFlow id="flow_to_quality_check" sourceRef="ocr_processing" targetRef="quality_check" />
<bpmn:sequenceFlow id="flow_to_quality_gateway" sourceRef="quality_check" targetRef="quality_gateway" />
<bpmn:sequenceFlow id="flow_quality_ok" sourceRef="quality_gateway" targetRef="define_expectations">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${ocrConfidence >= 0.85}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_quality_bad" sourceRef="quality_gateway" targetRef="manual_ocr_fix">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${ocrConfidence &lt; 0.85}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_from_manual_fix" sourceRef="manual_ocr_fix" targetRef="define_expectations" />
<bpmn:sequenceFlow id="flow_to_ai_grading" sourceRef="define_expectations" targetRef="ai_grading" />
<bpmn:sequenceFlow id="flow_to_teacher_review" sourceRef="ai_grading" targetRef="teacher_review" />
<bpmn:sequenceFlow id="flow_to_review_gateway" sourceRef="teacher_review" targetRef="review_gateway" />
<bpmn:sequenceFlow id="flow_review_ok" sourceRef="review_gateway" targetRef="calculate_grades">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == true}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_review_adjust" sourceRef="review_gateway" targetRef="ai_grading">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == false}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_to_gradebook" sourceRef="calculate_grades" targetRef="update_gradebook" />
<bpmn:sequenceFlow id="flow_to_export" sourceRef="update_gradebook" targetRef="generate_export" />
<bpmn:sequenceFlow id="flow_to_notify_gateway" sourceRef="generate_export" targetRef="notify_gateway" />
<bpmn:sequenceFlow id="flow_notify_yes" sourceRef="notify_gateway" targetRef="notify_parents">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${notifyParents == true}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_notify_no" sourceRef="notify_gateway" targetRef="archive_exam">
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${notifyParents == false}</bpmn:conditionExpression>
</bpmn:sequenceFlow>
<bpmn:sequenceFlow id="flow_from_notify" sourceRef="notify_parents" targetRef="archive_exam" />
<bpmn:sequenceFlow id="flow_to_end" sourceRef="archive_exam" targetRef="end" />
<bpmn:sequenceFlow id="flow_deadline_warning" sourceRef="correction_deadline" targetRef="deadline_warning" />
</bpmn:process>
<!-- BPMN Diagram -->
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="KlausurKorrekturProcess">
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
<dc:Bounds x="152" y="102" width="36" height="36" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="ocr_processing_di" bpmnElement="ocr_processing">
<dc:Bounds x="240" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="quality_check_di" bpmnElement="quality_check">
<dc:Bounds x="390" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="quality_gateway_di" bpmnElement="quality_gateway" isMarkerVisible="true">
<dc:Bounds x="545" y="95" width="50" height="50" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="manual_ocr_fix_di" bpmnElement="manual_ocr_fix">
<dc:Bounds x="520" y="200" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="define_expectations_di" bpmnElement="define_expectations">
<dc:Bounds x="650" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="ai_grading_di" bpmnElement="ai_grading">
<dc:Bounds x="800" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="teacher_review_di" bpmnElement="teacher_review">
<dc:Bounds x="950" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="calculate_grades_di" bpmnElement="calculate_grades">
<dc:Bounds x="1100" y="80" width="100" height="80" />
</bpmndi:BPMNShape>
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
<dc:Bounds x="1702" y="102" width="36" height="36" />
</bpmndi:BPMNShape>
</bpmndi:BPMNPlane>
</bpmndi:BPMNDiagram>
</bpmn:definitions>
+48
View File
@@ -0,0 +1,48 @@
# Binaries
*.exe
*.exe~
*.dll
*.so
*.dylib
server
# Test binary
*.test
# Output of go coverage tool
*.out
# Go workspace file
go.work
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Local config
.env
.env.local
*.local
# Logs
*.log
logs/
# Temp files
*.tmp
*.temp
.DS_Store
# Git
.git/
.gitignore
# Docker
Dockerfile
docker-compose*.yml
# Vendor (if using)
vendor/
+21
View File
@@ -0,0 +1,21 @@
# Server Configuration
PORT=8081
ENVIRONMENT=development
# Database Configuration
# PostgreSQL connection string
DATABASE_URL=postgres://user:password@localhost:5432/consent_db?sslmode=disable
# JWT Configuration (should match BreakPilot's JWT secret for token validation)
JWT_SECRET=your-jwt-secret-here
JWT_REFRESH_SECRET=your-refresh-secret-here
# CORS Configuration
ALLOWED_ORIGINS=http://localhost:3000,http://localhost:8000,https://breakpilot.app
# Rate Limiting
RATE_LIMIT_REQUESTS=100
RATE_LIMIT_WINDOW=60
# BreakPilot Integration
BREAKPILOT_API_URL=http://localhost:8000
+42
View File
@@ -0,0 +1,42 @@
# Build stage
FROM golang:1.23-alpine AS builder
WORKDIR /app
# Install build dependencies
RUN apk add --no-cache git ca-certificates
# Copy go mod files
COPY go.mod go.sum ./
RUN go mod download
# Copy source code
COPY . .
# Build the binary
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o consent-service ./cmd/server
# Runtime stage
FROM alpine:3.19
WORKDIR /app
# Install runtime dependencies
RUN apk --no-cache add ca-certificates tzdata
# Copy binary from builder
COPY --from=builder /app/consent-service .
# Create non-root user
RUN adduser -D -g '' appuser
USER appuser
# Expose port
EXPOSE 8081
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8081/health || exit 1
# Run the binary
CMD ["./consent-service"]
+471
View File
@@ -0,0 +1,471 @@
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/breakpilot/consent-service/internal/config"
"github.com/breakpilot/consent-service/internal/database"
"github.com/breakpilot/consent-service/internal/handlers"
"github.com/breakpilot/consent-service/internal/middleware"
"github.com/breakpilot/consent-service/internal/services"
"github.com/breakpilot/consent-service/internal/services/jitsi"
"github.com/breakpilot/consent-service/internal/services/matrix"
"github.com/gin-gonic/gin"
)
func main() {
// Load configuration
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Initialize database
db, err := database.Connect(cfg.DatabaseURL)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
// Run migrations
if err := database.Migrate(db); err != nil {
log.Fatalf("Failed to run migrations: %v", err)
}
// Setup Gin router
if cfg.Environment == "production" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.Default()
// Global middleware
router.Use(middleware.CORS())
router.Use(middleware.RequestLogger())
router.Use(middleware.RateLimiter())
// Health check
router.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{
"status": "healthy",
"service": "consent-service",
"version": "1.0.0",
})
})
// Initialize services
authService := services.NewAuthService(db.Pool, cfg.JWTSecret, cfg.JWTRefreshSecret)
oauthService := services.NewOAuthService(db.Pool, cfg.JWTSecret)
totpService := services.NewTOTPService(db.Pool, "BreakPilot")
emailService := services.NewEmailService(services.EmailConfig{
Host: cfg.SMTPHost,
Port: cfg.SMTPPort,
Username: cfg.SMTPUsername,
Password: cfg.SMTPPassword,
FromName: cfg.SMTPFromName,
FromAddr: cfg.SMTPFromAddr,
BaseURL: cfg.FrontendURL,
})
notificationService := services.NewNotificationService(db.Pool, emailService)
deadlineService := services.NewDeadlineService(db.Pool, notificationService)
emailTemplateService := services.NewEmailTemplateService(db.Pool)
dsrService := services.NewDSRService(db.Pool, notificationService, emailService)
// Initialize handlers
h := handlers.New(db)
authHandler := handlers.NewAuthHandler(authService, emailService)
oauthHandler := handlers.NewOAuthHandler(oauthService, totpService, authService)
notificationHandler := handlers.NewNotificationHandler(notificationService)
deadlineHandler := handlers.NewDeadlineHandler(deadlineService)
emailTemplateHandler := handlers.NewEmailTemplateHandler(emailTemplateService)
dsrHandler := handlers.NewDSRHandler(dsrService)
// Initialize Matrix service (if enabled)
var matrixService *matrix.MatrixService
if cfg.MatrixEnabled && cfg.MatrixAccessToken != "" {
matrixService = matrix.NewMatrixService(matrix.Config{
HomeserverURL: cfg.MatrixHomeserverURL,
AccessToken: cfg.MatrixAccessToken,
ServerName: cfg.MatrixServerName,
})
log.Println("Matrix service initialized")
} else {
log.Println("Matrix service disabled or not configured")
}
// Initialize Jitsi service (if enabled)
var jitsiService *jitsi.JitsiService
if cfg.JitsiEnabled {
jitsiService = jitsi.NewJitsiService(jitsi.Config{
BaseURL: cfg.JitsiBaseURL,
AppID: cfg.JitsiAppID,
AppSecret: cfg.JitsiAppSecret,
})
log.Println("Jitsi service initialized")
} else {
log.Println("Jitsi service disabled")
}
// Initialize communication handlers
communicationHandler := handlers.NewCommunicationHandlers(matrixService, jitsiService)
// Initialize default email templates (runs only once)
if err := emailTemplateService.InitDefaultTemplates(context.Background()); err != nil {
log.Printf("Warning: Failed to initialize default email templates: %v", err)
}
// API v1 routes
v1 := router.Group("/api/v1")
{
// =============================================
// OAuth 2.0 Endpoints (RFC 6749)
// =============================================
oauth := v1.Group("/oauth")
{
// Authorization endpoint (requires user auth for consent)
oauth.GET("/authorize", middleware.AuthMiddleware(cfg.JWTSecret), oauthHandler.Authorize)
// Token endpoint (public)
oauth.POST("/token", oauthHandler.Token)
// Revocation endpoint (RFC 7009)
oauth.POST("/revoke", oauthHandler.Revoke)
// Introspection endpoint (RFC 7662)
oauth.POST("/introspect", oauthHandler.Introspect)
}
// =============================================
// Authentication Routes (with 2FA support)
// =============================================
auth := v1.Group("/auth")
{
// Registration with mandatory 2FA setup
auth.POST("/register", oauthHandler.RegisterWith2FA)
// Login with 2FA support
auth.POST("/login", oauthHandler.LoginWith2FA)
// 2FA challenge verification (during login)
auth.POST("/2fa/verify", oauthHandler.Verify2FAChallenge)
// Legacy endpoints (kept for compatibility)
auth.POST("/logout", authHandler.Logout)
auth.POST("/refresh", authHandler.RefreshToken)
auth.POST("/verify-email", authHandler.VerifyEmail)
auth.POST("/resend-verification", authHandler.ResendVerification)
auth.POST("/forgot-password", authHandler.ForgotPassword)
auth.POST("/reset-password", authHandler.ResetPassword)
}
// =============================================
// 2FA Management Routes (require auth)
// =============================================
twoFA := v1.Group("/auth/2fa")
twoFA.Use(middleware.AuthMiddleware(cfg.JWTSecret))
{
twoFA.GET("/status", oauthHandler.Get2FAStatus)
twoFA.POST("/setup", oauthHandler.Setup2FA)
twoFA.POST("/verify-setup", oauthHandler.Verify2FASetup)
twoFA.POST("/disable", oauthHandler.Disable2FA)
twoFA.POST("/recovery-codes", oauthHandler.RegenerateRecoveryCodes)
}
// =============================================
// Profile Routes (require auth)
// =============================================
profile := v1.Group("/profile")
profile.Use(middleware.AuthMiddleware(cfg.JWTSecret))
{
profile.GET("", authHandler.GetProfile)
profile.PUT("", authHandler.UpdateProfile)
profile.PUT("/password", authHandler.ChangePassword)
profile.GET("/sessions", authHandler.GetActiveSessions)
profile.DELETE("/sessions/:id", authHandler.RevokeSession)
}
// =============================================
// Public consent routes (require user auth)
// =============================================
public := v1.Group("")
public.Use(middleware.AuthMiddleware(cfg.JWTSecret))
{
// Documents
public.GET("/documents", h.GetDocuments)
public.GET("/documents/:type", h.GetDocumentByType)
public.GET("/documents/:type/latest", h.GetLatestDocumentVersion)
// User Consent
public.POST("/consent", h.CreateConsent)
public.GET("/consent/my", h.GetMyConsents)
public.GET("/consent/check/:documentType", h.CheckConsent)
public.DELETE("/consent/:id", h.WithdrawConsent)
// Cookie Consent
public.GET("/cookies/categories", h.GetCookieCategories)
public.POST("/cookies/consent", h.SetCookieConsent)
public.GET("/cookies/consent/my", h.GetMyCookieConsent)
// GDPR / Data Subject Rights
public.GET("/privacy/my-data", h.GetMyData)
public.POST("/privacy/export", h.RequestDataExport)
public.POST("/privacy/delete", h.RequestDataDeletion)
// Data Subject Requests (User-facing)
public.POST("/dsr", dsrHandler.CreateDSR)
public.GET("/dsr", dsrHandler.GetMyDSRs)
public.GET("/dsr/:id", dsrHandler.GetMyDSR)
public.POST("/dsr/:id/cancel", dsrHandler.CancelMyDSR)
// Notifications
public.GET("/notifications", notificationHandler.GetNotifications)
public.GET("/notifications/unread-count", notificationHandler.GetUnreadCount)
public.PUT("/notifications/:id/read", notificationHandler.MarkAsRead)
public.PUT("/notifications/read-all", notificationHandler.MarkAllAsRead)
public.DELETE("/notifications/:id", notificationHandler.DeleteNotification)
public.GET("/notifications/preferences", notificationHandler.GetPreferences)
public.PUT("/notifications/preferences", notificationHandler.UpdatePreferences)
// Consent Deadlines & Suspension Status
public.GET("/consent/deadlines", deadlineHandler.GetPendingDeadlines)
public.GET("/account/suspension-status", deadlineHandler.GetSuspensionStatus)
}
// Admin routes (require admin auth)
admin := v1.Group("/admin")
admin.Use(middleware.AuthMiddleware(cfg.JWTSecret))
admin.Use(middleware.AdminOnly())
{
// Document Management
admin.GET("/documents", h.AdminGetDocuments)
admin.POST("/documents", h.AdminCreateDocument)
admin.PUT("/documents/:id", h.AdminUpdateDocument)
admin.DELETE("/documents/:id", h.AdminDeleteDocument)
admin.GET("/documents/:docId/versions", h.AdminGetVersions)
// Version Management
admin.POST("/versions", h.AdminCreateVersion)
admin.PUT("/versions/:id", h.AdminUpdateVersion)
admin.DELETE("/versions/:id", h.AdminDeleteVersion)
admin.POST("/versions/:id/archive", h.AdminArchiveVersion)
admin.POST("/versions/:id/submit-review", h.AdminSubmitForReview)
admin.POST("/versions/:id/approve", h.AdminApproveVersion)
admin.POST("/versions/:id/reject", h.AdminRejectVersion)
admin.GET("/versions/:id/compare", h.AdminCompareVersions)
admin.GET("/versions/:id/approval-history", h.AdminGetApprovalHistory)
// Publishing (DSB role recommended but Admin can also do it in dev)
admin.POST("/versions/:id/publish", h.AdminPublishVersion)
// Cookie Categories
admin.GET("/cookies/categories", h.AdminGetCookieCategories)
admin.POST("/cookies/categories", h.AdminCreateCookieCategory)
admin.PUT("/cookies/categories/:id", h.AdminUpdateCookieCategory)
admin.DELETE("/cookies/categories/:id", h.AdminDeleteCookieCategory)
// Statistics & Audit
admin.GET("/stats/consents", h.GetConsentStats)
admin.GET("/stats/cookies", h.GetCookieStats)
admin.GET("/audit-log", h.GetAuditLog)
// Deadline Management (for testing/manual trigger)
admin.POST("/deadlines/process", deadlineHandler.TriggerDeadlineProcessing)
// Scheduled Publishing
admin.GET("/scheduled-versions", h.GetScheduledVersions)
admin.POST("/scheduled-publishing/process", h.ProcessScheduledPublishing)
// OAuth Client Management
admin.GET("/oauth/clients", oauthHandler.AdminGetClients)
admin.POST("/oauth/clients", oauthHandler.AdminCreateClient)
// =============================================
// E-Mail Template Management
// =============================================
admin.GET("/email-templates/types", emailTemplateHandler.GetAllTemplateTypes)
admin.GET("/email-templates", emailTemplateHandler.GetAllTemplates)
admin.GET("/email-templates/settings", emailTemplateHandler.GetSettings)
admin.PUT("/email-templates/settings", emailTemplateHandler.UpdateSettings)
admin.GET("/email-templates/stats", emailTemplateHandler.GetEmailStats)
admin.GET("/email-templates/logs", emailTemplateHandler.GetSendLogs)
admin.GET("/email-templates/default/:type", emailTemplateHandler.GetDefaultContent)
admin.POST("/email-templates/initialize", emailTemplateHandler.InitializeTemplates)
admin.GET("/email-templates/:id", emailTemplateHandler.GetTemplate)
admin.POST("/email-templates", emailTemplateHandler.CreateTemplate)
admin.GET("/email-templates/:id/versions", emailTemplateHandler.GetTemplateVersions)
// E-Mail Template Versions
admin.GET("/email-template-versions/:id", emailTemplateHandler.GetVersion)
admin.POST("/email-template-versions", emailTemplateHandler.CreateVersion)
admin.PUT("/email-template-versions/:id", emailTemplateHandler.UpdateVersion)
admin.POST("/email-template-versions/:id/submit", emailTemplateHandler.SubmitForReview)
admin.POST("/email-template-versions/:id/approve", emailTemplateHandler.ApproveVersion)
admin.POST("/email-template-versions/:id/reject", emailTemplateHandler.RejectVersion)
admin.POST("/email-template-versions/:id/publish", emailTemplateHandler.PublishVersion)
admin.GET("/email-template-versions/:id/approvals", emailTemplateHandler.GetApprovals)
admin.POST("/email-template-versions/:id/preview", emailTemplateHandler.PreviewVersion)
admin.POST("/email-template-versions/:id/send-test", emailTemplateHandler.SendTestEmail)
// =============================================
// Data Subject Requests (DSR) Management
// =============================================
admin.GET("/dsr", dsrHandler.AdminListDSR)
admin.GET("/dsr/stats", dsrHandler.AdminGetDSRStats)
admin.POST("/dsr", dsrHandler.AdminCreateDSR)
admin.GET("/dsr/:id", dsrHandler.AdminGetDSR)
admin.PUT("/dsr/:id", dsrHandler.AdminUpdateDSR)
admin.POST("/dsr/:id/status", dsrHandler.AdminUpdateDSRStatus)
admin.POST("/dsr/:id/verify-identity", dsrHandler.AdminVerifyIdentity)
admin.POST("/dsr/:id/assign", dsrHandler.AdminAssignDSR)
admin.POST("/dsr/:id/extend", dsrHandler.AdminExtendDSRDeadline)
admin.POST("/dsr/:id/complete", dsrHandler.AdminCompleteDSR)
admin.POST("/dsr/:id/reject", dsrHandler.AdminRejectDSR)
admin.GET("/dsr/:id/history", dsrHandler.AdminGetDSRHistory)
admin.GET("/dsr/:id/communications", dsrHandler.AdminGetDSRCommunications)
admin.POST("/dsr/:id/communicate", dsrHandler.AdminSendDSRCommunication)
admin.GET("/dsr/:id/exception-checks", dsrHandler.AdminGetExceptionChecks)
admin.POST("/dsr/:id/exception-checks/init", dsrHandler.AdminInitExceptionChecks)
admin.PUT("/dsr/:id/exception-checks/:checkId", dsrHandler.AdminUpdateExceptionCheck)
admin.POST("/dsr/deadlines/process", dsrHandler.ProcessDeadlines)
// DSR Templates
admin.GET("/dsr-templates", dsrHandler.AdminGetDSRTemplates)
admin.GET("/dsr-templates/published", dsrHandler.AdminGetPublishedDSRTemplates)
admin.GET("/dsr-templates/:id/versions", dsrHandler.AdminGetDSRTemplateVersions)
admin.POST("/dsr-templates/:id/versions", dsrHandler.AdminCreateDSRTemplateVersion)
admin.POST("/dsr-template-versions/:versionId/publish", dsrHandler.AdminPublishDSRTemplateVersion)
}
// =============================================
// Communication Routes (Matrix + Jitsi)
// =============================================
communicationHandler.RegisterRoutes(v1, cfg.JWTSecret, middleware.AuthMiddleware(cfg.JWTSecret))
// =============================================
// Cookie Banner SDK Routes (Public - Anonymous)
// =============================================
// Diese Endpoints werden vom @breakpilot/consent-sdk verwendet
// für anonyme (device-basierte) Cookie-Einwilligungen.
banner := v1.Group("/banner")
{
// Public Endpoints (keine Auth erforderlich)
banner.POST("/consent", h.CreateBannerConsent)
banner.GET("/consent", h.GetBannerConsent)
banner.DELETE("/consent/:consentId", h.RevokeBannerConsent)
banner.GET("/config/:siteId", h.GetSiteConfig)
banner.GET("/consent/export", h.ExportBannerConsent)
}
// Banner Admin Routes (require admin auth)
bannerAdmin := v1.Group("/banner/admin")
bannerAdmin.Use(middleware.AuthMiddleware(cfg.JWTSecret))
bannerAdmin.Use(middleware.AdminOnly())
{
bannerAdmin.GET("/stats/:siteId", h.GetBannerStats)
}
}
// Start background scheduler for scheduled publishing
go startScheduledPublishingWorker(db)
// Start DSR deadline monitoring worker
go startDSRDeadlineWorker(dsrService)
// Start server
port := cfg.Port
if port == "" {
port = "8080"
}
log.Printf("Starting Consent Service on port %s", port)
if err := router.Run(":" + port); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}
// startScheduledPublishingWorker runs every minute to check for scheduled versions
func startScheduledPublishingWorker(db *database.DB) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
log.Println("Scheduled publishing worker started (checking every minute)")
for range ticker.C {
processScheduledVersions(db)
}
}
func processScheduledVersions(db *database.DB) {
ctx := context.Background()
// Find all scheduled versions that are due
rows, err := db.Pool.Query(ctx, `
SELECT id, document_id, version
FROM document_versions
WHERE status = 'scheduled'
AND scheduled_publish_at IS NOT NULL
AND scheduled_publish_at <= NOW()
`)
if err != nil {
log.Printf("Scheduler: Error fetching scheduled versions: %v", err)
return
}
defer rows.Close()
var publishedCount int
for rows.Next() {
var versionID, docID string
var version string
if err := rows.Scan(&versionID, &docID, &version); err != nil {
continue
}
// Publish this version
_, err := db.Pool.Exec(ctx, `
UPDATE document_versions
SET status = 'published', published_at = NOW(), updated_at = NOW()
WHERE id = $1
`, versionID)
if err == nil {
// Archive previous published versions for this document
db.Pool.Exec(ctx, `
UPDATE document_versions
SET status = 'archived', updated_at = NOW()
WHERE document_id = $1 AND id != $2 AND status = 'published'
`, docID, versionID)
// Log the publishing
details := fmt.Sprintf("Version %s automatically published by scheduler", version)
db.Pool.Exec(ctx, `
INSERT INTO consent_audit_log (action, entity_type, entity_id, details, user_agent)
VALUES ('version_scheduled_published', 'document_version', $1, $2, 'scheduler')
`, versionID, details)
publishedCount++
log.Printf("Scheduler: Published version %s (ID: %s)", version, versionID)
}
}
if publishedCount > 0 {
log.Printf("Scheduler: Published %d version(s)", publishedCount)
}
}
// startDSRDeadlineWorker monitors DSR deadlines and sends notifications
func startDSRDeadlineWorker(dsrService *services.DSRService) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
log.Println("DSR deadline monitoring worker started (checking every hour)")
// Run immediately on startup
ctx := context.Background()
if err := dsrService.ProcessDeadlines(ctx); err != nil {
log.Printf("DSR Worker: Error processing deadlines: %v", err)
}
for range ticker.C {
ctx := context.Background()
if err := dsrService.ProcessDeadlines(ctx); err != nil {
log.Printf("DSR Worker: Error processing deadlines: %v", err)
}
}
}
+41
View File
@@ -0,0 +1,41 @@
version: '3.8'
services:
consent-service:
build: .
ports:
- "8081:8081"
env_file:
- .env
environment:
- DATABASE_URL=postgres://consent:consent123@postgres:5432/consent_db?sslmode=disable
depends_on:
postgres:
condition: service_healthy
networks:
- consent-network
postgres:
image: postgres:16-alpine
ports:
- "5433:5432"
environment:
- POSTGRES_USER=consent
- POSTGRES_PASSWORD=consent123
- POSTGRES_DB=consent_db
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U consent -d consent_db"]
interval: 5s
timeout: 5s
retries: 5
networks:
- consent-network
volumes:
postgres_data:
networks:
consent-network:
driver: bridge
+49
View File
@@ -0,0 +1,49 @@
module github.com/breakpilot/consent-service
go 1.23.0
require (
github.com/gin-gonic/gin v1.11.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.6
github.com/joho/godotenv v1.5.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
golang.org/x/crypto v0.40.0
)
require (
github.com/bytedance/sonic v1.14.0 // indirect
github.com/bytedance/sonic/loader v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.27.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.54.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.0 // indirect
go.uber.org/mock v0.5.0 // indirect
golang.org/x/arch v0.20.0 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/net v0.42.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/tools v0.34.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect
)
+105
View File
@@ -0,0 +1,105 @@
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg=
github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+170
View File
@@ -0,0 +1,170 @@
package config
import (
"fmt"
"os"
"github.com/joho/godotenv"
)
// Config holds all configuration for the service
type Config struct {
// Server
Port string
Environment string
// Database
DatabaseURL string
// JWT
JWTSecret string
JWTRefreshSecret string
// CORS
AllowedOrigins []string
// Rate Limiting
RateLimitRequests int
RateLimitWindow int // in seconds
// BreakPilot Integration
BreakPilotAPIURL string
FrontendURL string
// SMTP Email Configuration
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFromName string
SMTPFromAddr string
// Consent Settings
ConsentDeadlineDays int
ConsentReminderEnabled bool
// VAPID Keys for Web Push
VAPIDPublicKey string
VAPIDPrivateKey string
// Matrix (Synapse) Configuration
MatrixHomeserverURL string
MatrixAccessToken string
MatrixServerName string
MatrixEnabled bool
// Jitsi Configuration
JitsiBaseURL string
JitsiAppID string
JitsiAppSecret string
JitsiEnabled bool
}
// Load loads configuration from environment variables
func Load() (*Config, error) {
// Load .env file if exists (for development)
_ = godotenv.Load()
cfg := &Config{
Port: getEnv("PORT", "8080"),
Environment: getEnv("ENVIRONMENT", "development"),
DatabaseURL: getEnv("DATABASE_URL", ""),
JWTSecret: getEnv("JWT_SECRET", ""),
JWTRefreshSecret: getEnv("JWT_REFRESH_SECRET", ""),
RateLimitRequests: getEnvInt("RATE_LIMIT_REQUESTS", 100),
RateLimitWindow: getEnvInt("RATE_LIMIT_WINDOW", 60),
BreakPilotAPIURL: getEnv("BREAKPILOT_API_URL", "http://localhost:8000"),
FrontendURL: getEnv("FRONTEND_URL", "http://localhost:8000"),
// SMTP Configuration
SMTPHost: getEnv("SMTP_HOST", ""),
SMTPPort: getEnvInt("SMTP_PORT", 587),
SMTPUsername: getEnv("SMTP_USERNAME", ""),
SMTPPassword: getEnv("SMTP_PASSWORD", ""),
SMTPFromName: getEnv("SMTP_FROM_NAME", "BreakPilot"),
SMTPFromAddr: getEnv("SMTP_FROM_ADDR", "noreply@breakpilot.app"),
// Consent Settings
ConsentDeadlineDays: getEnvInt("CONSENT_DEADLINE_DAYS", 30),
ConsentReminderEnabled: getEnvBool("CONSENT_REMINDER_ENABLED", true),
// VAPID Keys
VAPIDPublicKey: getEnv("VAPID_PUBLIC_KEY", ""),
VAPIDPrivateKey: getEnv("VAPID_PRIVATE_KEY", ""),
// Matrix Configuration
MatrixHomeserverURL: getEnv("MATRIX_HOMESERVER_URL", "http://synapse:8008"),
MatrixAccessToken: getEnv("MATRIX_ACCESS_TOKEN", ""),
MatrixServerName: getEnv("MATRIX_SERVER_NAME", "breakpilot.local"),
MatrixEnabled: getEnvBool("MATRIX_ENABLED", true),
// Jitsi Configuration
JitsiBaseURL: getEnv("JITSI_BASE_URL", "http://localhost:8443"),
JitsiAppID: getEnv("JITSI_APP_ID", "breakpilot"),
JitsiAppSecret: getEnv("JITSI_APP_SECRET", ""),
JitsiEnabled: getEnvBool("JITSI_ENABLED", true),
}
// Parse allowed origins
originsStr := getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000")
cfg.AllowedOrigins = parseCommaSeparated(originsStr)
// Validate required fields
if cfg.DatabaseURL == "" {
return nil, fmt.Errorf("DATABASE_URL is required")
}
if cfg.JWTSecret == "" {
return nil, fmt.Errorf("JWT_SECRET is required")
}
return cfg, nil
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var result int
fmt.Sscanf(value, "%d", &result)
return result
}
return defaultValue
}
func getEnvBool(key string, defaultValue bool) bool {
if value := os.Getenv(key); value != "" {
return value == "true" || value == "1" || value == "yes"
}
return defaultValue
}
func parseCommaSeparated(s string) []string {
if s == "" {
return []string{}
}
var result []string
start := 0
for i := 0; i <= len(s); i++ {
if i == len(s) || s[i] == ',' {
item := s[start:i]
// Trim whitespace
for len(item) > 0 && item[0] == ' ' {
item = item[1:]
}
for len(item) > 0 && item[len(item)-1] == ' ' {
item = item[:len(item)-1]
}
if item != "" {
result = append(result, item)
}
start = i + 1
}
}
return result
}
@@ -0,0 +1,322 @@
package config
import (
"os"
"testing"
)
// TestGetEnv tests the getEnv helper function
func TestGetEnv(t *testing.T) {
// Test with default value when env var not set
result := getEnv("TEST_NONEXISTENT_VAR_12345", "default")
if result != "default" {
t.Errorf("Expected 'default', got '%s'", result)
}
// Test with set env var
os.Setenv("TEST_ENV_VAR", "custom_value")
defer os.Unsetenv("TEST_ENV_VAR")
result = getEnv("TEST_ENV_VAR", "default")
if result != "custom_value" {
t.Errorf("Expected 'custom_value', got '%s'", result)
}
}
// TestGetEnvInt tests the getEnvInt helper function
func TestGetEnvInt(t *testing.T) {
tests := []struct {
name string
envValue string
defaultValue int
expected int
}{
{"default when not set", "", 100, 100},
{"parse valid int", "42", 0, 42},
{"parse zero", "0", 100, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue != "" {
os.Setenv("TEST_INT_VAR", tt.envValue)
defer os.Unsetenv("TEST_INT_VAR")
} else {
os.Unsetenv("TEST_INT_VAR")
}
result := getEnvInt("TEST_INT_VAR", tt.defaultValue)
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
// TestGetEnvBool tests the getEnvBool helper function
func TestGetEnvBool(t *testing.T) {
tests := []struct {
name string
envValue string
defaultValue bool
expected bool
}{
{"default when not set", "", true, true},
{"default false when not set", "", false, false},
{"parse true", "true", false, true},
{"parse 1", "1", false, true},
{"parse yes", "yes", false, true},
{"parse false", "false", true, false},
{"parse 0", "0", true, false},
{"parse no", "no", true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue != "" {
os.Setenv("TEST_BOOL_VAR", tt.envValue)
defer os.Unsetenv("TEST_BOOL_VAR")
} else {
os.Unsetenv("TEST_BOOL_VAR")
}
result := getEnvBool("TEST_BOOL_VAR", tt.defaultValue)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
// TestParseCommaSeparated tests the parseCommaSeparated helper function
func TestParseCommaSeparated(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{"empty string", "", []string{}},
{"single value", "value1", []string{"value1"}},
{"multiple values", "value1,value2,value3", []string{"value1", "value2", "value3"}},
{"with spaces", "value1, value2, value3", []string{"value1", "value2", "value3"}},
{"with trailing comma", "value1,value2,", []string{"value1", "value2"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseCommaSeparated(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected length %d, got %d", len(tt.expected), len(result))
return
}
for i := range result {
if result[i] != tt.expected[i] {
t.Errorf("At index %d: expected '%s', got '%s'", i, tt.expected[i], result[i])
}
}
})
}
}
// TestConfigEnvironmentDefaults tests default environment values
func TestConfigEnvironmentDefaults(t *testing.T) {
// Clear any existing env vars that might interfere
varsToUnset := []string{
"PORT", "ENVIRONMENT", "DATABASE_URL", "JWT_SECRET", "JWT_REFRESH_SECRET",
}
for _, v := range varsToUnset {
os.Unsetenv(v)
}
// Set required vars
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
defer func() {
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWT_SECRET")
}()
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Test defaults
if cfg.Port != "8080" {
t.Errorf("Expected default port '8080', got '%s'", cfg.Port)
}
if cfg.Environment != "development" {
t.Errorf("Expected default environment 'development', got '%s'", cfg.Environment)
}
}
// TestConfigLoadWithEnvironment tests loading config with different environments
func TestConfigLoadWithEnvironment(t *testing.T) {
// Set required vars
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
defer func() {
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWT_SECRET")
}()
tests := []struct {
name string
environment string
}{
{"development", "development"},
{"staging", "staging"},
{"production", "production"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("ENVIRONMENT", tt.environment)
defer os.Unsetenv("ENVIRONMENT")
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Environment != tt.environment {
t.Errorf("Expected environment '%s', got '%s'", tt.environment, cfg.Environment)
}
})
}
}
// TestConfigMissingRequiredVars tests that missing required vars return errors
func TestConfigMissingRequiredVars(t *testing.T) {
// Clear all env vars
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWT_SECRET")
_, err := Load()
if err == nil {
t.Error("Expected error when DATABASE_URL is missing")
}
// Set DATABASE_URL but not JWT_SECRET
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
defer os.Unsetenv("DATABASE_URL")
_, err = Load()
if err == nil {
t.Error("Expected error when JWT_SECRET is missing")
}
}
// TestConfigAllowedOrigins tests that allowed origins are parsed correctly
func TestConfigAllowedOrigins(t *testing.T) {
// Set required vars
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
os.Setenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000,http://localhost:8001")
defer func() {
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWT_SECRET")
os.Unsetenv("ALLOWED_ORIGINS")
}()
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
expected := []string{"http://localhost:3000", "http://localhost:8000", "http://localhost:8001"}
if len(cfg.AllowedOrigins) != len(expected) {
t.Errorf("Expected %d origins, got %d", len(expected), len(cfg.AllowedOrigins))
}
for i, origin := range cfg.AllowedOrigins {
if origin != expected[i] {
t.Errorf("At index %d: expected '%s', got '%s'", i, expected[i], origin)
}
}
}
// TestConfigDebugSettings tests debug-related settings for different environments
func TestConfigDebugSettings(t *testing.T) {
// Set required vars
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
defer func() {
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWT_SECRET")
}()
// Test development environment
t.Run("development", func(t *testing.T) {
os.Setenv("ENVIRONMENT", "development")
defer os.Unsetenv("ENVIRONMENT")
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Environment != "development" {
t.Errorf("Expected 'development', got '%s'", cfg.Environment)
}
})
// Test staging environment
t.Run("staging", func(t *testing.T) {
os.Setenv("ENVIRONMENT", "staging")
defer os.Unsetenv("ENVIRONMENT")
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Environment != "staging" {
t.Errorf("Expected 'staging', got '%s'", cfg.Environment)
}
})
// Test production environment
t.Run("production", func(t *testing.T) {
os.Setenv("ENVIRONMENT", "production")
defer os.Unsetenv("ENVIRONMENT")
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Environment != "production" {
t.Errorf("Expected 'production', got '%s'", cfg.Environment)
}
})
}
// TestConfigStagingPorts tests that staging uses different ports
func TestConfigStagingPorts(t *testing.T) {
// Set required vars
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5433/breakpilot_staging")
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
os.Setenv("ENVIRONMENT", "staging")
os.Setenv("PORT", "8081")
defer func() {
os.Unsetenv("DATABASE_URL")
os.Unsetenv("JWT_SECRET")
os.Unsetenv("ENVIRONMENT")
os.Unsetenv("PORT")
}()
cfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Port != "8081" {
t.Errorf("Expected staging port '8081', got '%s'", cfg.Port)
}
if cfg.Environment != "staging" {
t.Errorf("Expected 'staging', got '%s'", cfg.Environment)
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,442 @@
package handlers
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services"
)
// AuthHandler handles authentication endpoints
type AuthHandler struct {
authService *services.AuthService
emailService *services.EmailService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService) *AuthHandler {
return &AuthHandler{
authService: authService,
emailService: emailService,
}
}
// Register handles user registration
// @Summary Register a new user
// @Tags auth
// @Accept json
// @Produce json
// @Param request body models.RegisterRequest true "Registration data"
// @Success 201 {object} map[string]interface{}
// @Failure 400 {object} map[string]string
// @Failure 409 {object} map[string]string
// @Router /auth/register [post]
func (h *AuthHandler) Register(c *gin.Context) {
var req models.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
return
}
user, verificationToken, err := h.authService.Register(c.Request.Context(), &req)
if err != nil {
if err == services.ErrUserExists {
c.JSON(http.StatusConflict, gin.H{"error": "User with this email already exists"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to register user"})
return
}
// Send verification email (async, don't block response)
go func() {
var name string
if user.Name != nil {
name = *user.Name
}
if err := h.emailService.SendVerificationEmail(user.Email, name, verificationToken); err != nil {
// Log error but don't fail registration
println("Failed to send verification email:", err.Error())
}
}()
c.JSON(http.StatusCreated, gin.H{
"message": "Registration successful. Please check your email to verify your account.",
"user": gin.H{
"id": user.ID,
"email": user.Email,
"name": user.Name,
},
})
}
// Login handles user login
// @Summary Login user
// @Tags auth
// @Accept json
// @Produce json
// @Param request body models.LoginRequest true "Login credentials"
// @Success 200 {object} models.LoginResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /auth/login [post]
func (h *AuthHandler) Login(c *gin.Context) {
var req models.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
return
}
ipAddress := c.ClientIP()
userAgent := c.Request.UserAgent()
response, err := h.authService.Login(c.Request.Context(), &req, ipAddress, userAgent)
if err != nil {
switch err {
case services.ErrInvalidCredentials:
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid email or password"})
case services.ErrAccountLocked:
c.JSON(http.StatusForbidden, gin.H{"error": "Account is temporarily locked. Please try again later."})
case services.ErrAccountSuspended:
c.JSON(http.StatusForbidden, gin.H{
"error": "Account is suspended",
"reason": "consent_required",
"redirect": "/consent/pending",
})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Login failed"})
}
return
}
c.JSON(http.StatusOK, response)
}
// Logout handles user logout
// @Summary Logout user
// @Tags auth
// @Accept json
// @Produce json
// @Param Authorization header string true "Bearer token"
// @Success 200 {object} map[string]string
// @Router /auth/logout [post]
func (h *AuthHandler) Logout(c *gin.Context) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
if err := c.ShouldBindJSON(&req); err == nil && req.RefreshToken != "" {
_ = h.authService.Logout(c.Request.Context(), req.RefreshToken)
}
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
}
// RefreshToken refreshes the access token
// @Summary Refresh access token
// @Tags auth
// @Accept json
// @Produce json
// @Param request body models.RefreshTokenRequest true "Refresh token"
// @Success 200 {object} models.LoginResponse
// @Failure 401 {object} map[string]string
// @Router /auth/refresh [post]
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req models.RefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
response, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
if err != nil {
if err == services.ErrAccountSuspended {
c.JSON(http.StatusForbidden, gin.H{
"error": "Account is suspended",
"reason": "consent_required",
"redirect": "/consent/pending",
})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired refresh token"})
return
}
c.JSON(http.StatusOK, response)
}
// VerifyEmail verifies user email
// @Summary Verify email address
// @Tags auth
// @Accept json
// @Produce json
// @Param request body models.VerifyEmailRequest true "Verification token"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Router /auth/verify-email [post]
func (h *AuthHandler) VerifyEmail(c *gin.Context) {
var req models.VerifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
if err := h.authService.VerifyEmail(c.Request.Context(), req.Token); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired verification token"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Email verified successfully. You can now log in."})
}
// ResendVerification resends verification email
// @Summary Resend verification email
// @Tags auth
// @Accept json
// @Produce json
// @Param request body map[string]string true "Email"
// @Success 200 {object} map[string]string
// @Router /auth/resend-verification [post]
func (h *AuthHandler) ResendVerification(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
// Always return success to prevent email enumeration
c.JSON(http.StatusOK, gin.H{"message": "If an account exists with this email, a verification email has been sent."})
}
// ForgotPassword initiates password reset
// @Summary Request password reset
// @Tags auth
// @Accept json
// @Produce json
// @Param request body models.ForgotPasswordRequest true "Email"
// @Success 200 {object} map[string]string
// @Router /auth/forgot-password [post]
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req models.ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
token, userID, err := h.authService.CreatePasswordResetToken(c.Request.Context(), req.Email, c.ClientIP())
if err == nil && userID != nil {
// Send email asynchronously
go func() {
_ = h.emailService.SendPasswordResetEmail(req.Email, "", token)
}()
}
// Always return success to prevent email enumeration
c.JSON(http.StatusOK, gin.H{"message": "If an account exists with this email, a password reset link has been sent."})
}
// ResetPassword resets password with token
// @Summary Reset password
// @Tags auth
// @Accept json
// @Produce json
// @Param request body models.ResetPasswordRequest true "Reset token and new password"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Router /auth/reset-password [post]
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req models.ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
return
}
if err := h.authService.ResetPassword(c.Request.Context(), req.Token, req.NewPassword); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired reset token"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Password reset successfully. You can now log in with your new password."})
}
// GetProfile returns the current user's profile
// @Summary Get user profile
// @Tags profile
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} models.User
// @Failure 401 {object} map[string]string
// @Router /profile [get]
func (h *AuthHandler) GetProfile(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
return
}
user, err := h.authService.GetUserByID(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
return
}
c.JSON(http.StatusOK, user)
}
// UpdateProfile updates the current user's profile
// @Summary Update user profile
// @Tags profile
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body models.UpdateProfileRequest true "Profile data"
// @Success 200 {object} models.User
// @Failure 400 {object} map[string]string
// @Router /profile [put]
func (h *AuthHandler) UpdateProfile(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
return
}
var req models.UpdateProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
user, err := h.authService.UpdateProfile(c.Request.Context(), userID, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update profile"})
return
}
c.JSON(http.StatusOK, user)
}
// ChangePassword changes the current user's password
// @Summary Change password
// @Tags profile
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body models.ChangePasswordRequest true "Password data"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Router /profile/password [put]
func (h *AuthHandler) ChangePassword(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
return
}
var req models.ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
return
}
if err := h.authService.ChangePassword(c.Request.Context(), userID, req.CurrentPassword, req.NewPassword); err != nil {
if err == services.ErrInvalidCredentials {
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is incorrect"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to change password"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Password changed successfully"})
}
// GetActiveSessions returns all active sessions for the current user
// @Summary Get active sessions
// @Tags profile
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {array} models.UserSession
// @Router /profile/sessions [get]
func (h *AuthHandler) GetActiveSessions(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
return
}
sessions, err := h.authService.GetActiveSessions(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get sessions"})
return
}
c.JSON(http.StatusOK, gin.H{"sessions": sessions})
}
// RevokeSession revokes a specific session
// @Summary Revoke session
// @Tags profile
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path string true "Session ID"
// @Success 200 {object} map[string]string
// @Failure 404 {object} map[string]string
// @Router /profile/sessions/{id} [delete]
func (h *AuthHandler) RevokeSession(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
return
}
sessionID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session ID"})
return
}
if err := h.authService.RevokeSession(c.Request.Context(), userID, sessionID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Session revoked successfully"})
}
@@ -0,0 +1,561 @@
package handlers
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// ========================================
// Cookie Banner SDK API Handlers
// ========================================
// Diese Endpoints werden vom @breakpilot/consent-sdk verwendet
// für anonyme (device-basierte) Cookie-Einwilligungen.
// BannerConsentRecord repräsentiert einen anonymen Consent-Eintrag
type BannerConsentRecord struct {
ID string `json:"id"`
SiteID string `json:"site_id"`
DeviceFingerprint string `json:"device_fingerprint"`
UserID *string `json:"user_id,omitempty"`
Categories map[string]bool `json:"categories"`
Vendors map[string]bool `json:"vendors,omitempty"`
TCFString *string `json:"tcf_string,omitempty"`
IPHash *string `json:"ip_hash,omitempty"`
UserAgent *string `json:"user_agent,omitempty"`
Language *string `json:"language,omitempty"`
Platform *string `json:"platform,omitempty"`
AppVersion *string `json:"app_version,omitempty"`
Version string `json:"version"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
}
// BannerConsentRequest ist der Request-Body für POST /consent
type BannerConsentRequest struct {
SiteID string `json:"siteId" binding:"required"`
UserID *string `json:"userId,omitempty"`
DeviceFingerprint string `json:"deviceFingerprint" binding:"required"`
Consent ConsentData `json:"consent" binding:"required"`
Metadata *ConsentMetadata `json:"metadata,omitempty"`
}
// ConsentData enthält die eigentlichen Consent-Daten
type ConsentData struct {
Categories map[string]bool `json:"categories" binding:"required"`
Vendors map[string]bool `json:"vendors,omitempty"`
}
// ConsentMetadata enthält optionale Metadaten
type ConsentMetadata struct {
UserAgent *string `json:"userAgent,omitempty"`
Language *string `json:"language,omitempty"`
ScreenResolution *string `json:"screenResolution,omitempty"`
Platform *string `json:"platform,omitempty"`
AppVersion *string `json:"appVersion,omitempty"`
}
// BannerConsentResponse ist die Antwort auf POST /consent
type BannerConsentResponse struct {
ConsentID string `json:"consentId"`
Timestamp string `json:"timestamp"`
ExpiresAt string `json:"expiresAt"`
Version string `json:"version"`
}
// SiteConfig repräsentiert die Konfiguration für eine Site
type SiteConfig struct {
SiteID string `json:"siteId"`
SiteName string `json:"siteName"`
Categories []CategoryConfig `json:"categories"`
UI UIConfig `json:"ui"`
Legal LegalConfig `json:"legal"`
TCF *TCFConfig `json:"tcf,omitempty"`
}
// CategoryConfig repräsentiert eine Consent-Kategorie
type CategoryConfig struct {
ID string `json:"id"`
Name map[string]string `json:"name"`
Description map[string]string `json:"description"`
Required bool `json:"required"`
Vendors []VendorConfig `json:"vendors"`
}
// VendorConfig repräsentiert einen Vendor (Third-Party)
type VendorConfig struct {
ID string `json:"id"`
Name string `json:"name"`
PrivacyPolicyURL string `json:"privacyPolicyUrl"`
Cookies []CookieInfo `json:"cookies"`
}
// CookieInfo repräsentiert ein Cookie
type CookieInfo struct {
Name string `json:"name"`
Expiration string `json:"expiration"`
Description string `json:"description"`
}
// UIConfig repräsentiert UI-Einstellungen
type UIConfig struct {
Theme string `json:"theme"`
Position string `json:"position"`
}
// LegalConfig repräsentiert rechtliche Informationen
type LegalConfig struct {
PrivacyPolicyURL string `json:"privacyPolicyUrl"`
ImprintURL string `json:"imprintUrl"`
}
// TCFConfig repräsentiert TCF 2.2 Einstellungen
type TCFConfig struct {
Enabled bool `json:"enabled"`
CmpID int `json:"cmpId"`
CmpVersion int `json:"cmpVersion"`
}
// ========================================
// Handler Methods
// ========================================
// CreateBannerConsent erstellt oder aktualisiert einen Consent-Eintrag
// POST /api/v1/banner/consent
func (h *Handler) CreateBannerConsent(c *gin.Context) {
var req BannerConsentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"message": "Invalid request body: " + err.Error(),
})
return
}
ctx := context.Background()
// IP-Adresse anonymisieren
ipHash := anonymizeIP(c.ClientIP())
// Consent-ID generieren
consentID := uuid.New().String()
// Ablaufdatum (1 Jahr)
expiresAt := time.Now().AddDate(1, 0, 0)
// Categories und Vendors als JSON
categoriesJSON, _ := json.Marshal(req.Consent.Categories)
vendorsJSON, _ := json.Marshal(req.Consent.Vendors)
// Metadaten extrahieren
var userAgent, language, platform, appVersion *string
if req.Metadata != nil {
userAgent = req.Metadata.UserAgent
language = req.Metadata.Language
platform = req.Metadata.Platform
appVersion = req.Metadata.AppVersion
}
// In Datenbank speichern
_, err := h.db.Pool.Exec(ctx, `
INSERT INTO banner_consents (
id, site_id, device_fingerprint, user_id,
categories, vendors, ip_hash, user_agent,
language, platform, app_version, version,
expires_at, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NOW(), NOW())
ON CONFLICT (site_id, device_fingerprint)
DO UPDATE SET
categories = $5,
vendors = $6,
ip_hash = $7,
user_agent = $8,
language = $9,
platform = $10,
app_version = $11,
version = $12,
expires_at = $13,
updated_at = NOW()
RETURNING id
`, consentID, req.SiteID, req.DeviceFingerprint, req.UserID,
categoriesJSON, vendorsJSON, ipHash, userAgent,
language, platform, appVersion, "1.0.0", expiresAt)
if err != nil {
// Fallback: Existierenden Consent abrufen
var existingID string
err2 := h.db.Pool.QueryRow(ctx, `
SELECT id FROM banner_consents
WHERE site_id = $1 AND device_fingerprint = $2
`, req.SiteID, req.DeviceFingerprint).Scan(&existingID)
if err2 == nil {
consentID = existingID
}
}
// Audit-Log schreiben
h.logBannerConsentAudit(ctx, consentID, "created", req, ipHash)
// Response
c.JSON(http.StatusCreated, BannerConsentResponse{
ConsentID: consentID,
Timestamp: time.Now().UTC().Format(time.RFC3339),
ExpiresAt: expiresAt.UTC().Format(time.RFC3339),
Version: "1.0.0",
})
}
// GetBannerConsent ruft einen bestehenden Consent ab
// GET /api/v1/banner/consent?siteId=xxx&deviceFingerprint=xxx
func (h *Handler) GetBannerConsent(c *gin.Context) {
siteID := c.Query("siteId")
deviceFingerprint := c.Query("deviceFingerprint")
if siteID == "" || deviceFingerprint == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "missing_parameters",
"message": "siteId and deviceFingerprint are required",
})
return
}
ctx := context.Background()
var record BannerConsentRecord
var categoriesJSON, vendorsJSON []byte
err := h.db.Pool.QueryRow(ctx, `
SELECT id, site_id, device_fingerprint, user_id,
categories, vendors, version,
created_at, updated_at, expires_at, revoked_at
FROM banner_consents
WHERE site_id = $1 AND device_fingerprint = $2 AND revoked_at IS NULL
`, siteID, deviceFingerprint).Scan(
&record.ID, &record.SiteID, &record.DeviceFingerprint, &record.UserID,
&categoriesJSON, &vendorsJSON, &record.Version,
&record.CreatedAt, &record.UpdatedAt, &record.ExpiresAt, &record.RevokedAt,
)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "consent_not_found",
"message": "No consent record found",
})
return
}
// JSON parsen
json.Unmarshal(categoriesJSON, &record.Categories)
json.Unmarshal(vendorsJSON, &record.Vendors)
c.JSON(http.StatusOK, gin.H{
"consentId": record.ID,
"consent": gin.H{
"categories": record.Categories,
"vendors": record.Vendors,
},
"createdAt": record.CreatedAt.UTC().Format(time.RFC3339),
"updatedAt": record.UpdatedAt.UTC().Format(time.RFC3339),
"expiresAt": record.ExpiresAt.UTC().Format(time.RFC3339),
"version": record.Version,
})
}
// RevokeBannerConsent widerruft einen Consent
// DELETE /api/v1/banner/consent/:consentId
func (h *Handler) RevokeBannerConsent(c *gin.Context) {
consentID := c.Param("consentId")
ctx := context.Background()
result, err := h.db.Pool.Exec(ctx, `
UPDATE banner_consents
SET revoked_at = NOW(), updated_at = NOW()
WHERE id = $1 AND revoked_at IS NULL
`, consentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "revoke_failed",
"message": "Failed to revoke consent",
})
return
}
if result.RowsAffected() == 0 {
c.JSON(http.StatusNotFound, gin.H{
"error": "consent_not_found",
"message": "Consent not found or already revoked",
})
return
}
// Audit-Log
h.logBannerConsentAudit(ctx, consentID, "revoked", nil, anonymizeIP(c.ClientIP()))
c.JSON(http.StatusOK, gin.H{
"status": "revoked",
"revokedAt": time.Now().UTC().Format(time.RFC3339),
})
}
// GetSiteConfig gibt die Konfiguration für eine Site zurück
// GET /api/v1/banner/config/:siteId
func (h *Handler) GetSiteConfig(c *gin.Context) {
siteID := c.Param("siteId")
// Standard-Kategorien (aus Datenbank oder Default)
categories := []CategoryConfig{
{
ID: "essential",
Name: map[string]string{
"de": "Essentiell",
"en": "Essential",
},
Description: map[string]string{
"de": "Notwendig für die Grundfunktionen der Website.",
"en": "Required for basic website functionality.",
},
Required: true,
Vendors: []VendorConfig{},
},
{
ID: "functional",
Name: map[string]string{
"de": "Funktional",
"en": "Functional",
},
Description: map[string]string{
"de": "Ermöglicht Personalisierung und Komfortfunktionen.",
"en": "Enables personalization and comfort features.",
},
Required: false,
Vendors: []VendorConfig{},
},
{
ID: "analytics",
Name: map[string]string{
"de": "Statistik",
"en": "Analytics",
},
Description: map[string]string{
"de": "Hilft uns, die Website zu verbessern.",
"en": "Helps us improve the website.",
},
Required: false,
Vendors: []VendorConfig{},
},
{
ID: "marketing",
Name: map[string]string{
"de": "Marketing",
"en": "Marketing",
},
Description: map[string]string{
"de": "Ermöglicht personalisierte Werbung.",
"en": "Enables personalized advertising.",
},
Required: false,
Vendors: []VendorConfig{},
},
{
ID: "social",
Name: map[string]string{
"de": "Soziale Medien",
"en": "Social Media",
},
Description: map[string]string{
"de": "Ermöglicht Inhalte von sozialen Netzwerken.",
"en": "Enables content from social networks.",
},
Required: false,
Vendors: []VendorConfig{},
},
}
config := SiteConfig{
SiteID: siteID,
SiteName: "BreakPilot",
Categories: categories,
UI: UIConfig{
Theme: "auto",
Position: "bottom",
},
Legal: LegalConfig{
PrivacyPolicyURL: "/datenschutz",
ImprintURL: "/impressum",
},
}
c.JSON(http.StatusOK, config)
}
// ExportBannerConsent exportiert alle Consent-Daten eines Nutzers (DSGVO Art. 20)
// GET /api/v1/banner/consent/export?userId=xxx
func (h *Handler) ExportBannerConsent(c *gin.Context) {
userID := c.Query("userId")
if userID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "missing_user_id",
"message": "userId parameter is required",
})
return
}
ctx := context.Background()
rows, err := h.db.Pool.Query(ctx, `
SELECT id, site_id, device_fingerprint, categories, vendors,
version, created_at, updated_at, revoked_at
FROM banner_consents
WHERE user_id = $1
ORDER BY created_at DESC
`, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "export_failed",
"message": "Failed to export consent data",
})
return
}
defer rows.Close()
var consents []map[string]interface{}
for rows.Next() {
var id, siteID, deviceFingerprint, version string
var categoriesJSON, vendorsJSON []byte
var createdAt, updatedAt time.Time
var revokedAt *time.Time
rows.Scan(&id, &siteID, &deviceFingerprint, &categoriesJSON, &vendorsJSON,
&version, &createdAt, &updatedAt, &revokedAt)
var categories, vendors map[string]bool
json.Unmarshal(categoriesJSON, &categories)
json.Unmarshal(vendorsJSON, &vendors)
consent := map[string]interface{}{
"consentId": id,
"siteId": siteID,
"consent": map[string]interface{}{
"categories": categories,
"vendors": vendors,
},
"createdAt": createdAt.UTC().Format(time.RFC3339),
"revokedAt": nil,
}
if revokedAt != nil {
consent["revokedAt"] = revokedAt.UTC().Format(time.RFC3339)
}
consents = append(consents, consent)
}
c.JSON(http.StatusOK, gin.H{
"userId": userID,
"exportedAt": time.Now().UTC().Format(time.RFC3339),
"consents": consents,
})
}
// GetBannerStats gibt anonymisierte Statistiken zurück (Admin)
// GET /api/v1/banner/admin/stats/:siteId
func (h *Handler) GetBannerStats(c *gin.Context) {
siteID := c.Param("siteId")
ctx := context.Background()
// Gesamtanzahl Consents
var totalConsents int
h.db.Pool.QueryRow(ctx, `
SELECT COUNT(*) FROM banner_consents
WHERE site_id = $1 AND revoked_at IS NULL
`, siteID).Scan(&totalConsents)
// Consent-Rate pro Kategorie
categoryStats := make(map[string]map[string]interface{})
rows, _ := h.db.Pool.Query(ctx, `
SELECT
key as category,
COUNT(*) FILTER (WHERE value::text = 'true') as accepted,
COUNT(*) as total
FROM banner_consents,
jsonb_each(categories::jsonb)
WHERE site_id = $1 AND revoked_at IS NULL
GROUP BY key
`, siteID)
if rows != nil {
defer rows.Close()
for rows.Next() {
var category string
var accepted, total int
rows.Scan(&category, &accepted, &total)
rate := float64(0)
if total > 0 {
rate = float64(accepted) / float64(total)
}
categoryStats[category] = map[string]interface{}{
"accepted": accepted,
"rate": rate,
}
}
}
c.JSON(http.StatusOK, gin.H{
"siteId": siteID,
"period": gin.H{
"from": time.Now().AddDate(0, -1, 0).Format("2006-01-02"),
"to": time.Now().Format("2006-01-02"),
},
"totalConsents": totalConsents,
"consentByCategory": categoryStats,
})
}
// ========================================
// Helper Functions
// ========================================
// anonymizeIP anonymisiert eine IP-Adresse (DSGVO-konform)
func anonymizeIP(ip string) string {
// IPv4: Letztes Oktett auf 0
parts := strings.Split(ip, ".")
if len(parts) == 4 {
parts[3] = "0"
anonymized := strings.Join(parts, ".")
hash := sha256.Sum256([]byte(anonymized))
return hex.EncodeToString(hash[:])[:16]
}
// IPv6: Hash
hash := sha256.Sum256([]byte(ip))
return hex.EncodeToString(hash[:])[:16]
}
// logBannerConsentAudit schreibt einen Audit-Log-Eintrag
func (h *Handler) logBannerConsentAudit(ctx context.Context, consentID, action string, req interface{}, ipHash string) {
details, _ := json.Marshal(req)
h.db.Pool.Exec(ctx, `
INSERT INTO banner_consent_audit_log (
id, consent_id, action, details, ip_hash, created_at
) VALUES ($1, $2, $3, $4, $5, NOW())
`, uuid.New().String(), consentID, action, string(details), ipHash)
}
@@ -0,0 +1,511 @@
package handlers
import (
"net/http"
"time"
"github.com/breakpilot/consent-service/internal/services/jitsi"
"github.com/breakpilot/consent-service/internal/services/matrix"
"github.com/gin-gonic/gin"
)
// CommunicationHandlers handles Matrix and Jitsi API endpoints
type CommunicationHandlers struct {
matrixService *matrix.MatrixService
jitsiService *jitsi.JitsiService
}
// NewCommunicationHandlers creates new communication handlers
func NewCommunicationHandlers(matrixSvc *matrix.MatrixService, jitsiSvc *jitsi.JitsiService) *CommunicationHandlers {
return &CommunicationHandlers{
matrixService: matrixSvc,
jitsiService: jitsiSvc,
}
}
// ========================================
// Health & Status Endpoints
// ========================================
// GetCommunicationStatus returns status of Matrix and Jitsi services
func (h *CommunicationHandlers) GetCommunicationStatus(c *gin.Context) {
ctx := c.Request.Context()
status := gin.H{
"timestamp": time.Now().UTC().Format(time.RFC3339),
}
// Check Matrix
if h.matrixService != nil {
matrixErr := h.matrixService.HealthCheck(ctx)
status["matrix"] = gin.H{
"enabled": true,
"healthy": matrixErr == nil,
"server_name": h.matrixService.GetServerName(),
"error": errToString(matrixErr),
}
} else {
status["matrix"] = gin.H{
"enabled": false,
"healthy": false,
}
}
// Check Jitsi
if h.jitsiService != nil {
jitsiErr := h.jitsiService.HealthCheck(ctx)
serverInfo := h.jitsiService.GetServerInfo()
status["jitsi"] = gin.H{
"enabled": true,
"healthy": jitsiErr == nil,
"base_url": serverInfo["base_url"],
"auth_enabled": serverInfo["auth_enabled"],
"error": errToString(jitsiErr),
}
} else {
status["jitsi"] = gin.H{
"enabled": false,
"healthy": false,
}
}
c.JSON(http.StatusOK, status)
}
// ========================================
// Matrix Room Endpoints
// ========================================
// CreateRoomRequest for creating Matrix rooms
type CreateRoomRequest struct {
Type string `json:"type" binding:"required"` // "class_info", "student_dm", "parent_rep"
ClassName string `json:"class_name"`
SchoolName string `json:"school_name"`
StudentName string `json:"student_name,omitempty"`
TeacherIDs []string `json:"teacher_ids"`
ParentIDs []string `json:"parent_ids,omitempty"`
ParentRepIDs []string `json:"parent_rep_ids,omitempty"`
}
// CreateRoom creates a new Matrix room based on type
func (h *CommunicationHandlers) CreateRoom(c *gin.Context) {
if h.matrixService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
return
}
var req CreateRoomRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
var resp *matrix.CreateRoomResponse
var err error
switch req.Type {
case "class_info":
resp, err = h.matrixService.CreateClassInfoRoom(ctx, req.ClassName, req.SchoolName, req.TeacherIDs)
case "student_dm":
resp, err = h.matrixService.CreateStudentDMRoom(ctx, req.StudentName, req.ClassName, req.TeacherIDs, req.ParentIDs)
case "parent_rep":
resp, err = h.matrixService.CreateParentRepRoom(ctx, req.ClassName, req.TeacherIDs, req.ParentRepIDs)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid room type. Use: class_info, student_dm, parent_rep"})
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"room_id": resp.RoomID,
"type": req.Type,
})
}
// InviteUserRequest for inviting users to rooms
type InviteUserRequest struct {
RoomID string `json:"room_id" binding:"required"`
UserID string `json:"user_id" binding:"required"`
}
// InviteUser invites a user to a Matrix room
func (h *CommunicationHandlers) InviteUser(c *gin.Context) {
if h.matrixService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
return
}
var req InviteUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
if err := h.matrixService.InviteUser(ctx, req.RoomID, req.UserID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// SendMessageRequest for sending messages
type SendMessageRequest struct {
RoomID string `json:"room_id" binding:"required"`
Message string `json:"message" binding:"required"`
HTML string `json:"html,omitempty"`
}
// SendMessage sends a message to a Matrix room
func (h *CommunicationHandlers) SendMessage(c *gin.Context) {
if h.matrixService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
return
}
var req SendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
var err error
if req.HTML != "" {
err = h.matrixService.SendHTMLMessage(ctx, req.RoomID, req.Message, req.HTML)
} else {
err = h.matrixService.SendMessage(ctx, req.RoomID, req.Message)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// SendNotificationRequest for sending school notifications
type SendNotificationRequest struct {
RoomID string `json:"room_id" binding:"required"`
Type string `json:"type" binding:"required"` // "absence", "grade", "announcement"
StudentName string `json:"student_name,omitempty"`
Date string `json:"date,omitempty"`
Lesson int `json:"lesson,omitempty"`
Subject string `json:"subject,omitempty"`
GradeType string `json:"grade_type,omitempty"`
Grade float64 `json:"grade,omitempty"`
Title string `json:"title,omitempty"`
Content string `json:"content,omitempty"`
TeacherName string `json:"teacher_name,omitempty"`
}
// SendNotification sends a typed notification (absence, grade, announcement)
func (h *CommunicationHandlers) SendNotification(c *gin.Context) {
if h.matrixService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
return
}
var req SendNotificationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
var err error
switch req.Type {
case "absence":
err = h.matrixService.SendAbsenceNotification(ctx, req.RoomID, req.StudentName, req.Date, req.Lesson)
case "grade":
err = h.matrixService.SendGradeNotification(ctx, req.RoomID, req.StudentName, req.Subject, req.GradeType, req.Grade)
case "announcement":
err = h.matrixService.SendClassAnnouncement(ctx, req.RoomID, req.Title, req.Content, req.TeacherName)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification type. Use: absence, grade, announcement"})
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// RegisterUserRequest for user registration
type RegisterUserRequest struct {
Username string `json:"username" binding:"required"`
DisplayName string `json:"display_name"`
}
// RegisterMatrixUser registers a new Matrix user
func (h *CommunicationHandlers) RegisterMatrixUser(c *gin.Context) {
if h.matrixService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
return
}
var req RegisterUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
resp, err := h.matrixService.RegisterUser(ctx, req.Username, req.DisplayName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"user_id": resp.UserID,
})
}
// ========================================
// Jitsi Video Conference Endpoints
// ========================================
// CreateMeetingRequest for creating Jitsi meetings
type CreateMeetingRequest struct {
Type string `json:"type" binding:"required"` // "quick", "training", "parent_teacher", "class"
Title string `json:"title,omitempty"`
DisplayName string `json:"display_name"`
Email string `json:"email,omitempty"`
Duration int `json:"duration,omitempty"` // minutes
ClassName string `json:"class_name,omitempty"`
ParentName string `json:"parent_name,omitempty"`
StudentName string `json:"student_name,omitempty"`
Subject string `json:"subject,omitempty"`
StartTime time.Time `json:"start_time,omitempty"`
}
// CreateMeeting creates a new Jitsi meeting
func (h *CommunicationHandlers) CreateMeeting(c *gin.Context) {
if h.jitsiService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Jitsi service not configured"})
return
}
var req CreateMeetingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
var link *jitsi.MeetingLink
var err error
switch req.Type {
case "quick":
link, err = h.jitsiService.CreateQuickMeeting(ctx, req.DisplayName)
case "training":
link, err = h.jitsiService.CreateTrainingSession(ctx, req.Title, req.DisplayName, req.Email, req.Duration)
case "parent_teacher":
link, err = h.jitsiService.CreateParentTeacherMeeting(ctx, req.DisplayName, req.ParentName, req.StudentName, req.StartTime)
case "class":
link, err = h.jitsiService.CreateClassMeeting(ctx, req.ClassName, req.DisplayName, req.Subject)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid meeting type. Use: quick, training, parent_teacher, class"})
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"room_name": link.RoomName,
"url": link.URL,
"join_url": link.JoinURL,
"moderator_url": link.ModeratorURL,
"password": link.Password,
"expires_at": link.ExpiresAt,
})
}
// GetEmbedURLRequest for embedding Jitsi
type GetEmbedURLRequest struct {
RoomName string `json:"room_name" binding:"required"`
DisplayName string `json:"display_name"`
AudioMuted bool `json:"audio_muted"`
VideoMuted bool `json:"video_muted"`
}
// GetEmbedURL returns an embeddable Jitsi URL
func (h *CommunicationHandlers) GetEmbedURL(c *gin.Context) {
if h.jitsiService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Jitsi service not configured"})
return
}
var req GetEmbedURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
config := &jitsi.MeetingConfig{
StartWithAudioMuted: req.AudioMuted,
StartWithVideoMuted: req.VideoMuted,
DisableDeepLinking: true,
}
embedURL := h.jitsiService.BuildEmbedURL(req.RoomName, req.DisplayName, config)
iframeCode := h.jitsiService.BuildIFrameCode(req.RoomName, 800, 600)
c.JSON(http.StatusOK, gin.H{
"embed_url": embedURL,
"iframe_code": iframeCode,
})
}
// GetJitsiInfo returns Jitsi server information
func (h *CommunicationHandlers) GetJitsiInfo(c *gin.Context) {
if h.jitsiService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Jitsi service not configured"})
return
}
info := h.jitsiService.GetServerInfo()
c.JSON(http.StatusOK, info)
}
// ========================================
// Admin Statistics Endpoints (for Admin Panel)
// ========================================
// CommunicationStats holds communication service statistics
type CommunicationStats struct {
Matrix MatrixStats `json:"matrix"`
Jitsi JitsiStats `json:"jitsi"`
}
// MatrixStats holds Matrix-specific statistics
type MatrixStats struct {
Enabled bool `json:"enabled"`
Healthy bool `json:"healthy"`
ServerName string `json:"server_name"`
// TODO: Add real stats from Matrix Synapse Admin API
TotalUsers int `json:"total_users"`
TotalRooms int `json:"total_rooms"`
ActiveToday int `json:"active_today"`
MessagesToday int `json:"messages_today"`
}
// JitsiStats holds Jitsi-specific statistics
type JitsiStats struct {
Enabled bool `json:"enabled"`
Healthy bool `json:"healthy"`
BaseURL string `json:"base_url"`
AuthEnabled bool `json:"auth_enabled"`
// TODO: Add real stats from Jitsi SRTP API or Jicofo
ActiveMeetings int `json:"active_meetings"`
TotalParticipants int `json:"total_participants"`
MeetingsToday int `json:"meetings_today"`
AvgDurationMin int `json:"avg_duration_min"`
}
// GetAdminStats returns admin statistics for Matrix and Jitsi
func (h *CommunicationHandlers) GetAdminStats(c *gin.Context) {
ctx := c.Request.Context()
stats := CommunicationStats{}
// Matrix Stats
if h.matrixService != nil {
matrixErr := h.matrixService.HealthCheck(ctx)
stats.Matrix = MatrixStats{
Enabled: true,
Healthy: matrixErr == nil,
ServerName: h.matrixService.GetServerName(),
// Placeholder stats - in production these would come from Synapse Admin API
TotalUsers: 0,
TotalRooms: 0,
ActiveToday: 0,
MessagesToday: 0,
}
} else {
stats.Matrix = MatrixStats{Enabled: false}
}
// Jitsi Stats
if h.jitsiService != nil {
jitsiErr := h.jitsiService.HealthCheck(ctx)
serverInfo := h.jitsiService.GetServerInfo()
stats.Jitsi = JitsiStats{
Enabled: true,
Healthy: jitsiErr == nil,
BaseURL: serverInfo["base_url"],
AuthEnabled: serverInfo["auth_enabled"] == "true",
// Placeholder stats - in production these would come from Jicofo/JVB stats
ActiveMeetings: 0,
TotalParticipants: 0,
MeetingsToday: 0,
AvgDurationMin: 0,
}
} else {
stats.Jitsi = JitsiStats{Enabled: false}
}
c.JSON(http.StatusOK, stats)
}
// ========================================
// Helper Functions
// ========================================
func errToString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
// RegisterRoutes registers all communication routes
func (h *CommunicationHandlers) RegisterRoutes(router *gin.RouterGroup, jwtSecret string, authMiddleware gin.HandlerFunc) {
comm := router.Group("/communication")
{
// Public health check
comm.GET("/status", h.GetCommunicationStatus)
// Protected routes
protected := comm.Group("")
protected.Use(authMiddleware)
{
// Matrix
protected.POST("/rooms", h.CreateRoom)
protected.POST("/rooms/invite", h.InviteUser)
protected.POST("/messages", h.SendMessage)
protected.POST("/notifications", h.SendNotification)
// Jitsi
protected.POST("/meetings", h.CreateMeeting)
protected.POST("/meetings/embed", h.GetEmbedURL)
protected.GET("/jitsi/info", h.GetJitsiInfo)
}
// Admin routes (for Matrix user registration and stats)
admin := comm.Group("/admin")
admin.Use(authMiddleware)
// TODO: Add AdminOnly middleware
{
admin.POST("/matrix/users", h.RegisterMatrixUser)
admin.GET("/stats", h.GetAdminStats)
}
}
}
@@ -0,0 +1,407 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestGetCommunicationStatus_NoServices tests status with no services configured
func TestGetCommunicationStatus_NoServices_ReturnsDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create handler with no services
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.GET("/api/v1/communication/status", handler.GetCommunicationStatus)
req, _ := http.NewRequest("GET", "/api/v1/communication/status", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
// Check matrix is disabled
matrix, ok := response["matrix"].(map[string]interface{})
if !ok {
t.Fatal("Expected matrix in response")
}
if matrix["enabled"] != false {
t.Error("Expected matrix.enabled to be false")
}
// Check jitsi is disabled
jitsi, ok := response["jitsi"].(map[string]interface{})
if !ok {
t.Fatal("Expected jitsi in response")
}
if jitsi["enabled"] != false {
t.Error("Expected jitsi.enabled to be false")
}
// Check timestamp exists
if _, ok := response["timestamp"]; !ok {
t.Error("Expected timestamp in response")
}
}
// TestCreateRoom_NoMatrixService tests room creation without Matrix
func TestCreateRoom_NoMatrixService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/rooms", handler.CreateRoom)
body := `{"type": "class_info", "class_name": "5b"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
var response map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "Matrix service not configured" {
t.Errorf("Unexpected error message: %s", response["error"])
}
}
// TestCreateRoom_InvalidBody tests room creation with invalid body
func TestCreateRoom_InvalidBody_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/rooms", handler.CreateRoom)
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms", bytes.NewBufferString("{invalid"))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Service unavailable check happens first, so we get 503
// This is expected behavior - service check before body validation
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestInviteUser_NoMatrixService tests invite without Matrix
func TestInviteUser_NoMatrixService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/rooms/invite", handler.InviteUser)
body := `{"room_id": "!abc:server", "user_id": "@user:server"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms/invite", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestSendMessage_NoMatrixService tests message sending without Matrix
func TestSendMessage_NoMatrixService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/messages", handler.SendMessage)
body := `{"room_id": "!abc:server", "message": "Hello"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/messages", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestSendNotification_NoMatrixService tests notification without Matrix
func TestSendNotification_NoMatrixService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/notifications", handler.SendNotification)
body := `{"room_id": "!abc:server", "type": "absence", "student_name": "Max"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/notifications", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestCreateMeeting_NoJitsiService tests meeting creation without Jitsi
func TestCreateMeeting_NoJitsiService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/meetings", handler.CreateMeeting)
body := `{"type": "quick", "display_name": "Teacher"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/meetings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
var response map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "Jitsi service not configured" {
t.Errorf("Unexpected error message: %s", response["error"])
}
}
// TestGetEmbedURL_NoJitsiService tests embed URL without Jitsi
func TestGetEmbedURL_NoJitsiService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/meetings/embed", handler.GetEmbedURL)
body := `{"room_name": "test-room", "display_name": "User"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/meetings/embed", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestGetJitsiInfo_NoJitsiService tests Jitsi info without service
func TestGetJitsiInfo_NoJitsiService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.GET("/api/v1/communication/jitsi/info", handler.GetJitsiInfo)
req, _ := http.NewRequest("GET", "/api/v1/communication/jitsi/info", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestRegisterMatrixUser_NoMatrixService tests user registration without Matrix
func TestRegisterMatrixUser_NoMatrixService_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/admin/matrix/users", handler.RegisterMatrixUser)
body := `{"username": "testuser", "display_name": "Test User"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/admin/matrix/users", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestGetAdminStats_NoServices tests admin stats without services
func TestGetAdminStats_NoServices_ReturnsDisabledStats(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.GET("/api/v1/communication/admin/stats", handler.GetAdminStats)
req, _ := http.NewRequest("GET", "/api/v1/communication/admin/stats", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response CommunicationStats
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response.Matrix.Enabled {
t.Error("Expected matrix.enabled to be false")
}
if response.Jitsi.Enabled {
t.Error("Expected jitsi.enabled to be false")
}
}
// TestErrToString tests the helper function
func TestErrToString_NilError_ReturnsEmpty(t *testing.T) {
result := errToString(nil)
if result != "" {
t.Errorf("Expected empty string, got %s", result)
}
}
// TestErrToString_WithError_ReturnsMessage tests error string conversion
func TestErrToString_WithError_ReturnsMessage(t *testing.T) {
err := &testError{"test error message"}
result := errToString(err)
if result != "test error message" {
t.Errorf("Expected 'test error message', got %s", result)
}
}
// testError is a simple error implementation for testing
type testError struct {
message string
}
func (e *testError) Error() string {
return e.message
}
// TestCreateRoomRequest_Types tests different room types validation
func TestCreateRoom_InvalidType_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
// Since we don't have Matrix service, we get 503 first
// This test documents expected behavior when Matrix IS available
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/rooms", handler.CreateRoom)
body := `{"type": "invalid_type", "class_name": "5b"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Without Matrix service, we get 503 before type validation
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestCreateMeeting_InvalidType tests invalid meeting type
func TestCreateMeeting_InvalidType_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/meetings", handler.CreateMeeting)
body := `{"type": "invalid", "display_name": "User"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/meetings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Without Jitsi service, we get 503 before type validation
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestSendNotification_InvalidType tests invalid notification type
func TestSendNotification_InvalidType_Returns503(t *testing.T) {
gin.SetMode(gin.TestMode)
handler := NewCommunicationHandlers(nil, nil)
router := gin.New()
router.POST("/api/v1/communication/notifications", handler.SendNotification)
body := `{"room_id": "!abc:server", "type": "invalid", "student_name": "Max"}`
req, _ := http.NewRequest("POST", "/api/v1/communication/notifications", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Without Matrix service, we get 503 before type validation
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503, got %d", w.Code)
}
}
// TestNewCommunicationHandlers tests constructor
func TestNewCommunicationHandlers_WithNilServices_CreatesHandler(t *testing.T) {
handler := NewCommunicationHandlers(nil, nil)
if handler == nil {
t.Fatal("Expected handler to be created")
}
if handler.matrixService != nil {
t.Error("Expected matrixService to be nil")
}
if handler.jitsiService != nil {
t.Error("Expected jitsiService to be nil")
}
}
@@ -0,0 +1,92 @@
package handlers
import (
"net/http"
"github.com/breakpilot/consent-service/internal/middleware"
"github.com/breakpilot/consent-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// DeadlineHandler handles deadline-related requests
type DeadlineHandler struct {
deadlineService *services.DeadlineService
}
// NewDeadlineHandler creates a new deadline handler
func NewDeadlineHandler(deadlineService *services.DeadlineService) *DeadlineHandler {
return &DeadlineHandler{
deadlineService: deadlineService,
}
}
// GetPendingDeadlines returns all pending consent deadlines for the current user
func (h *DeadlineHandler) GetPendingDeadlines(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
deadlines, err := h.deadlineService.GetPendingDeadlines(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch deadlines"})
return
}
c.JSON(http.StatusOK, gin.H{
"deadlines": deadlines,
"count": len(deadlines),
})
}
// GetSuspensionStatus returns the current suspension status for a user
func (h *DeadlineHandler) GetSuspensionStatus(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
suspended, err := h.deadlineService.IsUserSuspended(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check suspension status"})
return
}
response := gin.H{
"suspended": suspended,
}
if suspended {
suspension, err := h.deadlineService.GetAccountSuspension(c.Request.Context(), userID)
if err == nil && suspension != nil {
response["reason"] = suspension.Reason
response["suspended_at"] = suspension.SuspendedAt
response["details"] = suspension.Details
}
deadlines, err := h.deadlineService.GetPendingDeadlines(c.Request.Context(), userID)
if err == nil {
response["pending_deadlines"] = deadlines
}
}
c.JSON(http.StatusOK, response)
}
// TriggerDeadlineProcessing manually triggers deadline processing (admin only)
func (h *DeadlineHandler) TriggerDeadlineProcessing(c *gin.Context) {
if !middleware.IsAdmin(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin access required"})
return
}
if err := h.deadlineService.ProcessDailyDeadlines(c.Request.Context()); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process deadlines"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Deadline processing completed"})
}
@@ -0,0 +1,948 @@
package handlers
import (
"context"
"net/http"
"strconv"
"time"
"github.com/breakpilot/consent-service/internal/middleware"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// DSRHandler handles Data Subject Request HTTP endpoints
type DSRHandler struct {
dsrService *services.DSRService
}
// NewDSRHandler creates a new DSR handler
func NewDSRHandler(dsrService *services.DSRService) *DSRHandler {
return &DSRHandler{
dsrService: dsrService,
}
}
// ========================================
// USER ENDPOINTS
// ========================================
// CreateDSR creates a new data subject request (user-facing)
func (h *DSRHandler) CreateDSR(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
var req models.CreateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
// Get user email if not provided
if req.RequesterEmail == "" {
var email string
ctx := context.Background()
h.dsrService.GetPool().QueryRow(ctx, "SELECT email FROM users WHERE id = $1", userID).Scan(&email)
req.RequesterEmail = email
}
// Set source as API
req.Source = "api"
dsr, err := h.dsrService.CreateRequest(c.Request.Context(), req, &userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Ihre Anfrage wurde erfolgreich eingereicht",
"request_number": dsr.RequestNumber,
"dsr": dsr,
})
}
// GetMyDSRs returns DSRs for the current user
func (h *DSRHandler) GetMyDSRs(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
dsrs, err := h.dsrService.ListByUser(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch requests"})
return
}
c.JSON(http.StatusOK, gin.H{"requests": dsrs})
}
// GetMyDSR returns a specific DSR for the current user
func (h *DSRHandler) GetMyDSR(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
dsr, err := h.dsrService.GetByID(c.Request.Context(), dsrID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Request not found"})
return
}
// Verify ownership
if dsr.UserID == nil || *dsr.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
return
}
c.JSON(http.StatusOK, dsr)
}
// CancelMyDSR cancels a user's own DSR
func (h *DSRHandler) CancelMyDSR(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
err = h.dsrService.CancelRequest(c.Request.Context(), dsrID, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde storniert"})
}
// ========================================
// ADMIN ENDPOINTS
// ========================================
// AdminListDSR returns all DSRs with filters (admin only)
func (h *DSRHandler) AdminListDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
// Parse pagination
limit := 20
offset := 0
if l := c.Query("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
if o := c.Query("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
}
}
// Parse filters
filters := models.DSRListFilters{}
if status := c.Query("status"); status != "" {
filters.Status = &status
}
if reqType := c.Query("request_type"); reqType != "" {
filters.RequestType = &reqType
}
if assignedTo := c.Query("assigned_to"); assignedTo != "" {
filters.AssignedTo = &assignedTo
}
if priority := c.Query("priority"); priority != "" {
filters.Priority = &priority
}
if c.Query("overdue_only") == "true" {
filters.OverdueOnly = true
}
if search := c.Query("search"); search != "" {
filters.Search = &search
}
if from := c.Query("from_date"); from != "" {
if t, err := time.Parse("2006-01-02", from); err == nil {
filters.FromDate = &t
}
}
if to := c.Query("to_date"); to != "" {
if t, err := time.Parse("2006-01-02", to); err == nil {
filters.ToDate = &t
}
}
dsrs, total, err := h.dsrService.List(c.Request.Context(), filters, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch requests"})
return
}
c.JSON(http.StatusOK, gin.H{
"requests": dsrs,
"total": total,
"limit": limit,
"offset": offset,
})
}
// AdminGetDSR returns a specific DSR (admin only)
func (h *DSRHandler) AdminGetDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
dsr, err := h.dsrService.GetByID(c.Request.Context(), dsrID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Request not found"})
return
}
c.JSON(http.StatusOK, dsr)
}
// AdminCreateDSR creates a DSR manually (admin only)
func (h *DSRHandler) AdminCreateDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.CreateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
// Set source as admin_panel
if req.Source == "" {
req.Source = "admin_panel"
}
dsr, err := h.dsrService.CreateRequest(c.Request.Context(), req, &userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Anfrage wurde erstellt",
"request_number": dsr.RequestNumber,
"dsr": dsr,
})
}
// AdminUpdateDSR updates a DSR (admin only)
func (h *DSRHandler) AdminUpdateDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.UpdateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := c.Request.Context()
// Update status if provided
if req.Status != nil {
err = h.dsrService.UpdateStatus(ctx, dsrID, models.DSRStatus(*req.Status), "", &userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
}
// Update processing notes
if req.ProcessingNotes != nil {
h.dsrService.GetPool().Exec(ctx, `
UPDATE data_subject_requests SET processing_notes = $1, updated_at = NOW() WHERE id = $2
`, *req.ProcessingNotes, dsrID)
}
// Update priority
if req.Priority != nil {
h.dsrService.GetPool().Exec(ctx, `
UPDATE data_subject_requests SET priority = $1, updated_at = NOW() WHERE id = $2
`, *req.Priority, dsrID)
}
// Get updated DSR
dsr, _ := h.dsrService.GetByID(ctx, dsrID)
c.JSON(http.StatusOK, gin.H{
"message": "Anfrage wurde aktualisiert",
"dsr": dsr,
})
}
// AdminGetDSRStats returns dashboard statistics
func (h *DSRHandler) AdminGetDSRStats(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
stats, err := h.dsrService.GetDashboardStats(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch statistics"})
return
}
c.JSON(http.StatusOK, stats)
}
// AdminVerifyIdentity verifies the identity of a requester
func (h *DSRHandler) AdminVerifyIdentity(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.VerifyDSRIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.VerifyIdentity(c.Request.Context(), dsrID, req.Method, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Identität wurde verifiziert"})
}
// AdminAssignDSR assigns a DSR to a user
func (h *DSRHandler) AdminAssignDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req struct {
AssigneeID string `json:"assignee_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
assigneeID, err := uuid.Parse(req.AssigneeID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid assignee ID"})
return
}
err = h.dsrService.AssignRequest(c.Request.Context(), dsrID, assigneeID, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde zugewiesen"})
}
// AdminExtendDSRDeadline extends the deadline for a DSR
func (h *DSRHandler) AdminExtendDSRDeadline(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.ExtendDSRDeadlineRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.ExtendDeadline(c.Request.Context(), dsrID, req.Reason, req.Days, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Frist wurde verlängert"})
}
// AdminCompleteDSR marks a DSR as completed
func (h *DSRHandler) AdminCompleteDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.CompleteDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.CompleteRequest(c.Request.Context(), dsrID, req.ResultSummary, req.ResultData, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde abgeschlossen"})
}
// AdminRejectDSR rejects a DSR
func (h *DSRHandler) AdminRejectDSR(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.RejectDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.RejectRequest(c.Request.Context(), dsrID, req.Reason, req.LegalBasis, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde abgelehnt"})
}
// AdminGetDSRHistory returns the status history for a DSR
func (h *DSRHandler) AdminGetDSRHistory(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
history, err := h.dsrService.GetStatusHistory(c.Request.Context(), dsrID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch history"})
return
}
c.JSON(http.StatusOK, gin.H{"history": history})
}
// AdminGetDSRCommunications returns communications for a DSR
func (h *DSRHandler) AdminGetDSRCommunications(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
comms, err := h.dsrService.GetCommunications(c.Request.Context(), dsrID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch communications"})
return
}
c.JSON(http.StatusOK, gin.H{"communications": comms})
}
// AdminSendDSRCommunication sends a communication for a DSR
func (h *DSRHandler) AdminSendDSRCommunication(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req models.SendDSRCommunicationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.SendCommunication(c.Request.Context(), dsrID, req, userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Kommunikation wurde gesendet"})
}
// AdminUpdateDSRStatus updates the status of a DSR
func (h *DSRHandler) AdminUpdateDSRStatus(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req struct {
Status string `json:"status" binding:"required"`
Comment string `json:"comment"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.UpdateStatus(c.Request.Context(), dsrID, models.DSRStatus(req.Status), req.Comment, &userID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Status wurde aktualisiert"})
}
// ========================================
// EXCEPTION CHECKS (Art. 17)
// ========================================
// AdminGetExceptionChecks returns exception checks for an erasure DSR
func (h *DSRHandler) AdminGetExceptionChecks(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
checks, err := h.dsrService.GetExceptionChecks(c.Request.Context(), dsrID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch exception checks"})
return
}
c.JSON(http.StatusOK, gin.H{"exception_checks": checks})
}
// AdminInitExceptionChecks initializes exception checks for an erasure DSR
func (h *DSRHandler) AdminInitExceptionChecks(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
dsrID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
return
}
err = h.dsrService.InitErasureExceptionChecks(c.Request.Context(), dsrID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to initialize exception checks"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Ausnahmeprüfungen wurden initialisiert"})
}
// AdminUpdateExceptionCheck updates a single exception check
func (h *DSRHandler) AdminUpdateExceptionCheck(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
checkID, err := uuid.Parse(c.Param("checkId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid check ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req struct {
Applies bool `json:"applies"`
Notes *string `json:"notes"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
err = h.dsrService.UpdateExceptionCheck(c.Request.Context(), checkID, req.Applies, req.Notes, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update exception check"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Ausnahmeprüfung wurde aktualisiert"})
}
// ========================================
// TEMPLATE ENDPOINTS
// ========================================
// AdminGetDSRTemplates returns all DSR templates
func (h *DSRHandler) AdminGetDSRTemplates(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
ctx := c.Request.Context()
rows, err := h.dsrService.GetPool().Query(ctx, `
SELECT id, template_type, name, description, request_types, is_active, sort_order, created_at, updated_at
FROM dsr_templates ORDER BY sort_order, name
`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch templates"})
return
}
defer rows.Close()
var templates []map[string]interface{}
for rows.Next() {
var id uuid.UUID
var templateType, name string
var description *string
var requestTypes []byte
var isActive bool
var sortOrder int
var createdAt, updatedAt time.Time
err := rows.Scan(&id, &templateType, &name, &description, &requestTypes, &isActive, &sortOrder, &createdAt, &updatedAt)
if err != nil {
continue
}
templates = append(templates, map[string]interface{}{
"id": id,
"template_type": templateType,
"name": name,
"description": description,
"request_types": string(requestTypes),
"is_active": isActive,
"sort_order": sortOrder,
"created_at": createdAt,
"updated_at": updatedAt,
})
}
c.JSON(http.StatusOK, gin.H{"templates": templates})
}
// AdminGetDSRTemplateVersions returns versions for a template
func (h *DSRHandler) AdminGetDSRTemplateVersions(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
templateID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid template ID"})
return
}
ctx := c.Request.Context()
rows, err := h.dsrService.GetPool().Query(ctx, `
SELECT id, template_id, version, language, subject, body_html, body_text,
status, published_at, created_by, approved_by, approved_at, created_at, updated_at
FROM dsr_template_versions WHERE template_id = $1 ORDER BY created_at DESC
`, templateID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch versions"})
return
}
defer rows.Close()
var versions []map[string]interface{}
for rows.Next() {
var id, tempID uuid.UUID
var version, language, subject, bodyHTML, bodyText, status string
var publishedAt, approvedAt *time.Time
var createdBy, approvedBy *uuid.UUID
var createdAt, updatedAt time.Time
err := rows.Scan(&id, &tempID, &version, &language, &subject, &bodyHTML, &bodyText,
&status, &publishedAt, &createdBy, &approvedBy, &approvedAt, &createdAt, &updatedAt)
if err != nil {
continue
}
versions = append(versions, map[string]interface{}{
"id": id,
"template_id": tempID,
"version": version,
"language": language,
"subject": subject,
"body_html": bodyHTML,
"body_text": bodyText,
"status": status,
"published_at": publishedAt,
"created_by": createdBy,
"approved_by": approvedBy,
"approved_at": approvedAt,
"created_at": createdAt,
"updated_at": updatedAt,
})
}
c.JSON(http.StatusOK, gin.H{"versions": versions})
}
// AdminCreateDSRTemplateVersion creates a new template version
func (h *DSRHandler) AdminCreateDSRTemplateVersion(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
templateID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid template ID"})
return
}
userID, _ := middleware.GetUserID(c)
var req struct {
Version string `json:"version" binding:"required"`
Language string `json:"language"`
Subject string `json:"subject" binding:"required"`
BodyHTML string `json:"body_html" binding:"required"`
BodyText string `json:"body_text"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.Language == "" {
req.Language = "de"
}
ctx := c.Request.Context()
var versionID uuid.UUID
err = h.dsrService.GetPool().QueryRow(ctx, `
INSERT INTO dsr_template_versions (template_id, version, language, subject, body_html, body_text, created_by)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id
`, templateID, req.Version, req.Language, req.Subject, req.BodyHTML, req.BodyText, userID).Scan(&versionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create version"})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Version wurde erstellt",
"id": versionID,
})
}
// AdminPublishDSRTemplateVersion publishes a template version
func (h *DSRHandler) AdminPublishDSRTemplateVersion(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
versionID, err := uuid.Parse(c.Param("versionId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid version ID"})
return
}
userID, _ := middleware.GetUserID(c)
ctx := c.Request.Context()
_, err = h.dsrService.GetPool().Exec(ctx, `
UPDATE dsr_template_versions
SET status = 'published', published_at = NOW(), approved_by = $1, approved_at = NOW(), updated_at = NOW()
WHERE id = $2
`, userID, versionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to publish version"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Version wurde veröffentlicht"})
}
// AdminGetPublishedDSRTemplates returns all published templates for selection
func (h *DSRHandler) AdminGetPublishedDSRTemplates(c *gin.Context) {
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
return
}
requestType := c.Query("request_type")
language := c.DefaultQuery("language", "de")
ctx := c.Request.Context()
query := `
SELECT t.id, t.template_type, t.name, t.description,
v.id as version_id, v.version, v.subject, v.body_html, v.body_text
FROM dsr_templates t
JOIN dsr_template_versions v ON t.id = v.template_id
WHERE t.is_active = TRUE AND v.status = 'published' AND v.language = $1
`
args := []interface{}{language}
if requestType != "" {
query += ` AND t.request_types @> $2::jsonb`
args = append(args, `["`+requestType+`"]`)
}
query += " ORDER BY t.sort_order, t.name"
rows, err := h.dsrService.GetPool().Query(ctx, query, args...)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch templates"})
return
}
defer rows.Close()
var templates []map[string]interface{}
for rows.Next() {
var templateID, versionID uuid.UUID
var templateType, name, version, subject, bodyHTML, bodyText string
var description *string
err := rows.Scan(&templateID, &templateType, &name, &description, &versionID, &version, &subject, &bodyHTML, &bodyText)
if err != nil {
continue
}
templates = append(templates, map[string]interface{}{
"template_id": templateID,
"template_type": templateType,
"name": name,
"description": description,
"version_id": versionID,
"version": version,
"subject": subject,
"body_html": bodyHTML,
"body_text": bodyText,
})
}
c.JSON(http.StatusOK, gin.H{"templates": templates})
}
// ========================================
// DEADLINE PROCESSING
// ========================================
// ProcessDeadlines triggers deadline checking (called by scheduler)
func (h *DSRHandler) ProcessDeadlines(c *gin.Context) {
err := h.dsrService.ProcessDeadlines(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process deadlines"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Deadline processing completed"})
}
@@ -0,0 +1,448 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/breakpilot/consent-service/internal/models"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
// TestCreateDSR_InvalidBody tests create DSR with invalid body
func TestCreateDSR_InvalidBody_Returns400(t *testing.T) {
router := gin.New()
// Mock handler that mimics the actual behavior for invalid body
router.POST("/api/v1/dsr", func(c *gin.Context) {
var req models.CreateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
return
}
})
// Invalid JSON
req, _ := http.NewRequest("POST", "/api/v1/dsr", bytes.NewBufferString("{invalid json"))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// TestCreateDSR_MissingType tests create DSR with missing type
func TestCreateDSR_MissingType_Returns400(t *testing.T) {
router := gin.New()
router.POST("/api/v1/dsr", func(c *gin.Context) {
var req models.CreateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.RequestType == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "request_type is required"})
return
}
})
body := `{"requester_email": "test@example.com"}`
req, _ := http.NewRequest("POST", "/api/v1/dsr", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// TestCreateDSR_InvalidType tests create DSR with invalid type
func TestCreateDSR_InvalidType_Returns400(t *testing.T) {
router := gin.New()
router.POST("/api/v1/dsr", func(c *gin.Context) {
var req models.CreateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if !models.IsValidDSRRequestType(req.RequestType) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request_type"})
return
}
})
body := `{"request_type": "invalid_type", "requester_email": "test@example.com"}`
req, _ := http.NewRequest("POST", "/api/v1/dsr", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// TestAdminListDSR_Unauthorized_Returns401 tests admin list without auth
func TestAdminListDSR_Unauthorized_Returns401(t *testing.T) {
router := gin.New()
// Simplified auth check
router.GET("/api/v1/admin/dsr", func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
return
}
c.JSON(http.StatusOK, gin.H{"requests": []interface{}{}})
})
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401, got %d", w.Code)
}
}
// TestAdminListDSR_ValidRequest tests admin list with valid auth
func TestAdminListDSR_ValidRequest_Returns200(t *testing.T) {
router := gin.New()
router.GET("/api/v1/admin/dsr", func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
return
}
c.JSON(http.StatusOK, gin.H{
"requests": []interface{}{},
"total": 0,
"limit": 20,
"offset": 0,
})
})
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr", nil)
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &response)
if _, ok := response["requests"]; !ok {
t.Error("Response should contain 'requests' field")
}
if _, ok := response["total"]; !ok {
t.Error("Response should contain 'total' field")
}
}
// TestAdminGetDSRStats_ValidRequest tests admin stats endpoint
func TestAdminGetDSRStats_ValidRequest_Returns200(t *testing.T) {
router := gin.New()
router.GET("/api/v1/admin/dsr/stats", func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
return
}
c.JSON(http.StatusOK, gin.H{
"total_requests": 0,
"pending_requests": 0,
"overdue_requests": 0,
"completed_this_month": 0,
"average_processing_days": 0,
"by_type": map[string]int{},
"by_status": map[string]int{},
})
})
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr/stats", nil)
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &response)
expectedFields := []string{"total_requests", "pending_requests", "overdue_requests", "by_type", "by_status"}
for _, field := range expectedFields {
if _, ok := response[field]; !ok {
t.Errorf("Response should contain '%s' field", field)
}
}
}
// TestAdminUpdateDSR_InvalidStatus_Returns400 tests admin update with invalid status
func TestAdminUpdateDSR_InvalidStatus_Returns400(t *testing.T) {
router := gin.New()
router.PUT("/api/v1/admin/dsr/:id", func(c *gin.Context) {
var req models.UpdateDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.Status != nil && !models.IsValidDSRStatus(*req.Status) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Updated"})
})
body := `{"status": "invalid_status"}`
req, _ := http.NewRequest("PUT", "/api/v1/admin/dsr/123", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// TestAdminVerifyIdentity_ValidRequest_Returns200 tests identity verification
func TestAdminVerifyIdentity_ValidRequest_Returns200(t *testing.T) {
router := gin.New()
router.POST("/api/v1/admin/dsr/:id/verify-identity", func(c *gin.Context) {
var req models.VerifyDSRIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.Method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "method is required"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Identität verifiziert"})
})
body := `{"method": "id_card"}`
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/verify-identity", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
// TestAdminExtendDeadline_MissingReason_Returns400 tests extend deadline without reason
func TestAdminExtendDeadline_MissingReason_Returns400(t *testing.T) {
router := gin.New()
router.POST("/api/v1/admin/dsr/:id/extend", func(c *gin.Context) {
var req models.ExtendDSRDeadlineRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.Reason == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "reason is required"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Deadline extended"})
})
body := `{"days": 30}`
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/extend", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// TestAdminCompleteDSR_ValidRequest_Returns200 tests complete DSR
func TestAdminCompleteDSR_ValidRequest_Returns200(t *testing.T) {
router := gin.New()
router.POST("/api/v1/admin/dsr/:id/complete", func(c *gin.Context) {
var req models.CompleteDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Anfrage erfolgreich abgeschlossen"})
})
body := `{"result_summary": "Alle Daten wurden bereitgestellt"}`
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/complete", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
// TestAdminRejectDSR_MissingLegalBasis_Returns400 tests reject DSR without legal basis
func TestAdminRejectDSR_MissingLegalBasis_Returns400(t *testing.T) {
router := gin.New()
router.POST("/api/v1/admin/dsr/:id/reject", func(c *gin.Context) {
var req models.RejectDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.LegalBasis == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "legal_basis is required"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Rejected"})
})
body := `{"reason": "Some reason"}`
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/reject", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", w.Code)
}
}
// TestAdminRejectDSR_ValidRequest_Returns200 tests reject DSR with valid data
func TestAdminRejectDSR_ValidRequest_Returns200(t *testing.T) {
router := gin.New()
router.POST("/api/v1/admin/dsr/:id/reject", func(c *gin.Context) {
var req models.RejectDSRRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
if req.LegalBasis == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "legal_basis is required"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Anfrage abgelehnt"})
})
body := `{"reason": "Daten benötigt für Rechtsstreit", "legal_basis": "Art. 17(3)e"}`
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/reject", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
// TestGetDSRTemplates_Returns200 tests templates endpoint
func TestGetDSRTemplates_Returns200(t *testing.T) {
router := gin.New()
router.GET("/api/v1/admin/dsr-templates", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"templates": []map[string]interface{}{
{
"id": "uuid-1",
"template_type": "dsr_receipt_access",
"name": "Eingangsbestätigung (Art. 15)",
},
},
})
})
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr-templates", nil)
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &response)
if _, ok := response["templates"]; !ok {
t.Error("Response should contain 'templates' field")
}
}
// TestRequestTypeValidation tests all valid request types
func TestRequestTypeValidation(t *testing.T) {
validTypes := []string{"access", "rectification", "erasure", "restriction", "portability"}
for _, reqType := range validTypes {
if !models.IsValidDSRRequestType(reqType) {
t.Errorf("Expected %s to be a valid request type", reqType)
}
}
invalidTypes := []string{"invalid", "delete", "copy", ""}
for _, reqType := range invalidTypes {
if models.IsValidDSRRequestType(reqType) {
t.Errorf("Expected %s to be an invalid request type", reqType)
}
}
}
// TestStatusValidation tests all valid statuses
func TestStatusValidation(t *testing.T) {
validStatuses := []string{"intake", "identity_verification", "processing", "completed", "rejected", "cancelled"}
for _, status := range validStatuses {
if !models.IsValidDSRStatus(status) {
t.Errorf("Expected %s to be a valid status", status)
}
}
invalidStatuses := []string{"invalid", "pending", "done", ""}
for _, status := range invalidStatuses {
if models.IsValidDSRStatus(status) {
t.Errorf("Expected %s to be an invalid status", status)
}
}
}
@@ -0,0 +1,528 @@
package handlers
import (
"net/http"
"strconv"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// EmailTemplateHandler handles email template operations
type EmailTemplateHandler struct {
service *services.EmailTemplateService
}
// NewEmailTemplateHandler creates a new email template handler
func NewEmailTemplateHandler(service *services.EmailTemplateService) *EmailTemplateHandler {
return &EmailTemplateHandler{service: service}
}
// GetAllTemplateTypes returns all available email template types with their variables
// GET /api/v1/admin/email-templates/types
func (h *EmailTemplateHandler) GetAllTemplateTypes(c *gin.Context) {
types := h.service.GetAllTemplateTypes()
c.JSON(http.StatusOK, gin.H{"types": types})
}
// GetAllTemplates returns all email templates with their latest published versions
// GET /api/v1/admin/email-templates
func (h *EmailTemplateHandler) GetAllTemplates(c *gin.Context) {
templates, err := h.service.GetAllTemplates(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"templates": templates})
}
// GetTemplate returns a single template by ID
// GET /api/v1/admin/email-templates/:id
func (h *EmailTemplateHandler) GetTemplate(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid template ID"})
return
}
template, err := h.service.GetTemplateByID(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "template not found"})
return
}
c.JSON(http.StatusOK, template)
}
// CreateTemplate creates a new email template type
// POST /api/v1/admin/email-templates
func (h *EmailTemplateHandler) CreateTemplate(c *gin.Context) {
var req models.CreateEmailTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
template, err := h.service.CreateEmailTemplate(c.Request.Context(), &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, template)
}
// GetTemplateVersions returns all versions for a template
// GET /api/v1/admin/email-templates/:id/versions
func (h *EmailTemplateHandler) GetTemplateVersions(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid template ID"})
return
}
versions, err := h.service.GetVersionsByTemplateID(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"versions": versions})
}
// GetVersion returns a single version by ID
// GET /api/v1/admin/email-template-versions/:id
func (h *EmailTemplateHandler) GetVersion(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
version, err := h.service.GetVersionByID(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
return
}
c.JSON(http.StatusOK, version)
}
// CreateVersion creates a new version of an email template
// POST /api/v1/admin/email-template-versions
func (h *EmailTemplateHandler) CreateVersion(c *gin.Context) {
var req models.CreateEmailTemplateVersionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Get user ID from context
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
uid, _ := uuid.Parse(userID.(string))
version, err := h.service.CreateTemplateVersion(c.Request.Context(), &req, uid)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, version)
}
// UpdateVersion updates a version
// PUT /api/v1/admin/email-template-versions/:id
func (h *EmailTemplateHandler) UpdateVersion(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
var req models.UpdateEmailTemplateVersionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.service.UpdateVersion(c.Request.Context(), id, &req); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "version updated"})
}
// SubmitForReview submits a version for review
// POST /api/v1/admin/email-template-versions/:id/submit
func (h *EmailTemplateHandler) SubmitForReview(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
var req struct {
Comment *string `json:"comment"`
}
c.ShouldBindJSON(&req)
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
uid, _ := uuid.Parse(userID.(string))
if err := h.service.SubmitForReview(c.Request.Context(), id, uid, req.Comment); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "version submitted for review"})
}
// ApproveVersion approves a version (DSB only)
// POST /api/v1/admin/email-template-versions/:id/approve
func (h *EmailTemplateHandler) ApproveVersion(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
// Check role
role, exists := c.Get("user_role")
if !exists || (role != "data_protection_officer" && role != "admin" && role != "super_admin") {
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
return
}
var req struct {
Comment *string `json:"comment"`
ScheduledPublishAt *string `json:"scheduled_publish_at"`
}
c.ShouldBindJSON(&req)
userID, _ := c.Get("user_id")
uid, _ := uuid.Parse(userID.(string))
var scheduledAt *time.Time
if req.ScheduledPublishAt != nil {
t, err := time.Parse(time.RFC3339, *req.ScheduledPublishAt)
if err == nil {
scheduledAt = &t
}
}
if err := h.service.ApproveVersion(c.Request.Context(), id, uid, req.Comment, scheduledAt); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "version approved"})
}
// RejectVersion rejects a version
// POST /api/v1/admin/email-template-versions/:id/reject
func (h *EmailTemplateHandler) RejectVersion(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
role, exists := c.Get("user_role")
if !exists || (role != "data_protection_officer" && role != "admin" && role != "super_admin") {
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
return
}
var req struct {
Comment string `json:"comment" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "comment is required"})
return
}
userID, _ := c.Get("user_id")
uid, _ := uuid.Parse(userID.(string))
if err := h.service.RejectVersion(c.Request.Context(), id, uid, req.Comment); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "version rejected"})
}
// PublishVersion publishes an approved version
// POST /api/v1/admin/email-template-versions/:id/publish
func (h *EmailTemplateHandler) PublishVersion(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
role, exists := c.Get("user_role")
if !exists || (role != "data_protection_officer" && role != "admin" && role != "super_admin") {
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
return
}
userID, _ := c.Get("user_id")
uid, _ := uuid.Parse(userID.(string))
if err := h.service.PublishVersion(c.Request.Context(), id, uid); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "version published"})
}
// GetApprovals returns approval history for a version
// GET /api/v1/admin/email-template-versions/:id/approvals
func (h *EmailTemplateHandler) GetApprovals(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
approvals, err := h.service.GetApprovals(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"approvals": approvals})
}
// PreviewVersion renders a preview of an email template version
// POST /api/v1/admin/email-template-versions/:id/preview
func (h *EmailTemplateHandler) PreviewVersion(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
var req struct {
Variables map[string]string `json:"variables"`
}
c.ShouldBindJSON(&req)
version, err := h.service.GetVersionByID(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
return
}
// Use default test values if not provided
if req.Variables == nil {
req.Variables = map[string]string{
"user_name": "Max Mustermann",
"user_email": "max@example.com",
"login_url": "https://breakpilot.app/login",
"support_email": "support@breakpilot.app",
"verification_url": "https://breakpilot.app/verify?token=abc123",
"verification_code": "123456",
"expires_in": "24 Stunden",
"reset_url": "https://breakpilot.app/reset?token=xyz789",
"reset_code": "RESET123",
"ip_address": "192.168.1.1",
"device_info": "Chrome auf Windows 11",
"changed_at": time.Now().Format("02.01.2006 15:04"),
"enabled_at": time.Now().Format("02.01.2006 15:04"),
"disabled_at": time.Now().Format("02.01.2006 15:04"),
"support_url": "https://breakpilot.app/support",
"security_url": "https://breakpilot.app/account/security",
"login_time": time.Now().Format("02.01.2006 15:04"),
"location": "Berlin, Deutschland",
"activity_type": "Mehrere fehlgeschlagene Login-Versuche",
"activity_time": time.Now().Format("02.01.2006 15:04"),
"locked_at": time.Now().Format("02.01.2006 15:04"),
"reason": "Zu viele fehlgeschlagene Login-Versuche",
"unlock_time": time.Now().Add(30 * time.Minute).Format("02.01.2006 15:04"),
"unlocked_at": time.Now().Format("02.01.2006 15:04"),
"requested_at": time.Now().Format("02.01.2006"),
"deletion_date": time.Now().AddDate(0, 0, 30).Format("02.01.2006"),
"cancel_url": "https://breakpilot.app/cancel-deletion?token=cancel123",
"data_info": "Benutzerdaten, Zustimmungshistorie, Audit-Logs",
"deleted_at": time.Now().Format("02.01.2006"),
"feedback_url": "https://breakpilot.app/feedback",
"download_url": "https://breakpilot.app/export/download?token=export123",
"file_size": "2.3 MB",
"old_email": "alt@example.com",
"new_email": "neu@example.com",
"document_name": "Datenschutzerklärung",
"document_type": "privacy",
"version": "2.0.0",
"consent_url": "https://breakpilot.app/consent",
"deadline": time.Now().AddDate(0, 0, 14).Format("02.01.2006"),
"days_left": "7",
"hours_left": "24 Stunden",
"consequences": "Ohne Ihre Zustimmung wird Ihr Konto suspendiert.",
"suspended_at": time.Now().Format("02.01.2006 15:04"),
"documents": "- Datenschutzerklärung v2.0.0\n- AGB v1.5.0",
}
}
preview, err := h.service.RenderTemplate(version, req.Variables)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, preview)
}
// SendTestEmail sends a test email
// POST /api/v1/admin/email-template-versions/:id/send-test
func (h *EmailTemplateHandler) SendTestEmail(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
return
}
var req models.SendTestEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.VersionID = idStr
version, err := h.service.GetVersionByID(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
return
}
// Get template to find type
template, err := h.service.GetTemplateByID(c.Request.Context(), version.TemplateID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "template not found"})
return
}
userID, _ := c.Get("user_id")
uid, _ := uuid.Parse(userID.(string))
// Send test email
if err := h.service.SendEmail(c.Request.Context(), template.Type, version.Language, req.Recipient, req.Variables, &uid); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "test email sent"})
}
// GetSettings returns global email settings
// GET /api/v1/admin/email-templates/settings
func (h *EmailTemplateHandler) GetSettings(c *gin.Context) {
settings, err := h.service.GetSettings(c.Request.Context())
if err != nil {
// Return default settings if none exist
c.JSON(http.StatusOK, gin.H{
"company_name": "BreakPilot",
"sender_name": "BreakPilot",
"sender_email": "noreply@breakpilot.app",
"primary_color": "#2563eb",
"secondary_color": "#64748b",
})
return
}
c.JSON(http.StatusOK, settings)
}
// UpdateSettings updates global email settings
// PUT /api/v1/admin/email-templates/settings
func (h *EmailTemplateHandler) UpdateSettings(c *gin.Context) {
var req models.UpdateEmailTemplateSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID, _ := c.Get("user_id")
uid, _ := uuid.Parse(userID.(string))
if err := h.service.UpdateSettings(c.Request.Context(), &req, uid); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "settings updated"})
}
// GetEmailStats returns email statistics
// GET /api/v1/admin/email-templates/stats
func (h *EmailTemplateHandler) GetEmailStats(c *gin.Context) {
stats, err := h.service.GetEmailStats(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// GetSendLogs returns email send logs
// GET /api/v1/admin/email-templates/logs
func (h *EmailTemplateHandler) GetSendLogs(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
if limit > 100 {
limit = 100
}
logs, total, err := h.service.GetSendLogs(c.Request.Context(), limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"logs": logs, "total": total})
}
// GetDefaultContent returns default template content for a type
// GET /api/v1/admin/email-templates/default/:type
func (h *EmailTemplateHandler) GetDefaultContent(c *gin.Context) {
templateType := c.Param("type")
language := c.DefaultQuery("language", "de")
subject, bodyHTML, bodyText := h.service.GetDefaultTemplateContent(templateType, language)
c.JSON(http.StatusOK, gin.H{
"subject": subject,
"body_html": bodyHTML,
"body_text": bodyText,
})
}
// InitializeTemplates initializes default email templates
// POST /api/v1/admin/email-templates/initialize
func (h *EmailTemplateHandler) InitializeTemplates(c *gin.Context) {
role, exists := c.Get("user_role")
if !exists || (role != "admin" && role != "super_admin") {
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
return
}
if err := h.service.InitDefaultTemplates(c.Request.Context()); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "default templates initialized"})
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,805 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
// setupTestRouter creates a test router with handlers
// Note: For full integration tests, use a test database
func setupTestRouter() *gin.Engine {
router := gin.New()
return router
}
// TestHealthEndpoint tests the health check endpoint
func TestHealthEndpoint(t *testing.T) {
router := setupTestRouter()
// Add health endpoint
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"service": "consent-service",
"version": "1.0.0",
})
})
req, _ := http.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var response map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &response)
if response["status"] != "healthy" {
t.Errorf("Expected status 'healthy', got %v", response["status"])
}
}
// TestUnauthorizedAccess tests that protected endpoints require auth
func TestUnauthorizedAccess(t *testing.T) {
router := setupTestRouter()
// Add a protected endpoint
router.GET("/api/v1/consent/my", func(c *gin.Context) {
auth := c.GetHeader("Authorization")
if auth == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
return
}
c.JSON(http.StatusOK, gin.H{"consents": []interface{}{}})
})
tests := []struct {
name string
authorization string
expectedStatus int
}{
{"no auth header", "", http.StatusUnauthorized},
{"empty bearer", "Bearer ", http.StatusOK}, // Would be invalid in real middleware
{"valid format", "Bearer test-token", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/api/v1/consent/my", nil)
if tt.authorization != "" {
req.Header.Set("Authorization", tt.authorization)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestCreateConsentRequest tests consent creation request validation
func TestCreateConsentRequest(t *testing.T) {
type ConsentRequest struct {
DocumentType string `json:"document_type"`
VersionID string `json:"version_id"`
Consented bool `json:"consented"`
}
tests := []struct {
name string
request ConsentRequest
expectValid bool
}{
{
name: "valid consent",
request: ConsentRequest{
DocumentType: "terms",
VersionID: "123e4567-e89b-12d3-a456-426614174000",
Consented: true,
},
expectValid: true,
},
{
name: "missing document type",
request: ConsentRequest{
VersionID: "123e4567-e89b-12d3-a456-426614174000",
Consented: true,
},
expectValid: false,
},
{
name: "missing version ID",
request: ConsentRequest{
DocumentType: "terms",
Consented: true,
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.request.DocumentType != "" && tt.request.VersionID != ""
if isValid != tt.expectValid {
t.Errorf("Expected valid=%v, got %v", tt.expectValid, isValid)
}
})
}
}
// TestDocumentTypeValidation tests valid document types
func TestDocumentTypeValidation(t *testing.T) {
validTypes := map[string]bool{
"terms": true,
"privacy": true,
"cookies": true,
"community_guidelines": true,
"imprint": true,
}
tests := []struct {
docType string
expected bool
}{
{"terms", true},
{"privacy", true},
{"cookies", true},
{"community_guidelines", true},
{"imprint", true},
{"invalid", false},
{"", false},
{"Terms", false}, // case sensitive
}
for _, tt := range tests {
t.Run(tt.docType, func(t *testing.T) {
_, isValid := validTypes[tt.docType]
if isValid != tt.expected {
t.Errorf("Expected %s valid=%v, got %v", tt.docType, tt.expected, isValid)
}
})
}
}
// TestVersionStatusTransitions tests valid status transitions
func TestVersionStatusTransitions(t *testing.T) {
validTransitions := map[string][]string{
"draft": {"review"},
"review": {"approved", "rejected"},
"approved": {"scheduled", "published"},
"scheduled": {"published"},
"published": {"archived"},
"rejected": {"draft"},
"archived": {}, // terminal state
}
tests := []struct {
fromStatus string
toStatus string
expected bool
}{
{"draft", "review", true},
{"draft", "published", false},
{"review", "approved", true},
{"review", "rejected", true},
{"review", "published", false},
{"approved", "published", true},
{"approved", "scheduled", true},
{"published", "archived", true},
{"published", "draft", false},
{"archived", "draft", false},
}
for _, tt := range tests {
t.Run(tt.fromStatus+"->"+tt.toStatus, func(t *testing.T) {
allowed := false
if transitions, ok := validTransitions[tt.fromStatus]; ok {
for _, t := range transitions {
if t == tt.toStatus {
allowed = true
break
}
}
}
if allowed != tt.expected {
t.Errorf("Transition %s->%s: expected %v, got %v",
tt.fromStatus, tt.toStatus, tt.expected, allowed)
}
})
}
}
// TestRolePermissions tests role-based access control
func TestRolePermissions(t *testing.T) {
permissions := map[string]map[string]bool{
"user": {
"view_documents": true,
"give_consent": true,
"view_own_data": true,
"request_deletion": true,
"create_document": false,
"publish_version": false,
"approve_version": false,
},
"admin": {
"view_documents": true,
"give_consent": true,
"view_own_data": true,
"create_document": true,
"edit_version": true,
"publish_version": true,
"approve_version": false, // Only DSB
},
"data_protection_officer": {
"view_documents": true,
"create_document": true,
"edit_version": true,
"approve_version": true,
"publish_version": true,
"view_audit_log": true,
},
}
tests := []struct {
role string
action string
shouldHave bool
}{
{"user", "view_documents", true},
{"user", "create_document", false},
{"admin", "create_document", true},
{"admin", "approve_version", false},
{"data_protection_officer", "approve_version", true},
}
for _, tt := range tests {
t.Run(tt.role+":"+tt.action, func(t *testing.T) {
rolePerms, ok := permissions[tt.role]
if !ok {
t.Fatalf("Unknown role: %s", tt.role)
}
hasPermission := rolePerms[tt.action]
if hasPermission != tt.shouldHave {
t.Errorf("Role %s action %s: expected %v, got %v",
tt.role, tt.action, tt.shouldHave, hasPermission)
}
})
}
}
// TestJSONResponseFormat tests that responses have correct format
func TestJSONResponseFormat(t *testing.T) {
router := setupTestRouter()
router.GET("/api/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"id": "123",
"name": "Test",
},
})
})
req, _ := http.NewRequest("GET", "/api/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
contentType := w.Header().Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("Expected Content-Type 'application/json; charset=utf-8', got %s", contentType)
}
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
if err != nil {
t.Fatalf("Response should be valid JSON: %v", err)
}
}
// TestErrorResponseFormat tests error response format
func TestErrorResponseFormat(t *testing.T) {
router := setupTestRouter()
router.GET("/api/error", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Bad Request",
"message": "Invalid input",
})
})
req, _ := http.NewRequest("GET", "/api/error", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
}
var response map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &response)
if response["error"] == nil {
t.Error("Error response should contain 'error' field")
}
}
// TestCookieCategoryValidation tests cookie category validation
func TestCookieCategoryValidation(t *testing.T) {
mandatoryCategories := []string{"necessary"}
optionalCategories := []string{"functional", "analytics", "marketing"}
// Necessary should always be consented
for _, cat := range mandatoryCategories {
t.Run("mandatory_"+cat, func(t *testing.T) {
// Business rule: mandatory categories cannot be declined
isMandatory := true
if !isMandatory {
t.Errorf("Category %s should be mandatory", cat)
}
})
}
// Optional categories can be toggled
for _, cat := range optionalCategories {
t.Run("optional_"+cat, func(t *testing.T) {
isMandatory := false
if isMandatory {
t.Errorf("Category %s should not be mandatory", cat)
}
})
}
}
// TestPaginationParams tests pagination parameter handling
func TestPaginationParams(t *testing.T) {
tests := []struct {
name string
page int
perPage int
expPage int
expLimit int
}{
{"defaults", 0, 0, 1, 50},
{"page 1", 1, 10, 1, 10},
{"page 5", 5, 20, 5, 20},
{"negative page", -1, 10, 1, 10}, // should default
{"too large per_page", 1, 500, 1, 100}, // should cap
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
page := tt.page
perPage := tt.perPage
// Apply defaults and limits
if page < 1 {
page = 1
}
if perPage < 1 {
perPage = 50
}
if perPage > 100 {
perPage = 100
}
if page != tt.expPage {
t.Errorf("Expected page %d, got %d", tt.expPage, page)
}
if perPage != tt.expLimit {
t.Errorf("Expected perPage %d, got %d", tt.expLimit, perPage)
}
})
}
}
// TestIPAddressExtraction tests IP address extraction from requests
func TestIPAddressExtraction(t *testing.T) {
tests := []struct {
name string
xForwarded string
remoteAddr string
expected string
}{
{"direct connection", "", "192.168.1.1:1234", "192.168.1.1"},
{"behind proxy", "10.0.0.1", "192.168.1.1:1234", "10.0.0.1"},
{"multiple proxies", "10.0.0.1, 10.0.0.2", "192.168.1.1:1234", "10.0.0.1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := setupTestRouter()
var extractedIP string
router.GET("/test", func(c *gin.Context) {
if xf := c.GetHeader("X-Forwarded-For"); xf != "" {
// Take first IP from list
for i, ch := range xf {
if ch == ',' {
extractedIP = xf[:i]
break
}
}
if extractedIP == "" {
extractedIP = xf
}
} else {
// Extract IP from RemoteAddr
addr := c.Request.RemoteAddr
for i := len(addr) - 1; i >= 0; i-- {
if addr[i] == ':' {
extractedIP = addr[:i]
break
}
}
}
c.JSON(http.StatusOK, gin.H{"ip": extractedIP})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwarded != "" {
req.Header.Set("X-Forwarded-For", tt.xForwarded)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if extractedIP != tt.expected {
t.Errorf("Expected IP %s, got %s", tt.expected, extractedIP)
}
})
}
}
// TestRequestBodySizeLimit tests that large requests are rejected
func TestRequestBodySizeLimit(t *testing.T) {
router := setupTestRouter()
// Simulate a body size limit check
maxBodySize := int64(1024 * 1024) // 1MB
router.POST("/api/upload", func(c *gin.Context) {
if c.Request.ContentLength > maxBodySize {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": "Request body too large",
})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
})
tests := []struct {
name string
contentLength int64
expectedStatus int
}{
{"small body", 1000, http.StatusOK},
{"medium body", 500000, http.StatusOK},
{"exactly at limit", maxBodySize, http.StatusOK},
{"over limit", maxBodySize + 1, http.StatusRequestEntityTooLarge},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := bytes.NewReader(make([]byte, 0))
req, _ := http.NewRequest("POST", "/api/upload", body)
req.ContentLength = tt.contentLength
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// ========================================
// EXTENDED HANDLER TESTS
// ========================================
// TestAuthHandlers tests authentication endpoints
func TestAuthHandlers(t *testing.T) {
router := setupTestRouter()
// Register endpoint
router.POST("/api/v1/auth/register", func(c *gin.Context) {
var req struct {
Email string `json:"email"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
c.JSON(http.StatusCreated, gin.H{"message": "User registered"})
})
// Login endpoint
router.POST("/api/v1/auth/login", func(c *gin.Context) {
var req struct {
Email string `json:"email"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
c.JSON(http.StatusOK, gin.H{"access_token": "token123"})
})
tests := []struct {
name string
endpoint string
method string
body interface{}
expectedStatus int
}{
{
name: "register - valid",
endpoint: "/api/v1/auth/register",
method: "POST",
body: map[string]string{"email": "test@example.com", "password": "password123"},
expectedStatus: http.StatusCreated,
},
{
name: "login - valid",
endpoint: "/api/v1/auth/login",
method: "POST",
body: map[string]string{"email": "test@example.com", "password": "password123"},
expectedStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jsonBody, _ := json.Marshal(tt.body)
req, _ := http.NewRequest(tt.method, tt.endpoint, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestDocumentHandlers tests document endpoints
func TestDocumentHandlers(t *testing.T) {
router := setupTestRouter()
// GET documents
router.GET("/api/v1/documents", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"documents": []interface{}{}})
})
// GET document by type
router.GET("/api/v1/documents/:type", func(c *gin.Context) {
docType := c.Param("type")
if docType == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid type"})
return
}
c.JSON(http.StatusOK, gin.H{"id": "123", "type": docType})
})
tests := []struct {
name string
endpoint string
expectedStatus int
}{
{"get all documents", "/api/v1/documents", http.StatusOK},
{"get terms", "/api/v1/documents/terms", http.StatusOK},
{"get privacy", "/api/v1/documents/privacy", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", tt.endpoint, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestConsentHandlers tests consent endpoints
func TestConsentHandlers(t *testing.T) {
router := setupTestRouter()
// Create consent
router.POST("/api/v1/consent", func(c *gin.Context) {
var req struct {
VersionID string `json:"version_id"`
Consented bool `json:"consented"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
c.JSON(http.StatusCreated, gin.H{"message": "Consent saved"})
})
// Check consent
router.GET("/api/v1/consent/check/:type", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"has_consent": true, "needs_update": false})
})
tests := []struct {
name string
endpoint string
method string
body interface{}
expectedStatus int
}{
{
name: "create consent",
endpoint: "/api/v1/consent",
method: "POST",
body: map[string]interface{}{"version_id": "123", "consented": true},
expectedStatus: http.StatusCreated,
},
{
name: "check consent",
endpoint: "/api/v1/consent/check/terms",
method: "GET",
expectedStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req *http.Request
if tt.body != nil {
jsonBody, _ := json.Marshal(tt.body)
req, _ = http.NewRequest(tt.method, tt.endpoint, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
} else {
req, _ = http.NewRequest(tt.method, tt.endpoint, nil)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestAdminHandlers tests admin endpoints
func TestAdminHandlers(t *testing.T) {
router := setupTestRouter()
// Create document (admin only)
router.POST("/api/v1/admin/documents", func(c *gin.Context) {
auth := c.GetHeader("Authorization")
if auth != "Bearer admin-token" {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin only"})
return
}
c.JSON(http.StatusCreated, gin.H{"message": "Document created"})
})
tests := []struct {
name string
token string
expectedStatus int
}{
{"admin token", "Bearer admin-token", http.StatusCreated},
{"user token", "Bearer user-token", http.StatusForbidden},
{"no token", "", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := map[string]string{"type": "terms", "name": "Test"}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequest("POST", "/api/v1/admin/documents", bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
if tt.token != "" {
req.Header.Set("Authorization", tt.token)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestCORSHeaders tests CORS headers
func TestCORSHeaders(t *testing.T) {
router := setupTestRouter()
router.Use(func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Next()
})
router.GET("/api/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "test"})
})
req, _ := http.NewRequest("GET", "/api/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Error("CORS headers not set correctly")
}
}
// TestRateLimiting tests rate limiting logic
func TestRateLimiting(t *testing.T) {
requests := 0
limit := 5
for i := 0; i < 10; i++ {
requests++
if requests > limit {
// Would return 429 Too Many Requests
if requests <= limit {
t.Error("Rate limit not enforced")
}
}
}
}
// TestEmailTemplateHandlers tests email template endpoints
func TestEmailTemplateHandlers(t *testing.T) {
router := setupTestRouter()
router.GET("/api/v1/admin/email-templates", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"templates": []interface{}{}})
})
router.POST("/api/v1/admin/email-templates/test", func(c *gin.Context) {
var req struct {
Recipient string `json:"recipient"`
VersionID string `json:"version_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Test email sent"})
})
req, _ := http.NewRequest("GET", "/api/v1/admin/email-templates", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
}
@@ -0,0 +1,203 @@
package handlers
import (
"net/http"
"strconv"
"github.com/breakpilot/consent-service/internal/middleware"
"github.com/breakpilot/consent-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// NotificationHandler handles notification-related requests
type NotificationHandler struct {
notificationService *services.NotificationService
}
// NewNotificationHandler creates a new notification handler
func NewNotificationHandler(notificationService *services.NotificationService) *NotificationHandler {
return &NotificationHandler{
notificationService: notificationService,
}
}
// GetNotifications returns notifications for the current user
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
// Parse query parameters
limit := 20
offset := 0
unreadOnly := false
if l := c.Query("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
if o := c.Query("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
}
}
if u := c.Query("unread_only"); u == "true" {
unreadOnly = true
}
notifications, total, err := h.notificationService.GetUserNotifications(c.Request.Context(), userID, limit, offset, unreadOnly)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch notifications"})
return
}
c.JSON(http.StatusOK, gin.H{
"notifications": notifications,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetUnreadCount returns the count of unread notifications
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get unread count"})
return
}
c.JSON(http.StatusOK, gin.H{"unread_count": count})
}
// MarkAsRead marks a notification as read
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
notificationID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification ID"})
return
}
if err := h.notificationService.MarkAsRead(c.Request.Context(), userID, notificationID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Notification not found or already read"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Notification marked as read"})
}
// MarkAllAsRead marks all notifications as read
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
if err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to mark notifications as read"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "All notifications marked as read"})
}
// DeleteNotification deletes a notification
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
notificationID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification ID"})
return
}
if err := h.notificationService.DeleteNotification(c.Request.Context(), userID, notificationID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Notification not found"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Notification deleted"})
}
// GetPreferences returns notification preferences for the user
func (h *NotificationHandler) GetPreferences(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
prefs, err := h.notificationService.GetPreferences(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get preferences"})
return
}
c.JSON(http.StatusOK, prefs)
}
// UpdatePreferences updates notification preferences for the user
func (h *NotificationHandler) UpdatePreferences(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
return
}
var req struct {
EmailEnabled *bool `json:"email_enabled"`
PushEnabled *bool `json:"push_enabled"`
InAppEnabled *bool `json:"in_app_enabled"`
ReminderFrequency *string `json:"reminder_frequency"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
// Get current preferences
prefs, _ := h.notificationService.GetPreferences(c.Request.Context(), userID)
// Update only provided fields
if req.EmailEnabled != nil {
prefs.EmailEnabled = *req.EmailEnabled
}
if req.PushEnabled != nil {
prefs.PushEnabled = *req.PushEnabled
}
if req.InAppEnabled != nil {
prefs.InAppEnabled = *req.InAppEnabled
}
if req.ReminderFrequency != nil {
prefs.ReminderFrequency = *req.ReminderFrequency
}
if err := h.notificationService.UpdatePreferences(c.Request.Context(), userID, prefs); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update preferences"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Preferences updated", "preferences": prefs})
}
@@ -0,0 +1,743 @@
package handlers
import (
"context"
"net/http"
"strings"
"github.com/breakpilot/consent-service/internal/middleware"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// OAuthHandler handles OAuth 2.0 endpoints
type OAuthHandler struct {
oauthService *services.OAuthService
totpService *services.TOTPService
authService *services.AuthService
}
// NewOAuthHandler creates a new OAuthHandler
func NewOAuthHandler(oauthService *services.OAuthService, totpService *services.TOTPService, authService *services.AuthService) *OAuthHandler {
return &OAuthHandler{
oauthService: oauthService,
totpService: totpService,
authService: authService,
}
}
// ========================================
// OAuth 2.0 Authorization Code Flow
// ========================================
// Authorize handles the OAuth 2.0 authorization request
// GET /oauth/authorize
func (h *OAuthHandler) Authorize(c *gin.Context) {
responseType := c.Query("response_type")
clientID := c.Query("client_id")
redirectURI := c.Query("redirect_uri")
scope := c.Query("scope")
state := c.Query("state")
codeChallenge := c.Query("code_challenge")
codeChallengeMethod := c.Query("code_challenge_method")
// Validate response_type
if responseType != "code" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "unsupported_response_type",
"error_description": "Only 'code' response_type is supported",
})
return
}
// Validate client
ctx := context.Background()
client, err := h.oauthService.ValidateClient(ctx, clientID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_client",
"error_description": "Unknown or invalid client_id",
})
return
}
// Validate redirect_uri
if err := h.oauthService.ValidateRedirectURI(client, redirectURI); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Invalid redirect_uri",
})
return
}
// Validate scopes
scopes, err := h.oauthService.ValidateScopes(client, scope)
if err != nil {
redirectWithError(c, redirectURI, "invalid_scope", "One or more requested scopes are invalid", state)
return
}
// For public clients, PKCE is required
if client.IsPublic && codeChallenge == "" {
redirectWithError(c, redirectURI, "invalid_request", "PKCE code_challenge is required for public clients", state)
return
}
// Get authenticated user
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
// User not authenticated - redirect to login
// Store authorization request in session and redirect to login
c.JSON(http.StatusUnauthorized, gin.H{
"error": "login_required",
"error_description": "User must be authenticated to authorize",
"login_url": "/auth/login",
})
return
}
// Generate authorization code
code, err := h.oauthService.GenerateAuthorizationCode(
ctx, client, userID, redirectURI, scopes, codeChallenge, codeChallengeMethod,
)
if err != nil {
redirectWithError(c, redirectURI, "server_error", "Failed to generate authorization code", state)
return
}
// Redirect with code
redirectURL := redirectURI + "?code=" + code
if state != "" {
redirectURL += "&state=" + state
}
c.Redirect(http.StatusFound, redirectURL)
}
// Token handles the OAuth 2.0 token request
// POST /oauth/token
func (h *OAuthHandler) Token(c *gin.Context) {
grantType := c.PostForm("grant_type")
switch grantType {
case "authorization_code":
h.tokenAuthorizationCode(c)
case "refresh_token":
h.tokenRefreshToken(c)
default:
c.JSON(http.StatusBadRequest, gin.H{
"error": "unsupported_grant_type",
"error_description": "Only 'authorization_code' and 'refresh_token' grant types are supported",
})
}
}
// tokenAuthorizationCode handles the authorization_code grant
func (h *OAuthHandler) tokenAuthorizationCode(c *gin.Context) {
code := c.PostForm("code")
clientID := c.PostForm("client_id")
redirectURI := c.PostForm("redirect_uri")
codeVerifier := c.PostForm("code_verifier")
if code == "" || clientID == "" || redirectURI == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Missing required parameters: code, client_id, redirect_uri",
})
return
}
// Validate client
ctx := context.Background()
client, err := h.oauthService.ValidateClient(ctx, clientID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_client",
"error_description": "Unknown or invalid client_id",
})
return
}
// For confidential clients, validate client_secret
if !client.IsPublic {
clientSecret := c.PostForm("client_secret")
if err := h.oauthService.ValidateClientSecret(client, clientSecret); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "invalid_client",
"error_description": "Invalid client credentials",
})
return
}
}
// Exchange authorization code for tokens
tokenResponse, err := h.oauthService.ExchangeAuthorizationCode(ctx, code, clientID, redirectURI, codeVerifier)
if err != nil {
switch err {
case services.ErrCodeExpired:
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_grant",
"error_description": "Authorization code has expired",
})
case services.ErrCodeUsed:
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_grant",
"error_description": "Authorization code has already been used",
})
case services.ErrPKCEVerifyFailed:
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_grant",
"error_description": "PKCE verification failed",
})
default:
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_grant",
"error_description": "Invalid authorization code",
})
}
return
}
c.JSON(http.StatusOK, tokenResponse)
}
// tokenRefreshToken handles the refresh_token grant
func (h *OAuthHandler) tokenRefreshToken(c *gin.Context) {
refreshToken := c.PostForm("refresh_token")
clientID := c.PostForm("client_id")
scope := c.PostForm("scope")
if refreshToken == "" || clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Missing required parameters: refresh_token, client_id",
})
return
}
ctx := context.Background()
// Refresh access token
tokenResponse, err := h.oauthService.RefreshAccessToken(ctx, refreshToken, clientID, scope)
if err != nil {
switch err {
case services.ErrInvalidScope:
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_scope",
"error_description": "Requested scope exceeds original grant",
})
default:
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_grant",
"error_description": "Invalid or expired refresh token",
})
}
return
}
c.JSON(http.StatusOK, tokenResponse)
}
// Revoke handles token revocation
// POST /oauth/revoke
func (h *OAuthHandler) Revoke(c *gin.Context) {
token := c.PostForm("token")
tokenTypeHint := c.PostForm("token_type_hint")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Missing token parameter",
})
return
}
ctx := context.Background()
_ = h.oauthService.RevokeToken(ctx, token, tokenTypeHint)
// RFC 7009: Always return 200 OK
c.Status(http.StatusOK)
}
// Introspect handles token introspection (for resource servers)
// POST /oauth/introspect
func (h *OAuthHandler) Introspect(c *gin.Context) {
token := c.PostForm("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Missing token parameter",
})
return
}
ctx := context.Background()
claims, err := h.oauthService.ValidateAccessToken(ctx, token)
if err != nil {
c.JSON(http.StatusOK, gin.H{"active": false})
return
}
c.JSON(http.StatusOK, gin.H{
"active": true,
"sub": (*claims)["sub"],
"client_id": (*claims)["client_id"],
"scope": (*claims)["scope"],
"exp": (*claims)["exp"],
"iat": (*claims)["iat"],
"iss": (*claims)["iss"],
})
}
// ========================================
// 2FA (TOTP) Endpoints
// ========================================
// Setup2FA initiates 2FA setup
// POST /auth/2fa/setup
func (h *OAuthHandler) Setup2FA(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
return
}
// Get user email
ctx := context.Background()
user, err := h.authService.GetUserByID(ctx, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return
}
// Setup 2FA
response, err := h.totpService.Setup2FA(ctx, userID, user.Email)
if err != nil {
switch err {
case services.ErrTOTPAlreadyEnabled:
c.JSON(http.StatusConflict, gin.H{"error": "2FA is already enabled for this account"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to setup 2FA"})
}
return
}
c.JSON(http.StatusOK, response)
}
// Verify2FASetup verifies the 2FA setup with a code
// POST /auth/2fa/verify-setup
func (h *OAuthHandler) Verify2FASetup(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
return
}
var req models.Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := context.Background()
err = h.totpService.Verify2FASetup(ctx, userID, req.Code)
if err != nil {
switch err {
case services.ErrTOTPAlreadyEnabled:
c.JSON(http.StatusConflict, gin.H{"error": "2FA is already enabled"})
case services.ErrTOTPInvalidCode:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to verify 2FA setup"})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "2FA enabled successfully"})
}
// Verify2FAChallenge verifies a 2FA challenge during login
// POST /auth/2fa/verify
func (h *OAuthHandler) Verify2FAChallenge(c *gin.Context) {
var req models.Verify2FAChallengeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := context.Background()
var userID *uuid.UUID
var err error
if req.RecoveryCode != "" {
// Verify with recovery code
userID, err = h.totpService.VerifyChallengeWithRecoveryCode(ctx, req.ChallengeID, req.RecoveryCode)
} else {
// Verify with TOTP code
userID, err = h.totpService.VerifyChallenge(ctx, req.ChallengeID, req.Code)
}
if err != nil {
switch err {
case services.ErrTOTPChallengeExpired:
c.JSON(http.StatusGone, gin.H{"error": "2FA challenge has expired"})
case services.ErrTOTPInvalidCode:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
case services.ErrRecoveryCodeInvalid:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid recovery code"})
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "2FA verification failed"})
}
return
}
// Get user and generate tokens
user, err := h.authService.GetUserByID(ctx, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return
}
// Generate access token
accessToken, err := h.authService.GenerateAccessToken(user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
// Generate refresh token
refreshToken, refreshTokenHash, err := h.authService.GenerateRefreshToken()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate refresh token"})
return
}
// Store session
ipAddress := middleware.GetClientIP(c)
userAgent := middleware.GetUserAgent(c)
// We need direct DB access for this, or we need to add a method to AuthService
// For now, we'll return the tokens and let the caller handle session storage
c.JSON(http.StatusOK, gin.H{
"access_token": accessToken,
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": 3600,
"user": map[string]interface{}{
"id": user.ID,
"email": user.Email,
"name": user.Name,
"role": user.Role,
},
"_session_hash": refreshTokenHash,
"_ip": ipAddress,
"_user_agent": userAgent,
})
}
// Disable2FA disables 2FA for the current user
// POST /auth/2fa/disable
func (h *OAuthHandler) Disable2FA(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
return
}
var req models.Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := context.Background()
err = h.totpService.Disable2FA(ctx, userID, req.Code)
if err != nil {
switch err {
case services.ErrTOTPNotEnabled:
c.JSON(http.StatusNotFound, gin.H{"error": "2FA is not enabled"})
case services.ErrTOTPInvalidCode:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to disable 2FA"})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "2FA disabled successfully"})
}
// Get2FAStatus returns the 2FA status for the current user
// GET /auth/2fa/status
func (h *OAuthHandler) Get2FAStatus(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
return
}
ctx := context.Background()
status, err := h.totpService.GetStatus(ctx, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get 2FA status"})
return
}
c.JSON(http.StatusOK, status)
}
// RegenerateRecoveryCodes generates new recovery codes
// POST /auth/2fa/recovery-codes
func (h *OAuthHandler) RegenerateRecoveryCodes(c *gin.Context) {
userID, err := middleware.GetUserID(c)
if err != nil || userID == uuid.Nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
return
}
var req models.Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := context.Background()
codes, err := h.totpService.RegenerateRecoveryCodes(ctx, userID, req.Code)
if err != nil {
switch err {
case services.ErrTOTPNotEnabled:
c.JSON(http.StatusNotFound, gin.H{"error": "2FA is not enabled"})
case services.ErrTOTPInvalidCode:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to regenerate recovery codes"})
}
return
}
c.JSON(http.StatusOK, gin.H{"recovery_codes": codes})
}
// ========================================
// Enhanced Login with 2FA
// ========================================
// LoginWith2FA handles login with optional 2FA
// POST /auth/login
func (h *OAuthHandler) LoginWith2FA(c *gin.Context) {
var req models.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := context.Background()
ipAddress := middleware.GetClientIP(c)
userAgent := middleware.GetUserAgent(c)
// Attempt login
response, err := h.authService.Login(ctx, &req, ipAddress, userAgent)
if err != nil {
switch err {
case services.ErrInvalidCredentials:
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid email or password"})
case services.ErrAccountLocked:
c.JSON(http.StatusForbidden, gin.H{"error": "Account is temporarily locked"})
case services.ErrAccountSuspended:
c.JSON(http.StatusForbidden, gin.H{"error": "Account is suspended"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Login failed"})
}
return
}
// Check if 2FA is enabled
twoFactorEnabled, _ := h.totpService.IsTwoFactorEnabled(ctx, response.User.ID)
if twoFactorEnabled {
// Create 2FA challenge
challengeID, err := h.totpService.CreateChallenge(ctx, response.User.ID, ipAddress, userAgent)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create 2FA challenge"})
return
}
// Return 2FA required response
c.JSON(http.StatusOK, gin.H{
"requires_2fa": true,
"challenge_id": challengeID,
"message": "2FA verification required",
})
return
}
// No 2FA required, return tokens
c.JSON(http.StatusOK, gin.H{
"requires_2fa": false,
"access_token": response.AccessToken,
"refresh_token": response.RefreshToken,
"token_type": "Bearer",
"expires_in": response.ExpiresIn,
"user": map[string]interface{}{
"id": response.User.ID,
"email": response.User.Email,
"name": response.User.Name,
"role": response.User.Role,
},
})
}
// ========================================
// Registration with mandatory 2FA setup
// ========================================
// RegisterWith2FA handles registration with mandatory 2FA setup
// POST /auth/register
func (h *OAuthHandler) RegisterWith2FA(c *gin.Context) {
var req models.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
ctx := context.Background()
// Validate password strength
if len(req.Password) < 8 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Password must be at least 8 characters"})
return
}
// Register user
user, verificationToken, err := h.authService.Register(ctx, &req)
if err != nil {
switch err {
case services.ErrUserExists:
c.JSON(http.StatusConflict, gin.H{"error": "A user with this email already exists"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "Registration failed"})
}
return
}
// Setup 2FA immediately
twoFAResponse, err := h.totpService.Setup2FA(ctx, user.ID, user.Email)
if err != nil {
// Non-fatal - user can set up 2FA later, but log it
c.JSON(http.StatusCreated, gin.H{
"message": "Registration successful. Please verify your email.",
"user_id": user.ID,
"verification_token": verificationToken, // In production, this would be sent via email
"two_factor_setup": nil,
"two_factor_error": "Failed to initialize 2FA. Please set it up in your account settings.",
})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Registration successful. Please verify your email and complete 2FA setup.",
"user_id": user.ID,
"verification_token": verificationToken, // In production, this would be sent via email
"two_factor_setup": map[string]interface{}{
"secret": twoFAResponse.Secret,
"qr_code": twoFAResponse.QRCodeDataURL,
"recovery_codes": twoFAResponse.RecoveryCodes,
"setup_required": true,
"setup_endpoint": "/auth/2fa/verify-setup",
},
})
}
// ========================================
// OAuth Client Management (Admin)
// ========================================
// AdminCreateClient creates a new OAuth client
// POST /admin/oauth/clients
func (h *OAuthHandler) AdminCreateClient(c *gin.Context) {
var req struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
RedirectURIs []string `json:"redirect_uris" binding:"required"`
Scopes []string `json:"scopes"`
GrantTypes []string `json:"grant_types"`
IsPublic bool `json:"is_public"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
return
}
userID, _ := middleware.GetUserID(c)
// Default scopes
if len(req.Scopes) == 0 {
req.Scopes = []string{"openid", "profile", "email"}
}
// Default grant types
if len(req.GrantTypes) == 0 {
req.GrantTypes = []string{"authorization_code", "refresh_token"}
}
ctx := context.Background()
client, clientSecret, err := h.oauthService.CreateClient(
ctx, req.Name, req.Description, req.RedirectURIs, req.Scopes, req.GrantTypes, req.IsPublic, &userID,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create client"})
return
}
response := gin.H{
"client_id": client.ClientID,
"name": client.Name,
"redirect_uris": client.RedirectURIs,
"scopes": client.Scopes,
"grant_types": client.GrantTypes,
"is_public": client.IsPublic,
}
// Only show client_secret once for confidential clients
if !client.IsPublic && clientSecret != "" {
response["client_secret"] = clientSecret
response["client_secret_warning"] = "Store this secret securely. It will not be shown again."
}
c.JSON(http.StatusCreated, response)
}
// AdminGetClients lists all OAuth clients
// GET /admin/oauth/clients
func (h *OAuthHandler) AdminGetClients(c *gin.Context) {
// This would need a new method in OAuthService
// For now, return a placeholder
c.JSON(http.StatusOK, gin.H{
"clients": []interface{}{},
"message": "Client listing not yet implemented",
})
}
// ========================================
// Helper Functions
// ========================================
func redirectWithError(c *gin.Context, redirectURI, errorCode, errorDescription, state string) {
separator := "?"
if strings.Contains(redirectURI, "?") {
separator = "&"
}
redirectURL := redirectURI + separator + "error=" + errorCode + "&error_description=" + errorDescription
if state != "" {
redirectURL += "&state=" + state
}
c.Redirect(http.StatusFound, redirectURL)
}
@@ -0,0 +1,933 @@
package handlers
import (
"net/http"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// SchoolHandlers contains all school-related HTTP handlers
type SchoolHandlers struct {
schoolService *services.SchoolService
attendanceService *services.AttendanceService
gradeService *services.GradeService
}
// NewSchoolHandlers creates new school handlers
func NewSchoolHandlers(schoolService *services.SchoolService, attendanceService *services.AttendanceService, gradeService *services.GradeService) *SchoolHandlers {
return &SchoolHandlers{
schoolService: schoolService,
attendanceService: attendanceService,
gradeService: gradeService,
}
}
// ========================================
// School Handlers
// ========================================
// CreateSchool creates a new school
// POST /api/v1/schools
func (h *SchoolHandlers) CreateSchool(c *gin.Context) {
var req models.CreateSchoolRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
school, err := h.schoolService.CreateSchool(c.Request.Context(), req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, school)
}
// GetSchool retrieves a school by ID
// GET /api/v1/schools/:id
func (h *SchoolHandlers) GetSchool(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
school, err := h.schoolService.GetSchool(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, school)
}
// ListSchools lists all schools
// GET /api/v1/schools
func (h *SchoolHandlers) ListSchools(c *gin.Context) {
schools, err := h.schoolService.ListSchools(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, schools)
}
// ========================================
// School Year Handlers
// ========================================
// CreateSchoolYear creates a new school year
// POST /api/v1/schools/:id/years
func (h *SchoolHandlers) CreateSchoolYear(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
var req struct {
Name string `json:"name" binding:"required"`
StartDate string `json:"start_date" binding:"required"`
EndDate string `json:"end_date" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
startDate, err := time.Parse("2006-01-02", req.StartDate)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid start date format"})
return
}
endDate, err := time.Parse("2006-01-02", req.EndDate)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid end date format"})
return
}
schoolYear, err := h.schoolService.CreateSchoolYear(c.Request.Context(), schoolID, req.Name, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, schoolYear)
}
// SetCurrentSchoolYear sets a school year as current
// PUT /api/v1/schools/:id/years/:yearId/current
func (h *SchoolHandlers) SetCurrentSchoolYear(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
yearIDStr := c.Param("yearId")
yearID, err := uuid.Parse(yearIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
return
}
if err := h.schoolService.SetCurrentSchoolYear(c.Request.Context(), schoolID, yearID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "school year set as current"})
}
// ========================================
// Class Handlers
// ========================================
// CreateClass creates a new class
// POST /api/v1/schools/:id/classes
func (h *SchoolHandlers) CreateClass(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
var req models.CreateClassRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
class, err := h.schoolService.CreateClass(c.Request.Context(), schoolID, req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, class)
}
// GetClass retrieves a class by ID
// GET /api/v1/classes/:id
func (h *SchoolHandlers) GetClass(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
class, err := h.schoolService.GetClass(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, class)
}
// ListClasses lists all classes for a school in a school year
// GET /api/v1/schools/:id/classes?school_year_id=...
func (h *SchoolHandlers) ListClasses(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
schoolYearIDStr := c.Query("school_year_id")
if schoolYearIDStr == "" {
// Get current school year
schoolYear, err := h.schoolService.GetCurrentSchoolYear(c.Request.Context(), schoolID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "no current school year set"})
return
}
schoolYearIDStr = schoolYear.ID.String()
}
schoolYearID, err := uuid.Parse(schoolYearIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
return
}
classes, err := h.schoolService.ListClasses(c.Request.Context(), schoolID, schoolYearID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, classes)
}
// ========================================
// Student Handlers
// ========================================
// CreateStudent creates a new student
// POST /api/v1/schools/:id/students
func (h *SchoolHandlers) CreateStudent(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
var req models.CreateStudentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
student, err := h.schoolService.CreateStudent(c.Request.Context(), schoolID, req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, student)
}
// GetStudent retrieves a student by ID
// GET /api/v1/students/:id
func (h *SchoolHandlers) GetStudent(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
return
}
student, err := h.schoolService.GetStudent(c.Request.Context(), id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, student)
}
// ListStudentsByClass lists all students in a class
// GET /api/v1/classes/:id/students
func (h *SchoolHandlers) ListStudentsByClass(c *gin.Context) {
classIDStr := c.Param("id")
classID, err := uuid.Parse(classIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
students, err := h.schoolService.ListStudentsByClass(c.Request.Context(), classID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, students)
}
// ========================================
// Subject Handlers
// ========================================
// CreateSubject creates a new subject
// POST /api/v1/schools/:id/subjects
func (h *SchoolHandlers) CreateSubject(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
var req struct {
Name string `json:"name" binding:"required"`
ShortName string `json:"short_name" binding:"required"`
Color *string `json:"color"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
subject, err := h.schoolService.CreateSubject(c.Request.Context(), schoolID, req.Name, req.ShortName, req.Color)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, subject)
}
// ListSubjects lists all subjects for a school
// GET /api/v1/schools/:id/subjects
func (h *SchoolHandlers) ListSubjects(c *gin.Context) {
schoolIDStr := c.Param("id")
schoolID, err := uuid.Parse(schoolIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
subjects, err := h.schoolService.ListSubjects(c.Request.Context(), schoolID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, subjects)
}
// ========================================
// Attendance Handlers
// ========================================
// RecordAttendance records attendance for a student
// POST /api/v1/attendance
func (h *SchoolHandlers) RecordAttendance(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
var req models.RecordAttendanceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
record, err := h.attendanceService.RecordAttendance(c.Request.Context(), req, userID.(uuid.UUID))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, record)
}
// RecordBulkAttendance records attendance for multiple students
// POST /api/v1/classes/:id/attendance
func (h *SchoolHandlers) RecordBulkAttendance(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
classIDStr := c.Param("id")
classID, err := uuid.Parse(classIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
var req struct {
Date string `json:"date" binding:"required"`
SlotID string `json:"slot_id" binding:"required"`
Records []struct {
StudentID string `json:"student_id"`
Status string `json:"status"`
Note *string `json:"note"`
} `json:"records" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
slotID, err := uuid.Parse(req.SlotID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid slot ID"})
return
}
// Convert to the expected type (without JSON tags)
records := make([]struct {
StudentID string
Status string
Note *string
}, len(req.Records))
for i, r := range req.Records {
records[i] = struct {
StudentID string
Status string
Note *string
}{
StudentID: r.StudentID,
Status: r.Status,
Note: r.Note,
}
}
err = h.attendanceService.RecordBulkAttendance(c.Request.Context(), classID, req.Date, slotID, records, userID.(uuid.UUID))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "attendance recorded"})
}
// GetClassAttendance gets attendance for a class on a specific date
// GET /api/v1/classes/:id/attendance?date=...
func (h *SchoolHandlers) GetClassAttendance(c *gin.Context) {
classIDStr := c.Param("id")
classID, err := uuid.Parse(classIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
date := c.Query("date")
if date == "" {
date = time.Now().Format("2006-01-02")
}
overview, err := h.attendanceService.GetAttendanceByClass(c.Request.Context(), classID, date)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, overview)
}
// GetStudentAttendance gets attendance history for a student
// GET /api/v1/students/:id/attendance?start_date=...&end_date=...
func (h *SchoolHandlers) GetStudentAttendance(c *gin.Context) {
studentIDStr := c.Param("id")
studentID, err := uuid.Parse(studentIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
return
}
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
var startDate, endDate time.Time
if startDateStr == "" {
startDate = time.Now().AddDate(0, -1, 0) // Last month
} else {
startDate, _ = time.Parse("2006-01-02", startDateStr)
}
if endDateStr == "" {
endDate = time.Now()
} else {
endDate, _ = time.Parse("2006-01-02", endDateStr)
}
records, err := h.attendanceService.GetStudentAttendance(c.Request.Context(), studentID, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, records)
}
// ========================================
// Absence Report Handlers
// ========================================
// ReportAbsence allows parents to report absence
// POST /api/v1/absence/report
func (h *SchoolHandlers) ReportAbsence(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
var req models.ReportAbsenceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
report, err := h.attendanceService.ReportAbsence(c.Request.Context(), req, userID.(uuid.UUID))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, report)
}
// ConfirmAbsence allows teachers to confirm absence
// PUT /api/v1/absence/:id/confirm
func (h *SchoolHandlers) ConfirmAbsence(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
reportIDStr := c.Param("id")
reportID, err := uuid.Parse(reportIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid report ID"})
return
}
var req struct {
Status string `json:"status" binding:"required"` // "excused" or "unexcused"
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err = h.attendanceService.ConfirmAbsence(c.Request.Context(), reportID, userID.(uuid.UUID), req.Status)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "absence confirmed"})
}
// GetPendingAbsenceReports gets pending absence reports for a class
// GET /api/v1/classes/:id/absence/pending
func (h *SchoolHandlers) GetPendingAbsenceReports(c *gin.Context) {
classIDStr := c.Param("id")
classID, err := uuid.Parse(classIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
reports, err := h.attendanceService.GetPendingAbsenceReports(c.Request.Context(), classID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, reports)
}
// ========================================
// Grade Handlers
// ========================================
// CreateGrade creates a new grade
// POST /api/v1/grades
func (h *SchoolHandlers) CreateGrade(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
var req models.CreateGradeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Get teacher ID from user ID
teacher, err := h.schoolService.GetTeacherByUserID(c.Request.Context(), userID.(uuid.UUID))
if err != nil {
c.JSON(http.StatusForbidden, gin.H{"error": "user is not a teacher"})
return
}
grade, err := h.gradeService.CreateGrade(c.Request.Context(), req, teacher.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, grade)
}
// GetStudentGrades gets all grades for a student
// GET /api/v1/students/:id/grades?school_year_id=...
func (h *SchoolHandlers) GetStudentGrades(c *gin.Context) {
studentIDStr := c.Param("id")
studentID, err := uuid.Parse(studentIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
return
}
schoolYearIDStr := c.Query("school_year_id")
if schoolYearIDStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "school_year_id is required"})
return
}
schoolYearID, err := uuid.Parse(schoolYearIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
return
}
grades, err := h.gradeService.GetStudentGrades(c.Request.Context(), studentID, schoolYearID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, grades)
}
// GetClassGrades gets grades for all students in a class for a subject (Notenspiegel)
// GET /api/v1/classes/:id/grades/:subjectId?school_year_id=...&semester=...
func (h *SchoolHandlers) GetClassGrades(c *gin.Context) {
classIDStr := c.Param("id")
classID, err := uuid.Parse(classIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
subjectIDStr := c.Param("subjectId")
subjectID, err := uuid.Parse(subjectIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid subject ID"})
return
}
schoolYearIDStr := c.Query("school_year_id")
if schoolYearIDStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "school_year_id is required"})
return
}
schoolYearID, err := uuid.Parse(schoolYearIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
return
}
semesterStr := c.DefaultQuery("semester", "1")
var semester int
if semesterStr == "1" {
semester = 1
} else {
semester = 2
}
overviews, err := h.gradeService.GetClassGradesBySubject(c.Request.Context(), classID, subjectID, schoolYearID, semester)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, overviews)
}
// GetGradeStatistics gets grade statistics for a class/subject
// GET /api/v1/classes/:id/grades/:subjectId/stats?school_year_id=...&semester=...
func (h *SchoolHandlers) GetGradeStatistics(c *gin.Context) {
classIDStr := c.Param("id")
classID, err := uuid.Parse(classIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
subjectIDStr := c.Param("subjectId")
subjectID, err := uuid.Parse(subjectIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid subject ID"})
return
}
schoolYearIDStr := c.Query("school_year_id")
if schoolYearIDStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "school_year_id is required"})
return
}
schoolYearID, err := uuid.Parse(schoolYearIDStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
return
}
semesterStr := c.DefaultQuery("semester", "1")
var semester int
if semesterStr == "1" {
semester = 1
} else {
semester = 2
}
stats, err := h.gradeService.GetSubjectGradeStatistics(c.Request.Context(), classID, subjectID, schoolYearID, semester)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// ========================================
// Parent Onboarding Handlers
// ========================================
// GenerateOnboardingToken generates a QR code token for parent onboarding
// POST /api/v1/onboarding/tokens
func (h *SchoolHandlers) GenerateOnboardingToken(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
var req struct {
SchoolID string `json:"school_id" binding:"required"`
ClassID string `json:"class_id" binding:"required"`
StudentID string `json:"student_id" binding:"required"`
Role string `json:"role"` // "parent" or "parent_representative"
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
schoolID, err := uuid.Parse(req.SchoolID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
return
}
classID, err := uuid.Parse(req.ClassID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
return
}
studentID, err := uuid.Parse(req.StudentID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
return
}
role := req.Role
if role == "" {
role = "parent"
}
token, err := h.schoolService.GenerateParentOnboardingToken(c.Request.Context(), schoolID, classID, studentID, userID.(uuid.UUID), role)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Generate QR code URL
qrURL := "/onboard-parent?token=" + token.Token
c.JSON(http.StatusCreated, gin.H{
"token": token.Token,
"qr_url": qrURL,
"expires_at": token.ExpiresAt,
})
}
// ValidateOnboardingToken validates an onboarding token
// GET /api/v1/onboarding/validate?token=...
func (h *SchoolHandlers) ValidateOnboardingToken(c *gin.Context) {
token := c.Query("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "token is required"})
return
}
onboardingToken, err := h.schoolService.ValidateOnboardingToken(c.Request.Context(), token)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "invalid or expired token"})
return
}
// Get student and school info
student, err := h.schoolService.GetStudent(c.Request.Context(), onboardingToken.StudentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
class, err := h.schoolService.GetClass(c.Request.Context(), onboardingToken.ClassID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
school, err := h.schoolService.GetSchool(c.Request.Context(), onboardingToken.SchoolID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"valid": true,
"role": onboardingToken.Role,
"student_name": student.FirstName + " " + student.LastName,
"class_name": class.Name,
"school_name": school.Name,
"expires_at": onboardingToken.ExpiresAt,
})
}
// RedeemOnboardingToken redeems a token and creates parent account
// POST /api/v1/onboarding/redeem
func (h *SchoolHandlers) RedeemOnboardingToken(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
return
}
var req struct {
Token string `json:"token" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err := h.schoolService.RedeemOnboardingToken(c.Request.Context(), req.Token, userID.(uuid.UUID))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "token redeemed successfully"})
}
// ========================================
// Register Routes
// ========================================
// RegisterRoutes registers all school-related routes
func (h *SchoolHandlers) RegisterRoutes(r *gin.RouterGroup, authMiddleware gin.HandlerFunc) {
// Public routes (for onboarding)
r.GET("/onboarding/validate", h.ValidateOnboardingToken)
// Protected routes
protected := r.Group("")
protected.Use(authMiddleware)
// Schools
protected.POST("/schools", h.CreateSchool)
protected.GET("/schools", h.ListSchools)
protected.GET("/schools/:id", h.GetSchool)
protected.POST("/schools/:id/years", h.CreateSchoolYear)
protected.PUT("/schools/:id/years/:yearId/current", h.SetCurrentSchoolYear)
protected.POST("/schools/:id/classes", h.CreateClass)
protected.GET("/schools/:id/classes", h.ListClasses)
protected.POST("/schools/:id/students", h.CreateStudent)
protected.POST("/schools/:id/subjects", h.CreateSubject)
protected.GET("/schools/:id/subjects", h.ListSubjects)
// Classes
protected.GET("/classes/:id", h.GetClass)
protected.GET("/classes/:id/students", h.ListStudentsByClass)
protected.GET("/classes/:id/attendance", h.GetClassAttendance)
protected.POST("/classes/:id/attendance", h.RecordBulkAttendance)
protected.GET("/classes/:id/absence/pending", h.GetPendingAbsenceReports)
protected.GET("/classes/:id/grades/:subjectId", h.GetClassGrades)
protected.GET("/classes/:id/grades/:subjectId/stats", h.GetGradeStatistics)
// Students
protected.GET("/students/:id", h.GetStudent)
protected.GET("/students/:id/attendance", h.GetStudentAttendance)
protected.GET("/students/:id/grades", h.GetStudentGrades)
// Attendance & Absence
protected.POST("/attendance", h.RecordAttendance)
protected.POST("/absence/report", h.ReportAbsence)
protected.PUT("/absence/:id/confirm", h.ConfirmAbsence)
// Grades
protected.POST("/grades", h.CreateGrade)
// Onboarding
protected.POST("/onboarding/tokens", h.GenerateOnboardingToken)
protected.POST("/onboarding/redeem", h.RedeemOnboardingToken)
}
@@ -0,0 +1,247 @@
package middleware
import (
"net/http"
"os"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// InputGateConfig holds configuration for input validation.
type InputGateConfig struct {
// Maximum request body size (default: 10MB)
MaxBodySize int64
// Maximum file upload size (default: 50MB)
MaxFileSize int64
// Allowed content types
AllowedContentTypes map[string]bool
// Allowed file types for uploads
AllowedFileTypes map[string]bool
// Blocked file extensions
BlockedExtensions map[string]bool
// Paths that allow larger uploads
LargeUploadPaths []string
// Paths excluded from validation
ExcludedPaths []string
// Enable strict content type checking
StrictContentType bool
}
// DefaultInputGateConfig returns sensible default configuration.
func DefaultInputGateConfig() InputGateConfig {
maxSize := int64(10 * 1024 * 1024) // 10MB
if envSize := os.Getenv("MAX_REQUEST_BODY_SIZE"); envSize != "" {
if size, err := strconv.ParseInt(envSize, 10, 64); err == nil {
maxSize = size
}
}
return InputGateConfig{
MaxBodySize: maxSize,
MaxFileSize: 50 * 1024 * 1024, // 50MB
AllowedContentTypes: map[string]bool{
"application/json": true,
"application/x-www-form-urlencoded": true,
"multipart/form-data": true,
"text/plain": true,
},
AllowedFileTypes: map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"image/webp": true,
"application/pdf": true,
"text/csv": true,
"application/msword": true,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true,
"application/vnd.ms-excel": true,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true,
},
BlockedExtensions: map[string]bool{
".exe": true, ".bat": true, ".cmd": true, ".com": true, ".msi": true,
".dll": true, ".scr": true, ".pif": true, ".vbs": true, ".js": true,
".jar": true, ".sh": true, ".ps1": true, ".app": true,
},
LargeUploadPaths: []string{
"/api/v1/files/upload",
"/api/v1/documents/upload",
"/api/v1/attachments",
},
ExcludedPaths: []string{
"/health",
"/metrics",
"/api/v1/health",
},
StrictContentType: true,
}
}
// isExcludedPath checks if path is excluded from validation.
func (c *InputGateConfig) isExcludedPath(path string) bool {
for _, excluded := range c.ExcludedPaths {
if path == excluded {
return true
}
}
return false
}
// isLargeUploadPath checks if path allows larger uploads.
func (c *InputGateConfig) isLargeUploadPath(path string) bool {
for _, uploadPath := range c.LargeUploadPaths {
if strings.HasPrefix(path, uploadPath) {
return true
}
}
return false
}
// getMaxSize returns the maximum allowed body size for the path.
func (c *InputGateConfig) getMaxSize(path string) int64 {
if c.isLargeUploadPath(path) {
return c.MaxFileSize
}
return c.MaxBodySize
}
// validateContentType validates the content type.
func (c *InputGateConfig) validateContentType(contentType string) (bool, string) {
if contentType == "" {
return true, ""
}
// Extract base content type (remove charset, boundary, etc.)
baseType := strings.Split(contentType, ";")[0]
baseType = strings.TrimSpace(strings.ToLower(baseType))
if !c.AllowedContentTypes[baseType] {
return false, "Content-Type '" + baseType + "' is not allowed"
}
return true, ""
}
// hasBlockedExtension checks if filename has a blocked extension.
func (c *InputGateConfig) hasBlockedExtension(filename string) bool {
if filename == "" {
return false
}
lowerFilename := strings.ToLower(filename)
for ext := range c.BlockedExtensions {
if strings.HasSuffix(lowerFilename, ext) {
return true
}
}
return false
}
// InputGate returns a middleware that validates incoming request bodies.
//
// Usage:
//
// r.Use(middleware.InputGate())
//
// // Or with custom config:
// config := middleware.DefaultInputGateConfig()
// config.MaxBodySize = 5 * 1024 * 1024 // 5MB
// r.Use(middleware.InputGateWithConfig(config))
func InputGate() gin.HandlerFunc {
return InputGateWithConfig(DefaultInputGateConfig())
}
// InputGateWithConfig returns an input gate middleware with custom configuration.
func InputGateWithConfig(config InputGateConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip excluded paths
if config.isExcludedPath(c.Request.URL.Path) {
c.Next()
return
}
// Skip validation for GET, HEAD, OPTIONS requests
method := c.Request.Method
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
c.Next()
return
}
// Validate content type for requests with body
contentType := c.GetHeader("Content-Type")
if config.StrictContentType {
valid, errMsg := config.validateContentType(contentType)
if !valid {
c.AbortWithStatusJSON(http.StatusUnsupportedMediaType, gin.H{
"error": "unsupported_media_type",
"message": errMsg,
})
return
}
}
// Check Content-Length header
contentLength := c.GetHeader("Content-Length")
if contentLength != "" {
length, err := strconv.ParseInt(contentLength, 10, 64)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": "invalid_content_length",
"message": "Invalid Content-Length header",
})
return
}
maxSize := config.getMaxSize(c.Request.URL.Path)
if length > maxSize {
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
"error": "payload_too_large",
"message": "Request body exceeds maximum size",
"max_size": maxSize,
})
return
}
}
// Set max multipart memory for file uploads
if strings.Contains(contentType, "multipart/form-data") {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, config.MaxFileSize)
}
c.Next()
}
}
// ValidateFileUpload validates a file upload.
// Use this in upload handlers for detailed validation.
func ValidateFileUpload(filename, contentType string, size int64, config *InputGateConfig) (bool, string) {
if config == nil {
defaultConfig := DefaultInputGateConfig()
config = &defaultConfig
}
// Check size
if size > config.MaxFileSize {
return false, "File size exceeds maximum allowed"
}
// Check extension
if config.hasBlockedExtension(filename) {
return false, "File extension is not allowed"
}
// Check content type
if contentType != "" && !config.AllowedFileTypes[contentType] {
return false, "File type '" + contentType + "' is not allowed"
}
return true, ""
}
@@ -0,0 +1,421 @@
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestInputGate_AllowsGETRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for GET request, got %d", w.Code)
}
}
func TestInputGate_AllowsHEADRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.HEAD("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodHead, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for HEAD request, got %d", w.Code)
}
}
func TestInputGate_AllowsOPTIONSRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.OPTIONS("/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for OPTIONS request, got %d", w.Code)
}
}
func TestInputGate_AllowsValidJSONRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`{"key": "value"}`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "16")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for valid JSON, got %d", w.Code)
}
}
func TestInputGate_RejectsInvalidContentType(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.StrictContentType = true
router.Use(InputGateWithConfig(config))
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`data`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/xml")
req.Header.Set("Content-Length", "4")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnsupportedMediaType {
t.Errorf("Expected status 415 for invalid content type, got %d", w.Code)
}
}
func TestInputGate_AllowsEmptyContentType(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`data`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
// No Content-Type header
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for empty content type, got %d", w.Code)
}
}
func TestInputGate_RejectsOversizedRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.MaxBodySize = 100 // 100 bytes
router.Use(InputGateWithConfig(config))
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Create a body larger than 100 bytes
largeBody := strings.Repeat("x", 200)
body := bytes.NewBufferString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "200")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Expected status 413 for oversized request, got %d", w.Code)
}
}
func TestInputGate_AllowsLargeUploadPath(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.MaxBodySize = 100 // 100 bytes
config.MaxFileSize = 1000 // 1000 bytes
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
router.Use(InputGateWithConfig(config))
router.POST("/api/v1/files/upload", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Create a body larger than MaxBodySize but smaller than MaxFileSize
largeBody := strings.Repeat("x", 500)
body := bytes.NewBufferString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1/files/upload", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "500")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for large upload path, got %d", w.Code)
}
}
func TestInputGate_ExcludedPaths(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultInputGateConfig()
config.MaxBodySize = 10 // Very small
config.ExcludedPaths = []string{"/health"}
router.Use(InputGateWithConfig(config))
router.POST("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
})
// Send oversized body to excluded path
largeBody := strings.Repeat("x", 100)
body := bytes.NewBufferString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/health", body)
req.Header.Set("Content-Length", "100")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Should pass because path is excluded
if w.Code != http.StatusOK {
t.Errorf("Expected status 200 for excluded path, got %d", w.Code)
}
}
func TestInputGate_RejectsInvalidContentLength(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`data`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Length", "invalid")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid content length, got %d", w.Code)
}
}
func TestValidateFileUpload_BlockedExtension(t *testing.T) {
tests := []struct {
filename string
contentType string
blocked bool
}{
{"malware.exe", "application/octet-stream", true},
{"script.bat", "application/octet-stream", true},
{"hack.cmd", "application/octet-stream", true},
{"shell.sh", "application/octet-stream", true},
{"powershell.ps1", "application/octet-stream", true},
{"document.pdf", "application/pdf", false},
{"image.jpg", "image/jpeg", false},
{"data.csv", "text/csv", false},
}
for _, tt := range tests {
valid, errMsg := ValidateFileUpload(tt.filename, tt.contentType, 100, nil)
if tt.blocked && valid {
t.Errorf("File %s should be blocked", tt.filename)
}
if !tt.blocked && !valid {
t.Errorf("File %s should not be blocked, error: %s", tt.filename, errMsg)
}
}
}
func TestValidateFileUpload_OversizedFile(t *testing.T) {
config := DefaultInputGateConfig()
config.MaxFileSize = 1000 // 1KB
valid, errMsg := ValidateFileUpload("test.pdf", "application/pdf", 2000, &config)
if valid {
t.Error("Should reject oversized file")
}
if !strings.Contains(errMsg, "size") {
t.Errorf("Error message should mention size, got: %s", errMsg)
}
}
func TestValidateFileUpload_ValidFile(t *testing.T) {
config := DefaultInputGateConfig()
valid, errMsg := ValidateFileUpload("document.pdf", "application/pdf", 1000, &config)
if !valid {
t.Errorf("Should accept valid file, got error: %s", errMsg)
}
}
func TestValidateFileUpload_InvalidContentType(t *testing.T) {
config := DefaultInputGateConfig()
valid, errMsg := ValidateFileUpload("file.xyz", "application/x-unknown", 100, &config)
if valid {
t.Error("Should reject unknown file type")
}
if !strings.Contains(errMsg, "not allowed") {
t.Errorf("Error message should mention not allowed, got: %s", errMsg)
}
}
func TestValidateFileUpload_NilConfig(t *testing.T) {
// Should use default config when nil is passed
valid, _ := ValidateFileUpload("document.pdf", "application/pdf", 1000, nil)
if !valid {
t.Error("Should accept valid file with nil config (uses defaults)")
}
}
func TestHasBlockedExtension(t *testing.T) {
config := DefaultInputGateConfig()
tests := []struct {
filename string
blocked bool
}{
{"test.exe", true},
{"TEST.EXE", true}, // Case insensitive
{"script.BAT", true},
{"app.APP", true},
{"document.pdf", false},
{"image.png", false},
{"", false},
}
for _, tt := range tests {
result := config.hasBlockedExtension(tt.filename)
if result != tt.blocked {
t.Errorf("File %s: expected blocked=%v, got %v", tt.filename, tt.blocked, result)
}
}
}
func TestValidateContentType(t *testing.T) {
config := DefaultInputGateConfig()
tests := []struct {
contentType string
valid bool
}{
{"application/json", true},
{"application/json; charset=utf-8", true},
{"APPLICATION/JSON", true}, // Case insensitive
{"multipart/form-data; boundary=----WebKitFormBoundary", true},
{"text/plain", true},
{"application/xml", false},
{"text/html", false},
{"", true}, // Empty is allowed
}
for _, tt := range tests {
valid, _ := config.validateContentType(tt.contentType)
if valid != tt.valid {
t.Errorf("Content-Type %q: expected valid=%v, got %v", tt.contentType, tt.valid, valid)
}
}
}
func TestIsLargeUploadPath(t *testing.T) {
config := DefaultInputGateConfig()
config.LargeUploadPaths = []string{"/api/v1/files/upload", "/api/v1/documents"}
tests := []struct {
path string
isLarge bool
}{
{"/api/v1/files/upload", true},
{"/api/v1/files/upload/batch", true}, // Prefix match
{"/api/v1/documents", true},
{"/api/v1/documents/1/attachments", true},
{"/api/v1/users", false},
{"/health", false},
}
for _, tt := range tests {
result := config.isLargeUploadPath(tt.path)
if result != tt.isLarge {
t.Errorf("Path %s: expected isLarge=%v, got %v", tt.path, tt.isLarge, result)
}
}
}
func TestGetMaxSize(t *testing.T) {
config := DefaultInputGateConfig()
config.MaxBodySize = 100
config.MaxFileSize = 1000
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
tests := []struct {
path string
expected int64
}{
{"/api/test", 100},
{"/api/v1/files/upload", 1000},
{"/health", 100},
}
for _, tt := range tests {
result := config.getMaxSize(tt.path)
if result != tt.expected {
t.Errorf("Path %s: expected maxSize=%d, got %d", tt.path, tt.expected, result)
}
}
}
func TestInputGate_DefaultMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(InputGate())
router.POST("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
body := bytes.NewBufferString(`{"key": "value"}`)
req := httptest.NewRequest(http.MethodPost, "/test", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
@@ -0,0 +1,379 @@
package middleware
import (
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// UserClaims represents the JWT claims for a user
type UserClaims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// CORS returns a CORS middleware
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// Allow localhost for development
allowedOrigins := []string{
"http://localhost:3000",
"http://localhost:8000",
"http://localhost:8080",
"https://breakpilot.app",
}
allowed := false
for _, o := range allowedOrigins {
if origin == o {
allowed = true
break
}
}
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
}
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// RequestLogger logs each request
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
c.Next()
latency := time.Since(start)
status := c.Writer.Status()
// Log only in development or for errors
if status >= 400 {
gin.DefaultWriter.Write([]byte(
method + " " + path + " " +
string(rune(status)) + " " +
latency.String() + "\n",
))
}
}
}
// RateLimiter implements a simple in-memory rate limiter
// Configurable via RATE_LIMIT_PER_MINUTE env var (default: 500)
func RateLimiter() gin.HandlerFunc {
type client struct {
count int
lastSeen time.Time
}
var (
mu sync.Mutex
clients = make(map[string]*client)
)
// Clean up old entries periodically
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, c := range clients {
if time.Since(c.lastSeen) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(c *gin.Context) {
ip := c.ClientIP()
// Skip rate limiting for Docker internal network (172.x.x.x) and localhost
// This prevents issues when multiple services share the same internal IP
if strings.HasPrefix(ip, "172.") || ip == "127.0.0.1" || ip == "::1" {
c.Next()
return
}
mu.Lock()
defer mu.Unlock()
if _, exists := clients[ip]; !exists {
clients[ip] = &client{}
}
cli := clients[ip]
// Reset count if more than a minute has passed
if time.Since(cli.lastSeen) > time.Minute {
cli.count = 0
}
cli.count++
cli.lastSeen = time.Now()
// Allow 500 requests per minute (increased for admin panels with many API calls)
if cli.count > 500 {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
})
return
}
c.Next()
}
}
// AuthMiddleware validates JWT tokens
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "missing_authorization",
"message": "Authorization header is required",
})
return
}
// Extract token from "Bearer <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_authorization",
"message": "Authorization header must be in format: Bearer <token>",
})
return
}
tokenString := parts[1]
// Parse and validate token
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(jwtSecret), nil
})
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_token",
"message": "Invalid or expired token",
})
return
}
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// Set user info in context
c.Set("user_id", claims.UserID)
c.Set("email", claims.Email)
c.Set("role", claims.Role)
c.Next()
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid_claims",
"message": "Invalid token claims",
})
return
}
}
}
// AdminOnly ensures only admin users can access the route
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User role not found",
})
return
}
roleStr, ok := role.(string)
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "Admin access required",
})
return
}
c.Next()
}
}
// DSBOnly ensures only Data Protection Officers can access the route
// Used for critical operations like publishing legal documents (four-eyes principle)
func DSBOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "User role not found",
})
return
}
roleStr, ok := role.(string)
if !ok || (roleStr != "data_protection_officer" && roleStr != "super_admin") {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "forbidden",
"message": "Only Data Protection Officers can perform this action",
})
return
}
c.Next()
}
}
// IsAdmin checks if the user has admin role
func IsAdmin(c *gin.Context) bool {
role, exists := c.Get("role")
if !exists {
return false
}
roleStr, ok := role.(string)
return ok && (roleStr == "admin" || roleStr == "super_admin" || roleStr == "data_protection_officer")
}
// IsDSB checks if the user has DSB role
func IsDSB(c *gin.Context) bool {
role, exists := c.Get("role")
if !exists {
return false
}
roleStr, ok := role.(string)
return ok && (roleStr == "data_protection_officer" || roleStr == "super_admin")
}
// GetUserID extracts the user ID from the context
func GetUserID(c *gin.Context) (uuid.UUID, error) {
userIDStr, exists := c.Get("user_id")
if !exists {
return uuid.Nil, nil
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
return uuid.Nil, err
}
return userID, nil
}
// GetClientIP returns the client's IP address
func GetClientIP(c *gin.Context) string {
// Check X-Forwarded-For header first (for proxied requests)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// GetUserAgent returns the client's User-Agent
func GetUserAgent(c *gin.Context) string {
return c.GetHeader("User-Agent")
}
// SuspensionCheckMiddleware checks if a user is suspended and restricts access
// Suspended users can only access consent-related endpoints
func SuspensionCheckMiddleware(pool interface{ QueryRow(ctx interface{}, sql string, args ...interface{}) interface{ Scan(dest ...interface{}) error } }) gin.HandlerFunc {
return func(c *gin.Context) {
userIDStr, exists := c.Get("user_id")
if !exists {
c.Next()
return
}
userID, err := uuid.Parse(userIDStr.(string))
if err != nil {
c.Next()
return
}
// Check user account status
var accountStatus string
err = pool.QueryRow(c.Request.Context(), `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
if err != nil {
c.Next()
return
}
if accountStatus == "suspended" {
// Check if current path is allowed for suspended users
path := c.Request.URL.Path
allowedPaths := []string{
"/api/v1/consent",
"/api/v1/documents",
"/api/v1/notifications",
"/api/v1/profile",
"/api/v1/privacy/my-data",
"/api/v1/auth/logout",
}
allowed := false
for _, p := range allowedPaths {
if strings.HasPrefix(path, p) {
allowed = true
break
}
}
if !allowed {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "account_suspended",
"message": "Your account is suspended due to pending consent requirements",
"redirect": "/consent/pending",
})
return
}
// Set suspended flag in context for handlers to use
c.Set("account_suspended", true)
}
c.Next()
}
}
// IsSuspended checks if the current user's account is suspended
func IsSuspended(c *gin.Context) bool {
suspended, exists := c.Get("account_suspended")
if !exists {
return false
}
return suspended.(bool)
}
@@ -0,0 +1,546 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
func init() {
gin.SetMode(gin.TestMode)
}
// Helper to create a valid JWT token for testing
func createTestToken(secret string, userID, email, role string, exp time.Time) string {
claims := UserClaims{
UserID: userID,
Email: email,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(exp),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte(secret))
return tokenString
}
// TestCORS tests the CORS middleware
func TestCORS(t *testing.T) {
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"success": true})
})
tests := []struct {
name string
origin string
method string
expectedStatus int
expectAllowedOrigin bool
}{
{"localhost:3000", "http://localhost:3000", "GET", http.StatusOK, true},
{"localhost:8000", "http://localhost:8000", "GET", http.StatusOK, true},
{"production", "https://breakpilot.app", "GET", http.StatusOK, true},
{"unknown origin", "https://unknown.com", "GET", http.StatusOK, false},
{"preflight", "http://localhost:3000", "OPTIONS", http.StatusNoContent, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest(tt.method, "/test", nil)
req.Header.Set("Origin", tt.origin)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
allowedOrigin := w.Header().Get("Access-Control-Allow-Origin")
if tt.expectAllowedOrigin && allowedOrigin != tt.origin {
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", tt.origin, allowedOrigin)
}
if !tt.expectAllowedOrigin && allowedOrigin != "" {
t.Errorf("Expected no Access-Control-Allow-Origin header, got %s", allowedOrigin)
}
})
}
}
// TestCORSHeaders tests that CORS headers are set correctly
func TestCORSHeaders(t *testing.T) {
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
expectedHeaders := map[string]string{
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Origin, Content-Type, Authorization, X-Requested-With",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400",
}
for header, expected := range expectedHeaders {
actual := w.Header().Get(header)
if actual != expected {
t.Errorf("Expected %s to be %s, got %s", header, expected, actual)
}
}
}
// TestAuthMiddleware_ValidToken tests authentication with valid token
func TestAuthMiddleware_ValidToken(t *testing.T) {
secret := "test-secret-key"
userID := uuid.New().String()
email := "test@example.com"
role := "user"
router := gin.New()
router.Use(AuthMiddleware(secret))
router.GET("/protected", func(c *gin.Context) {
uid, _ := c.Get("user_id")
em, _ := c.Get("email")
r, _ := c.Get("role")
c.JSON(http.StatusOK, gin.H{
"user_id": uid,
"email": em,
"role": r,
})
})
token := createTestToken(secret, userID, email, role, time.Now().Add(time.Hour))
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestAuthMiddleware_MissingHeader tests authentication without header
func TestAuthMiddleware_MissingHeader(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware("test-secret"))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/protected", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestAuthMiddleware_InvalidFormat tests authentication with invalid header format
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware("test-secret"))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
tests := []struct {
name string
header string
}{
{"no Bearer prefix", "some-token"},
{"Basic auth", "Basic dXNlcjpwYXNz"},
{"empty Bearer", "Bearer "},
{"multiple spaces", "Bearer token"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", tt.header)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
})
}
}
// TestAuthMiddleware_ExpiredToken tests authentication with expired token
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
secret := "test-secret"
router := gin.New()
router.Use(AuthMiddleware(secret))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
// Create expired token
token := createTestToken(secret, "user-123", "test@example.com", "user", time.Now().Add(-time.Hour))
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestAuthMiddleware_WrongSecret tests authentication with wrong secret
func TestAuthMiddleware_WrongSecret(t *testing.T) {
router := gin.New()
router.Use(AuthMiddleware("correct-secret"))
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
// Create token with different secret
token := createTestToken("wrong-secret", "user-123", "test@example.com", "user", time.Now().Add(time.Hour))
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestAdminOnly tests the AdminOnly middleware
func TestAdminOnly(t *testing.T) {
tests := []struct {
name string
role string
expectedStatus int
}{
{"admin allowed", "admin", http.StatusOK},
{"super_admin allowed", "super_admin", http.StatusOK},
{"dpo allowed", "data_protection_officer", http.StatusOK},
{"user forbidden", "user", http.StatusForbidden},
{"empty role forbidden", "", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("role", tt.role)
c.Next()
})
router.Use(AdminOnly())
router.GET("/admin", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/admin", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestAdminOnly_NoRole tests AdminOnly when role is not set
func TestAdminOnly_NoRole(t *testing.T) {
router := gin.New()
router.Use(AdminOnly())
router.GET("/admin", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/admin", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
// TestDSBOnly tests the DSBOnly middleware
func TestDSBOnly(t *testing.T) {
tests := []struct {
name string
role string
expectedStatus int
}{
{"dpo allowed", "data_protection_officer", http.StatusOK},
{"super_admin allowed", "super_admin", http.StatusOK},
{"admin forbidden", "admin", http.StatusForbidden},
{"user forbidden", "user", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("role", tt.role)
c.Next()
})
router.Use(DSBOnly())
router.GET("/dsb", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/dsb", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
})
}
}
// TestIsAdmin tests the IsAdmin helper function
func TestIsAdmin(t *testing.T) {
tests := []struct {
name string
role string
expected bool
}{
{"admin", "admin", true},
{"super_admin", "super_admin", true},
{"dpo", "data_protection_officer", true},
{"user", "user", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.role != "" {
c.Set("role", tt.role)
}
result := IsAdmin(c)
if result != tt.expected {
t.Errorf("Expected IsAdmin to be %v, got %v", tt.expected, result)
}
})
}
}
// TestIsDSB tests the IsDSB helper function
func TestIsDSB(t *testing.T) {
tests := []struct {
name string
role string
expected bool
}{
{"dpo", "data_protection_officer", true},
{"super_admin", "super_admin", true},
{"admin", "admin", false},
{"user", "user", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set("role", tt.role)
result := IsDSB(c)
if result != tt.expected {
t.Errorf("Expected IsDSB to be %v, got %v", tt.expected, result)
}
})
}
}
// TestGetUserID tests the GetUserID helper function
func TestGetUserID(t *testing.T) {
validUUID := uuid.New()
tests := []struct {
name string
userID string
setUserID bool
expectError bool
expectedID uuid.UUID
}{
{"valid UUID", validUUID.String(), true, false, validUUID},
{"invalid UUID", "not-a-uuid", true, true, uuid.Nil},
{"missing user_id", "", false, false, uuid.Nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setUserID {
c.Set("user_id", tt.userID)
}
result, err := GetUserID(c)
if tt.expectError && err == nil {
t.Error("Expected error but got none")
}
if !tt.expectError && result != tt.expectedID {
t.Errorf("Expected %v, got %v", tt.expectedID, result)
}
})
}
}
// TestGetClientIP tests the GetClientIP helper function
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
xff string
xri string
clientIP string
expectedIP string
}{
{"X-Forwarded-For", "10.0.0.1", "", "192.168.1.1", "10.0.0.1"},
{"X-Forwarded-For multiple", "10.0.0.1, 10.0.0.2", "", "192.168.1.1", "10.0.0.1"},
{"X-Real-IP", "", "10.0.0.1", "192.168.1.1", "10.0.0.1"},
{"direct", "", "", "192.168.1.1", "192.168.1.1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "/", nil)
if tt.xff != "" {
c.Request.Header.Set("X-Forwarded-For", tt.xff)
}
if tt.xri != "" {
c.Request.Header.Set("X-Real-IP", tt.xri)
}
c.Request.RemoteAddr = tt.clientIP + ":12345"
result := GetClientIP(c)
// Note: gin.ClientIP() might return different values
// depending on trusted proxies config
if result != tt.expectedIP && result != tt.clientIP {
t.Logf("Note: GetClientIP returned %s (expected %s or %s)", result, tt.expectedIP, tt.clientIP)
}
})
}
}
// TestGetUserAgent tests the GetUserAgent helper function
func TestGetUserAgent(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request, _ = http.NewRequest("GET", "/", nil)
expectedUA := "Mozilla/5.0 (Test)"
c.Request.Header.Set("User-Agent", expectedUA)
result := GetUserAgent(c)
if result != expectedUA {
t.Errorf("Expected %s, got %s", expectedUA, result)
}
}
// TestIsSuspended tests the IsSuspended helper function
func TestIsSuspended(t *testing.T) {
tests := []struct {
name string
suspended interface{}
setSuspended bool
expected bool
}{
{"suspended true", true, true, true},
{"suspended false", false, true, false},
{"not set", nil, false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setSuspended {
c.Set("account_suspended", tt.suspended)
}
result := IsSuspended(c)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
// BenchmarkCORS benchmarks the CORS middleware
func BenchmarkCORS(b *testing.B) {
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
// BenchmarkAuthMiddleware benchmarks the auth middleware
func BenchmarkAuthMiddleware(b *testing.B) {
secret := "test-secret-key"
token := createTestToken(secret, uuid.New().String(), "test@example.com", "user", time.Now().Add(time.Hour))
router := gin.New()
router.Use(AuthMiddleware(secret))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
}
@@ -0,0 +1,197 @@
package middleware
import (
"regexp"
"strings"
)
// PIIPattern defines a pattern for identifying PII.
type PIIPattern struct {
Name string
Pattern *regexp.Regexp
Replacement string
}
// PIIRedactor redacts personally identifiable information from strings.
type PIIRedactor struct {
patterns []*PIIPattern
}
// Pre-compiled patterns for common PII types
var (
emailPattern = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b`)
ipv4Pattern = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b`)
ipv6Pattern = regexp.MustCompile(`\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b`)
phonePattern = regexp.MustCompile(`(?:\+49|0049)[\s.-]?\d{2,4}[\s.-]?\d{3,8}|\b0\d{2,4}[\s.-]?\d{3,8}\b`)
ibanPattern = regexp.MustCompile(`(?i)\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b`)
uuidPattern = regexp.MustCompile(`(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b`)
namePattern = regexp.MustCompile(`\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b`)
)
// DefaultPIIPatterns returns the default set of PII patterns.
func DefaultPIIPatterns() []*PIIPattern {
return []*PIIPattern{
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
}
}
// AllPIIPatterns returns all available PII patterns.
func AllPIIPatterns() []*PIIPattern {
return []*PIIPattern{
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
{Name: "iban", Pattern: ibanPattern, Replacement: "[IBAN_REDACTED]"},
{Name: "uuid", Pattern: uuidPattern, Replacement: "[UUID_REDACTED]"},
{Name: "name", Pattern: namePattern, Replacement: "[NAME_REDACTED]"},
}
}
// NewPIIRedactor creates a new PII redactor with the given patterns.
func NewPIIRedactor(patterns []*PIIPattern) *PIIRedactor {
if patterns == nil {
patterns = DefaultPIIPatterns()
}
return &PIIRedactor{patterns: patterns}
}
// NewDefaultPIIRedactor creates a PII redactor with default patterns.
func NewDefaultPIIRedactor() *PIIRedactor {
return NewPIIRedactor(DefaultPIIPatterns())
}
// Redact removes PII from the given text.
func (r *PIIRedactor) Redact(text string) string {
if text == "" {
return text
}
result := text
for _, pattern := range r.patterns {
result = pattern.Pattern.ReplaceAllString(result, pattern.Replacement)
}
return result
}
// ContainsPII checks if the text contains any PII.
func (r *PIIRedactor) ContainsPII(text string) bool {
if text == "" {
return false
}
for _, pattern := range r.patterns {
if pattern.Pattern.MatchString(text) {
return true
}
}
return false
}
// PIIFinding represents a found PII instance.
type PIIFinding struct {
Type string
Match string
Start int
End int
}
// FindPII finds all PII in the text.
func (r *PIIRedactor) FindPII(text string) []PIIFinding {
if text == "" {
return nil
}
var findings []PIIFinding
for _, pattern := range r.patterns {
matches := pattern.Pattern.FindAllStringIndex(text, -1)
for _, match := range matches {
findings = append(findings, PIIFinding{
Type: pattern.Name,
Match: text[match[0]:match[1]],
Start: match[0],
End: match[1],
})
}
}
return findings
}
// Default module-level redactor
var defaultRedactor = NewDefaultPIIRedactor()
// RedactPII is a convenience function that uses the default redactor.
func RedactPII(text string) string {
return defaultRedactor.Redact(text)
}
// ContainsPIIDefault checks if text contains PII using default patterns.
func ContainsPIIDefault(text string) bool {
return defaultRedactor.ContainsPII(text)
}
// RedactMap redacts PII from all string values in a map.
func RedactMap(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
switch v := value.(type) {
case string:
result[key] = RedactPII(v)
case map[string]interface{}:
result[key] = RedactMap(v)
case []interface{}:
result[key] = redactSlice(v)
default:
result[key] = v
}
}
return result
}
func redactSlice(data []interface{}) []interface{} {
result := make([]interface{}, len(data))
for i, value := range data {
switch v := value.(type) {
case string:
result[i] = RedactPII(v)
case map[string]interface{}:
result[i] = RedactMap(v)
case []interface{}:
result[i] = redactSlice(v)
default:
result[i] = v
}
}
return result
}
// SafeLogString creates a safe-to-log version of sensitive data.
// Use this for logging user-related information.
func SafeLogString(format string, args ...interface{}) string {
// Convert args to strings and redact
safeArgs := make([]interface{}, len(args))
for i, arg := range args {
switch v := arg.(type) {
case string:
safeArgs[i] = RedactPII(v)
case error:
safeArgs[i] = RedactPII(v.Error())
default:
safeArgs[i] = arg
}
}
// Note: We can't use fmt.Sprintf here due to the variadic nature
// Instead, we redact the result
result := format
for _, arg := range safeArgs {
if s, ok := arg.(string); ok {
result = strings.Replace(result, "%s", s, 1)
result = strings.Replace(result, "%v", s, 1)
}
}
return RedactPII(result)
}
@@ -0,0 +1,228 @@
package middleware
import (
"testing"
)
func TestPIIRedactor_RedactsEmail(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "User test@example.com logged in"
result := redactor.Redact(text)
if result == text {
t.Error("Email should have been redacted")
}
if result != "User [EMAIL_REDACTED] logged in" {
t.Errorf("Unexpected result: %s", result)
}
}
func TestPIIRedactor_RedactsIPv4(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "Request from 192.168.1.100"
result := redactor.Redact(text)
if result == text {
t.Error("IP should have been redacted")
}
if result != "Request from [IP_REDACTED]" {
t.Errorf("Unexpected result: %s", result)
}
}
func TestPIIRedactor_RedactsGermanPhone(t *testing.T) {
redactor := NewDefaultPIIRedactor()
tests := []struct {
input string
expected string
}{
{"+49 30 12345678", "[PHONE_REDACTED]"},
{"0049 30 12345678", "[PHONE_REDACTED]"},
{"030 12345678", "[PHONE_REDACTED]"},
}
for _, tt := range tests {
result := redactor.Redact(tt.input)
if result != tt.expected {
t.Errorf("For input %q: expected %q, got %q", tt.input, tt.expected, result)
}
}
}
func TestPIIRedactor_RedactsMultiplePII(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "User test@example.com from 10.0.0.1"
result := redactor.Redact(text)
if result != "User [EMAIL_REDACTED] from [IP_REDACTED]" {
t.Errorf("Unexpected result: %s", result)
}
}
func TestPIIRedactor_PreservesNonPIIText(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "User logged in successfully"
result := redactor.Redact(text)
if result != text {
t.Errorf("Text should be unchanged: got %s", result)
}
}
func TestPIIRedactor_EmptyString(t *testing.T) {
redactor := NewDefaultPIIRedactor()
result := redactor.Redact("")
if result != "" {
t.Error("Empty string should remain empty")
}
}
func TestContainsPII(t *testing.T) {
redactor := NewDefaultPIIRedactor()
tests := []struct {
input string
expected bool
}{
{"test@example.com", true},
{"192.168.1.1", true},
{"+49 30 12345678", true},
{"Hello World", false},
{"", false},
}
for _, tt := range tests {
result := redactor.ContainsPII(tt.input)
if result != tt.expected {
t.Errorf("For input %q: expected %v, got %v", tt.input, tt.expected, result)
}
}
}
func TestFindPII(t *testing.T) {
redactor := NewDefaultPIIRedactor()
text := "Email: test@example.com, IP: 10.0.0.1"
findings := redactor.FindPII(text)
if len(findings) != 2 {
t.Errorf("Expected 2 findings, got %d", len(findings))
}
hasEmail := false
hasIP := false
for _, f := range findings {
if f.Type == "email" {
hasEmail = true
}
if f.Type == "ip_v4" {
hasIP = true
}
}
if !hasEmail {
t.Error("Should have found email")
}
if !hasIP {
t.Error("Should have found IP")
}
}
func TestRedactPII_GlobalFunction(t *testing.T) {
text := "User test@example.com logged in"
result := RedactPII(text)
if result == text {
t.Error("Email should have been redacted")
}
}
func TestContainsPIIDefault(t *testing.T) {
if !ContainsPIIDefault("test@example.com") {
t.Error("Should detect email as PII")
}
if ContainsPIIDefault("Hello World") {
t.Error("Should not detect non-PII text")
}
}
func TestRedactMap(t *testing.T) {
data := map[string]interface{}{
"email": "test@example.com",
"message": "Hello World",
"nested": map[string]interface{}{
"ip": "192.168.1.1",
},
}
result := RedactMap(data)
if result["email"] != "[EMAIL_REDACTED]" {
t.Errorf("Email should be redacted: %v", result["email"])
}
if result["message"] != "Hello World" {
t.Errorf("Non-PII should be unchanged: %v", result["message"])
}
nested := result["nested"].(map[string]interface{})
if nested["ip"] != "[IP_REDACTED]" {
t.Errorf("Nested IP should be redacted: %v", nested["ip"])
}
}
func TestAllPIIPatterns(t *testing.T) {
patterns := AllPIIPatterns()
if len(patterns) == 0 {
t.Error("Should have PII patterns")
}
// Check that we have the expected patterns
expectedNames := []string{"email", "ip_v4", "ip_v6", "phone", "iban", "uuid", "name"}
nameMap := make(map[string]bool)
for _, p := range patterns {
nameMap[p.Name] = true
}
for _, name := range expectedNames {
if !nameMap[name] {
t.Errorf("Missing expected pattern: %s", name)
}
}
}
func TestDefaultPIIPatterns(t *testing.T) {
patterns := DefaultPIIPatterns()
if len(patterns) != 4 {
t.Errorf("Expected 4 default patterns, got %d", len(patterns))
}
}
func TestIBANRedaction(t *testing.T) {
redactor := NewPIIRedactor(AllPIIPatterns())
text := "IBAN: DE89 3704 0044 0532 0130 00"
result := redactor.Redact(text)
if result == text {
t.Error("IBAN should have been redacted")
}
}
func TestUUIDRedaction(t *testing.T) {
redactor := NewPIIRedactor(AllPIIPatterns())
text := "User ID: a0000000-0000-0000-0000-000000000001"
result := redactor.Redact(text)
if result == text {
t.Error("UUID should have been redacted")
}
}
@@ -0,0 +1,75 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
// RequestIDHeader is the primary header for request IDs
RequestIDHeader = "X-Request-ID"
// CorrelationIDHeader is an alternative header for distributed tracing
CorrelationIDHeader = "X-Correlation-ID"
// RequestIDKey is the context key for storing the request ID
RequestIDKey = "request_id"
)
// RequestID returns a middleware that generates and propagates request IDs.
//
// For each incoming request:
// 1. Check for existing X-Request-ID or X-Correlation-ID header
// 2. If not present, generate a new UUID
// 3. Store in Gin context for use by handlers and logging
// 4. Add to response headers
//
// Usage:
//
// r.Use(middleware.RequestID())
//
// func handler(c *gin.Context) {
// requestID := middleware.GetRequestID(c)
// log.Printf("[%s] Processing request", requestID)
// }
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
// Try to get existing request ID from headers
requestID := c.GetHeader(RequestIDHeader)
if requestID == "" {
requestID = c.GetHeader(CorrelationIDHeader)
}
// Generate new ID if not provided
if requestID == "" {
requestID = uuid.New().String()
}
// Store in context for handlers and logging
c.Set(RequestIDKey, requestID)
// Add to response headers
c.Header(RequestIDHeader, requestID)
c.Header(CorrelationIDHeader, requestID)
c.Next()
}
}
// GetRequestID retrieves the request ID from the Gin context.
// Returns empty string if no request ID is set.
//
// Usage:
//
// requestID := middleware.GetRequestID(c)
func GetRequestID(c *gin.Context) string {
if id, exists := c.Get(RequestIDKey); exists {
if idStr, ok := id.(string); ok {
return idStr
}
}
return ""
}
// RequestIDFromContext is an alias for GetRequestID for API compatibility.
func RequestIDFromContext(c *gin.Context) string {
return GetRequestID(c)
}
@@ -0,0 +1,152 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestRequestID_GeneratesNewID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID == "" {
t.Error("Expected request ID to be set")
}
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check response header
requestID := w.Header().Get(RequestIDHeader)
if requestID == "" {
t.Error("Expected X-Request-ID header in response")
}
// Check correlation ID header
correlationID := w.Header().Get(CorrelationIDHeader)
if correlationID == "" {
t.Error("Expected X-Correlation-ID header in response")
}
if requestID != correlationID {
t.Error("X-Request-ID and X-Correlation-ID should match")
}
}
func TestRequestID_PropagatesExistingID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
customID := "custom-request-id-12345"
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID != customID {
t.Errorf("Expected request ID %s, got %s", customID, requestID)
}
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(RequestIDHeader, customID)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
responseID := w.Header().Get(RequestIDHeader)
if responseID != customID {
t.Errorf("Expected response header %s, got %s", customID, responseID)
}
}
func TestRequestID_PropagatesCorrelationID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
correlationID := "correlation-id-67890"
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID != correlationID {
t.Errorf("Expected request ID %s, got %s", correlationID, requestID)
}
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(CorrelationIDHeader, correlationID)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Both headers should be set with the correlation ID
if w.Header().Get(RequestIDHeader) != correlationID {
t.Error("X-Request-ID should match X-Correlation-ID")
}
}
func TestGetRequestID_ReturnsEmptyWhenNotSet(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// No RequestID middleware
router.GET("/test", func(c *gin.Context) {
requestID := GetRequestID(c)
if requestID != "" {
t.Errorf("Expected empty request ID, got %s", requestID)
}
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
func TestRequestIDFromContext_IsAliasForGetRequestID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestID())
router.GET("/test", func(c *gin.Context) {
id1 := GetRequestID(c)
id2 := RequestIDFromContext(c)
if id1 != id2 {
t.Errorf("GetRequestID and RequestIDFromContext should return same value")
}
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
}
@@ -0,0 +1,167 @@
package middleware
import (
"os"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// SecurityHeadersConfig holds configuration for security headers.
type SecurityHeadersConfig struct {
// X-Content-Type-Options
ContentTypeOptions string
// X-Frame-Options
FrameOptions string
// X-XSS-Protection (legacy but useful for older browsers)
XSSProtection string
// Strict-Transport-Security
HSTSEnabled bool
HSTSMaxAge int
HSTSIncludeSubdomains bool
HSTSPreload bool
// Content-Security-Policy
CSPEnabled bool
CSPPolicy string
// Referrer-Policy
ReferrerPolicy string
// Permissions-Policy
PermissionsPolicy string
// Cross-Origin headers
CrossOriginOpenerPolicy string
CrossOriginResourcePolicy string
// Development mode (relaxes some restrictions)
DevelopmentMode bool
// Excluded paths (e.g., health checks)
ExcludedPaths []string
}
// DefaultSecurityHeadersConfig returns sensible default configuration.
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
env := os.Getenv("ENVIRONMENT")
isDev := env == "" || strings.ToLower(env) == "development" || strings.ToLower(env) == "dev"
return SecurityHeadersConfig{
ContentTypeOptions: "nosniff",
FrameOptions: "DENY",
XSSProtection: "1; mode=block",
HSTSEnabled: true,
HSTSMaxAge: 31536000, // 1 year
HSTSIncludeSubdomains: true,
HSTSPreload: false,
CSPEnabled: true,
CSPPolicy: getDefaultCSP(isDev),
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=(), microphone=(), camera=()",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
DevelopmentMode: isDev,
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
}
}
// getDefaultCSP returns a sensible default CSP for the environment.
func getDefaultCSP(isDevelopment bool) string {
if isDevelopment {
return "default-src 'self' localhost:* ws://localhost:*; " +
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
"style-src 'self' 'unsafe-inline'; " +
"img-src 'self' data: https: blob:; " +
"font-src 'self' data:; " +
"connect-src 'self' localhost:* ws://localhost:* https:; " +
"frame-ancestors 'self'"
}
return "default-src 'self'; " +
"script-src 'self' 'unsafe-inline'; " +
"style-src 'self' 'unsafe-inline'; " +
"img-src 'self' data: https:; " +
"font-src 'self' data:; " +
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; " +
"frame-ancestors 'none'"
}
// buildHSTSHeader builds the Strict-Transport-Security header value.
func (c *SecurityHeadersConfig) buildHSTSHeader() string {
parts := []string{"max-age=" + strconv.Itoa(c.HSTSMaxAge)}
if c.HSTSIncludeSubdomains {
parts = append(parts, "includeSubDomains")
}
if c.HSTSPreload {
parts = append(parts, "preload")
}
return strings.Join(parts, "; ")
}
// isExcludedPath checks if the path should be excluded from security headers.
func (c *SecurityHeadersConfig) isExcludedPath(path string) bool {
for _, excluded := range c.ExcludedPaths {
if path == excluded {
return true
}
}
return false
}
// SecurityHeaders returns a middleware that adds security headers to all responses.
//
// Usage:
//
// r.Use(middleware.SecurityHeaders())
//
// // Or with custom config:
// config := middleware.DefaultSecurityHeadersConfig()
// config.CSPPolicy = "default-src 'self'"
// r.Use(middleware.SecurityHeadersWithConfig(config))
func SecurityHeaders() gin.HandlerFunc {
return SecurityHeadersWithConfig(DefaultSecurityHeadersConfig())
}
// SecurityHeadersWithConfig returns a security headers middleware with custom configuration.
func SecurityHeadersWithConfig(config SecurityHeadersConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// Skip for excluded paths
if config.isExcludedPath(c.Request.URL.Path) {
c.Next()
return
}
// Always add these headers
c.Header("X-Content-Type-Options", config.ContentTypeOptions)
c.Header("X-Frame-Options", config.FrameOptions)
c.Header("X-XSS-Protection", config.XSSProtection)
c.Header("Referrer-Policy", config.ReferrerPolicy)
// HSTS (only in production or if explicitly enabled)
if config.HSTSEnabled && !config.DevelopmentMode {
c.Header("Strict-Transport-Security", config.buildHSTSHeader())
}
// Content-Security-Policy
if config.CSPEnabled && config.CSPPolicy != "" {
c.Header("Content-Security-Policy", config.CSPPolicy)
}
// Permissions-Policy
if config.PermissionsPolicy != "" {
c.Header("Permissions-Policy", config.PermissionsPolicy)
}
// Cross-Origin headers (only in production)
if !config.DevelopmentMode {
c.Header("Cross-Origin-Opener-Policy", config.CrossOriginOpenerPolicy)
c.Header("Cross-Origin-Resource-Policy", config.CrossOriginResourcePolicy)
}
c.Next()
}
}
@@ -0,0 +1,377 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestSecurityHeaders_AddsBasicHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = true // Skip HSTS and cross-origin headers
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check basic security headers
tests := []struct {
header string
expected string
}{
{"X-Content-Type-Options", "nosniff"},
{"X-Frame-Options", "DENY"},
{"X-XSS-Protection", "1; mode=block"},
{"Referrer-Policy", "strict-origin-when-cross-origin"},
}
for _, tt := range tests {
value := w.Header().Get(tt.header)
if value != tt.expected {
t.Errorf("Header %s: expected %q, got %q", tt.header, tt.expected, value)
}
}
}
func TestSecurityHeaders_HSTSNotAddedInDevelopment(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = true
config.HSTSEnabled = true
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
hstsHeader := w.Header().Get("Strict-Transport-Security")
if hstsHeader != "" {
t.Errorf("HSTS should not be set in development mode, got: %s", hstsHeader)
}
}
func TestSecurityHeaders_HSTSAddedInProduction(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = false
config.HSTSEnabled = true
config.HSTSMaxAge = 31536000
config.HSTSIncludeSubdomains = true
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
hstsHeader := w.Header().Get("Strict-Transport-Security")
if hstsHeader == "" {
t.Error("HSTS should be set in production mode")
}
// Check that it contains max-age
if hstsHeader != "max-age=31536000; includeSubDomains" {
t.Errorf("Unexpected HSTS value: %s", hstsHeader)
}
}
func TestSecurityHeaders_HSTSWithPreload(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = false
config.HSTSEnabled = true
config.HSTSMaxAge = 31536000
config.HSTSIncludeSubdomains = true
config.HSTSPreload = true
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
hstsHeader := w.Header().Get("Strict-Transport-Security")
expected := "max-age=31536000; includeSubDomains; preload"
if hstsHeader != expected {
t.Errorf("Expected HSTS %q, got %q", expected, hstsHeader)
}
}
func TestSecurityHeaders_CSPHeader(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.CSPEnabled = true
config.CSPPolicy = "default-src 'self'"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
cspHeader := w.Header().Get("Content-Security-Policy")
if cspHeader != "default-src 'self'" {
t.Errorf("Expected CSP %q, got %q", "default-src 'self'", cspHeader)
}
}
func TestSecurityHeaders_NoCSPWhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.CSPEnabled = false
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
cspHeader := w.Header().Get("Content-Security-Policy")
if cspHeader != "" {
t.Errorf("CSP should not be set when disabled, got: %s", cspHeader)
}
}
func TestSecurityHeaders_ExcludedPaths(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.ExcludedPaths = []string{"/health", "/metrics"}
router.Use(SecurityHeadersWithConfig(config))
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
})
router.GET("/api", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Test excluded path
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("X-Content-Type-Options") != "" {
t.Error("Security headers should not be set for excluded paths")
}
// Test non-excluded path
req = httptest.NewRequest(http.MethodGet, "/api", nil)
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("Security headers should be set for non-excluded paths")
}
}
func TestSecurityHeaders_CrossOriginInProduction(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = false
config.CrossOriginOpenerPolicy = "same-origin"
config.CrossOriginResourcePolicy = "same-origin"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
coopHeader := w.Header().Get("Cross-Origin-Opener-Policy")
if coopHeader != "same-origin" {
t.Errorf("Expected COOP %q, got %q", "same-origin", coopHeader)
}
corpHeader := w.Header().Get("Cross-Origin-Resource-Policy")
if corpHeader != "same-origin" {
t.Errorf("Expected CORP %q, got %q", "same-origin", corpHeader)
}
}
func TestSecurityHeaders_NoCrossOriginInDevelopment(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.DevelopmentMode = true
config.CrossOriginOpenerPolicy = "same-origin"
config.CrossOriginResourcePolicy = "same-origin"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Header().Get("Cross-Origin-Opener-Policy") != "" {
t.Error("COOP should not be set in development mode")
}
if w.Header().Get("Cross-Origin-Resource-Policy") != "" {
t.Error("CORP should not be set in development mode")
}
}
func TestSecurityHeaders_PermissionsPolicy(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
config := DefaultSecurityHeadersConfig()
config.PermissionsPolicy = "geolocation=(), microphone=()"
router.Use(SecurityHeadersWithConfig(config))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
ppHeader := w.Header().Get("Permissions-Policy")
if ppHeader != "geolocation=(), microphone=()" {
t.Errorf("Expected Permissions-Policy %q, got %q", "geolocation=(), microphone=()", ppHeader)
}
}
func TestSecurityHeaders_DefaultMiddleware(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
// Use the default middleware function
router.Use(SecurityHeaders())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// Should at least have the basic headers
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("Default middleware should set X-Content-Type-Options")
}
}
func TestBuildHSTSHeader(t *testing.T) {
tests := []struct {
name string
config SecurityHeadersConfig
expected string
}{
{
name: "basic HSTS",
config: SecurityHeadersConfig{
HSTSMaxAge: 31536000,
HSTSIncludeSubdomains: false,
HSTSPreload: false,
},
expected: "max-age=31536000",
},
{
name: "HSTS with subdomains",
config: SecurityHeadersConfig{
HSTSMaxAge: 31536000,
HSTSIncludeSubdomains: true,
HSTSPreload: false,
},
expected: "max-age=31536000; includeSubDomains",
},
{
name: "HSTS with preload",
config: SecurityHeadersConfig{
HSTSMaxAge: 31536000,
HSTSIncludeSubdomains: true,
HSTSPreload: true,
},
expected: "max-age=31536000; includeSubDomains; preload",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.config.buildHSTSHeader()
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
func TestIsExcludedPath(t *testing.T) {
config := SecurityHeadersConfig{
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
}
tests := []struct {
path string
excluded bool
}{
{"/health", true},
{"/metrics", true},
{"/api/v1/health", true},
{"/api", false},
{"/health/check", false},
{"/", false},
}
for _, tt := range tests {
result := config.isExcludedPath(tt.path)
if result != tt.excluded {
t.Errorf("Path %s: expected excluded=%v, got %v", tt.path, tt.excluded, result)
}
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,505 @@
package services
import (
"context"
"fmt"
"time"
"github.com/breakpilot/consent-service/internal/database"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services/matrix"
"github.com/google/uuid"
)
// AttendanceService handles attendance tracking and notifications
type AttendanceService struct {
db *database.DB
matrix *matrix.MatrixService
}
// NewAttendanceService creates a new attendance service
func NewAttendanceService(db *database.DB, matrixService *matrix.MatrixService) *AttendanceService {
return &AttendanceService{
db: db,
matrix: matrixService,
}
}
// ========================================
// Attendance Recording
// ========================================
// RecordAttendance records a student's attendance for a specific lesson
func (s *AttendanceService) RecordAttendance(ctx context.Context, req models.RecordAttendanceRequest, recordedByUserID uuid.UUID) (*models.AttendanceRecord, error) {
studentID, err := uuid.Parse(req.StudentID)
if err != nil {
return nil, fmt.Errorf("invalid student ID: %w", err)
}
slotID, err := uuid.Parse(req.SlotID)
if err != nil {
return nil, fmt.Errorf("invalid slot ID: %w", err)
}
date, err := time.Parse("2006-01-02", req.Date)
if err != nil {
return nil, fmt.Errorf("invalid date format: %w", err)
}
record := &models.AttendanceRecord{
ID: uuid.New(),
StudentID: studentID,
Date: date,
SlotID: slotID,
Status: req.Status,
RecordedBy: recordedByUserID,
Note: req.Note,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (student_id, date, slot_id)
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = EXCLUDED.updated_at
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
record.ID, record.StudentID, record.Date, record.SlotID,
record.Status, record.RecordedBy, record.Note, record.CreatedAt, record.UpdatedAt,
).Scan(&record.ID)
if err != nil {
return nil, fmt.Errorf("failed to record attendance: %w", err)
}
// If student is absent, send notification to parents
if record.Status == models.AttendanceAbsent || record.Status == models.AttendancePending {
go s.notifyParentsOfAbsence(context.Background(), record)
}
return record, nil
}
// RecordBulkAttendance records attendance for multiple students at once
func (s *AttendanceService) RecordBulkAttendance(ctx context.Context, classID uuid.UUID, date string, slotID uuid.UUID, records []struct {
StudentID string
Status string
Note *string
}, recordedByUserID uuid.UUID) error {
parsedDate, err := time.Parse("2006-01-02", date)
if err != nil {
return fmt.Errorf("invalid date format: %w", err)
}
for _, rec := range records {
studentID, err := uuid.Parse(rec.StudentID)
if err != nil {
continue
}
query := `
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
ON CONFLICT (student_id, date, slot_id)
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = NOW()`
_, err = s.db.Pool.Exec(ctx, query,
uuid.New(), studentID, parsedDate, slotID, rec.Status, recordedByUserID, rec.Note,
)
if err != nil {
return fmt.Errorf("failed to record attendance for student %s: %w", rec.StudentID, err)
}
// Notify parents if absent
if rec.Status == models.AttendanceAbsent || rec.Status == models.AttendancePending {
go s.notifyParentsOfAbsenceByStudentID(context.Background(), studentID, parsedDate, slotID)
}
}
return nil
}
// GetAttendanceByClass gets attendance records for a class on a specific date
func (s *AttendanceService) GetAttendanceByClass(ctx context.Context, classID uuid.UUID, date string) (*models.ClassAttendanceOverview, error) {
parsedDate, err := time.Parse("2006-01-02", date)
if err != nil {
return nil, fmt.Errorf("invalid date format: %w", err)
}
// Get class info
classQuery := `SELECT id, school_id, school_year_id, name, grade, section, room, is_active FROM classes WHERE id = $1`
class := &models.Class{}
err = s.db.Pool.QueryRow(ctx, classQuery, classID).Scan(
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
&class.Grade, &class.Section, &class.Room, &class.IsActive,
)
if err != nil {
return nil, fmt.Errorf("failed to get class: %w", err)
}
// Get total students
var totalStudents int
err = s.db.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM students WHERE class_id = $1 AND is_active = true`, classID).Scan(&totalStudents)
if err != nil {
return nil, fmt.Errorf("failed to count students: %w", err)
}
// Get attendance records for the date
recordsQuery := `
SELECT ar.id, ar.student_id, ar.date, ar.slot_id, ar.status, ar.recorded_by, ar.note, ar.created_at, ar.updated_at
FROM attendance_records ar
JOIN students s ON ar.student_id = s.id
WHERE s.class_id = $1 AND ar.date = $2
ORDER BY ar.slot_id`
rows, err := s.db.Pool.Query(ctx, recordsQuery, classID, parsedDate)
if err != nil {
return nil, fmt.Errorf("failed to get attendance records: %w", err)
}
defer rows.Close()
var records []models.AttendanceRecord
presentCount := 0
absentCount := 0
lateCount := 0
seenStudents := make(map[uuid.UUID]bool)
for rows.Next() {
var record models.AttendanceRecord
err := rows.Scan(
&record.ID, &record.StudentID, &record.Date, &record.SlotID,
&record.Status, &record.RecordedBy, &record.Note, &record.CreatedAt, &record.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
}
records = append(records, record)
// Count unique students for summary (use first slot's status)
if !seenStudents[record.StudentID] {
seenStudents[record.StudentID] = true
switch record.Status {
case models.AttendancePresent:
presentCount++
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused, models.AttendancePending:
absentCount++
case models.AttendanceLate, models.AttendanceLateExcused:
lateCount++
}
}
}
return &models.ClassAttendanceOverview{
Class: *class,
Date: parsedDate,
TotalStudents: totalStudents,
PresentCount: presentCount,
AbsentCount: absentCount,
LateCount: lateCount,
Records: records,
}, nil
}
// GetStudentAttendance gets attendance history for a student
func (s *AttendanceService) GetStudentAttendance(ctx context.Context, studentID uuid.UUID, startDate, endDate time.Time) ([]models.AttendanceRecord, error) {
query := `
SELECT id, student_id, timetable_entry_id, date, slot_id, status, recorded_by, note, created_at, updated_at
FROM attendance_records
WHERE student_id = $1 AND date >= $2 AND date <= $3
ORDER BY date DESC, slot_id`
rows, err := s.db.Pool.Query(ctx, query, studentID, startDate, endDate)
if err != nil {
return nil, fmt.Errorf("failed to get student attendance: %w", err)
}
defer rows.Close()
var records []models.AttendanceRecord
for rows.Next() {
var record models.AttendanceRecord
err := rows.Scan(
&record.ID, &record.StudentID, &record.TimetableEntryID, &record.Date,
&record.SlotID, &record.Status, &record.RecordedBy, &record.Note,
&record.CreatedAt, &record.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
}
records = append(records, record)
}
return records, nil
}
// ========================================
// Absence Reports (Parent-initiated)
// ========================================
// ReportAbsence allows parents to report a student's absence
func (s *AttendanceService) ReportAbsence(ctx context.Context, req models.ReportAbsenceRequest, reportedByUserID uuid.UUID) (*models.AbsenceReport, error) {
studentID, err := uuid.Parse(req.StudentID)
if err != nil {
return nil, fmt.Errorf("invalid student ID: %w", err)
}
startDate, err := time.Parse("2006-01-02", req.StartDate)
if err != nil {
return nil, fmt.Errorf("invalid start date format: %w", err)
}
endDate, err := time.Parse("2006-01-02", req.EndDate)
if err != nil {
return nil, fmt.Errorf("invalid end date format: %w", err)
}
report := &models.AbsenceReport{
ID: uuid.New(),
StudentID: studentID,
StartDate: startDate,
EndDate: endDate,
Reason: req.Reason,
ReasonCategory: req.ReasonCategory,
Status: "reported",
ReportedBy: reportedByUserID,
ReportedAt: time.Now(),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO absence_reports (id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
report.ID, report.StudentID, report.StartDate, report.EndDate,
report.Reason, report.ReasonCategory, report.Status,
report.ReportedBy, report.ReportedAt, report.CreatedAt, report.UpdatedAt,
).Scan(&report.ID)
if err != nil {
return nil, fmt.Errorf("failed to create absence report: %w", err)
}
return report, nil
}
// ConfirmAbsence allows teachers to confirm/excuse an absence
func (s *AttendanceService) ConfirmAbsence(ctx context.Context, reportID uuid.UUID, confirmedByUserID uuid.UUID, status string) error {
query := `
UPDATE absence_reports
SET status = $1, confirmed_by = $2, confirmed_at = NOW(), updated_at = NOW()
WHERE id = $3`
result, err := s.db.Pool.Exec(ctx, query, status, confirmedByUserID, reportID)
if err != nil {
return fmt.Errorf("failed to confirm absence: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("absence report not found")
}
return nil
}
// GetAbsenceReports gets absence reports for a student
func (s *AttendanceService) GetAbsenceReports(ctx context.Context, studentID uuid.UUID) ([]models.AbsenceReport, error) {
query := `
SELECT id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, confirmed_by, confirmed_at, medical_certificate, certificate_uploaded, matrix_notification_sent, email_notification_sent, created_at, updated_at
FROM absence_reports
WHERE student_id = $1
ORDER BY start_date DESC`
rows, err := s.db.Pool.Query(ctx, query, studentID)
if err != nil {
return nil, fmt.Errorf("failed to get absence reports: %w", err)
}
defer rows.Close()
var reports []models.AbsenceReport
for rows.Next() {
var report models.AbsenceReport
err := rows.Scan(
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
&report.Reason, &report.ReasonCategory, &report.Status,
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
&report.MedicalCertificate, &report.CertificateUploaded,
&report.MatrixNotificationSent, &report.EmailNotificationSent,
&report.CreatedAt, &report.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan absence report: %w", err)
}
reports = append(reports, report)
}
return reports, nil
}
// GetPendingAbsenceReports gets all unconfirmed absence reports for a class
func (s *AttendanceService) GetPendingAbsenceReports(ctx context.Context, classID uuid.UUID) ([]models.AbsenceReport, error) {
query := `
SELECT ar.id, ar.student_id, ar.start_date, ar.end_date, ar.reason, ar.reason_category, ar.status, ar.reported_by, ar.reported_at, ar.confirmed_by, ar.confirmed_at, ar.medical_certificate, ar.certificate_uploaded, ar.matrix_notification_sent, ar.email_notification_sent, ar.created_at, ar.updated_at
FROM absence_reports ar
JOIN students s ON ar.student_id = s.id
WHERE s.class_id = $1 AND ar.status = 'reported'
ORDER BY ar.start_date DESC`
rows, err := s.db.Pool.Query(ctx, query, classID)
if err != nil {
return nil, fmt.Errorf("failed to get pending absence reports: %w", err)
}
defer rows.Close()
var reports []models.AbsenceReport
for rows.Next() {
var report models.AbsenceReport
err := rows.Scan(
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
&report.Reason, &report.ReasonCategory, &report.Status,
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
&report.MedicalCertificate, &report.CertificateUploaded,
&report.MatrixNotificationSent, &report.EmailNotificationSent,
&report.CreatedAt, &report.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan absence report: %w", err)
}
reports = append(reports, report)
}
return reports, nil
}
// ========================================
// Attendance Statistics
// ========================================
// GetStudentAttendanceStats gets attendance statistics for a student
func (s *AttendanceService) GetStudentAttendanceStats(ctx context.Context, studentID uuid.UUID, schoolYearID uuid.UUID) (map[string]interface{}, error) {
query := `
SELECT
COUNT(*) as total_records,
COUNT(CASE WHEN status = 'present' THEN 1 END) as present_count,
COUNT(CASE WHEN status IN ('absent', 'excused', 'unexcused', 'pending_confirmation') THEN 1 END) as absent_count,
COUNT(CASE WHEN status = 'unexcused' THEN 1 END) as unexcused_count,
COUNT(CASE WHEN status IN ('late', 'late_excused') THEN 1 END) as late_count
FROM attendance_records ar
JOIN timetable_slots ts ON ar.slot_id = ts.id
JOIN schools sch ON ts.school_id = sch.id
JOIN school_years sy ON sy.school_id = sch.id AND sy.id = $2
WHERE ar.student_id = $1 AND ar.date >= sy.start_date AND ar.date <= sy.end_date`
var totalRecords, presentCount, absentCount, unexcusedCount, lateCount int
err := s.db.Pool.QueryRow(ctx, query, studentID, schoolYearID).Scan(
&totalRecords, &presentCount, &absentCount, &unexcusedCount, &lateCount,
)
if err != nil {
return nil, fmt.Errorf("failed to get attendance stats: %w", err)
}
var attendanceRate float64
if totalRecords > 0 {
attendanceRate = float64(presentCount) / float64(totalRecords) * 100
}
return map[string]interface{}{
"total_records": totalRecords,
"present_count": presentCount,
"absent_count": absentCount,
"unexcused_count": unexcusedCount,
"late_count": lateCount,
"attendance_rate": attendanceRate,
}, nil
}
// ========================================
// Parent Notifications
// ========================================
func (s *AttendanceService) notifyParentsOfAbsence(ctx context.Context, record *models.AttendanceRecord) {
if s.matrix == nil {
return
}
// Get student info
var studentFirstName, studentLastName, matrixDMRoom string
err := s.db.Pool.QueryRow(ctx, `
SELECT first_name, last_name, matrix_dm_room
FROM students
WHERE id = $1`, record.StudentID).Scan(&studentFirstName, &studentLastName, &matrixDMRoom)
if err != nil || matrixDMRoom == "" {
return
}
// Get slot info
var slotNumber int
err = s.db.Pool.QueryRow(ctx, `SELECT slot_number FROM timetable_slots WHERE id = $1`, record.SlotID).Scan(&slotNumber)
if err != nil {
return
}
studentName := studentFirstName + " " + studentLastName
dateStr := record.Date.Format("02.01.2006")
// Send Matrix notification
err = s.matrix.SendAbsenceNotification(ctx, matrixDMRoom, studentName, dateStr, slotNumber)
if err != nil {
fmt.Printf("Failed to send absence notification: %v\n", err)
return
}
// Update notification status
s.db.Pool.Exec(ctx, `
UPDATE attendance_records
SET updated_at = NOW()
WHERE id = $1`, record.ID)
// Log the notification
s.createAbsenceNotificationLog(ctx, record.ID, studentName, dateStr, slotNumber)
}
func (s *AttendanceService) notifyParentsOfAbsenceByStudentID(ctx context.Context, studentID uuid.UUID, date time.Time, slotID uuid.UUID) {
record := &models.AttendanceRecord{
StudentID: studentID,
Date: date,
SlotID: slotID,
}
s.notifyParentsOfAbsence(ctx, record)
}
func (s *AttendanceService) createAbsenceNotificationLog(ctx context.Context, recordID uuid.UUID, studentName, dateStr string, slotNumber int) {
// Get parent IDs for this student
query := `
SELECT p.id
FROM parents p
JOIN student_parents sp ON p.id = sp.parent_id
JOIN attendance_records ar ON sp.student_id = ar.student_id
WHERE ar.id = $1`
rows, err := s.db.Pool.Query(ctx, query, recordID)
if err != nil {
return
}
defer rows.Close()
message := fmt.Sprintf("Abwesenheitsmeldung: %s war am %s in der %d. Stunde nicht anwesend.", studentName, dateStr, slotNumber)
for rows.Next() {
var parentID uuid.UUID
if err := rows.Scan(&parentID); err != nil {
continue
}
// Insert notification log
s.db.Pool.Exec(ctx, `
INSERT INTO absence_notifications (id, attendance_record_id, parent_id, channel, message_content, sent_at, created_at)
VALUES ($1, $2, $3, 'matrix', $4, NOW(), NOW())`,
uuid.New(), recordID, parentID, message)
}
}
@@ -0,0 +1,388 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
)
// TestValidateAttendanceRecord tests attendance record validation
func TestValidateAttendanceRecord(t *testing.T) {
slotID := uuid.New()
tests := []struct {
name string
record models.AttendanceRecord
expectValid bool
}{
{
name: "valid present record",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: true,
},
{
name: "valid absent record",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: models.AttendanceAbsent,
RecordedBy: uuid.New(),
},
expectValid: true,
},
{
name: "valid late record",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: models.AttendanceLate,
RecordedBy: uuid.New(),
},
expectValid: true,
},
{
name: "missing student ID",
record: models.AttendanceRecord{
StudentID: uuid.Nil,
SlotID: slotID,
Date: time.Now(),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: false,
},
{
name: "invalid status",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: "invalid_status",
RecordedBy: uuid.New(),
},
expectValid: false,
},
{
name: "future date",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now().AddDate(0, 0, 7),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: false,
},
{
name: "missing slot ID",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: uuid.Nil,
Date: time.Now(),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateAttendanceRecord(tt.record)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateAttendanceRecord validates an attendance record
func validateAttendanceRecord(record models.AttendanceRecord) bool {
if record.StudentID == uuid.Nil {
return false
}
if record.SlotID == uuid.Nil {
return false
}
if record.RecordedBy == uuid.Nil {
return false
}
if record.Date.After(time.Now().AddDate(0, 0, 1)) {
return false
}
// Validate status
validStatuses := map[string]bool{
models.AttendancePresent: true,
models.AttendanceAbsent: true,
models.AttendanceAbsentExcused: true,
models.AttendanceAbsentUnexcused: true,
models.AttendanceLate: true,
models.AttendanceLateExcused: true,
models.AttendancePending: true,
}
if !validStatuses[record.Status] {
return false
}
return true
}
// TestValidateAbsenceReport tests absence report validation
func TestValidateAbsenceReport(t *testing.T) {
now := time.Now()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
reason := "Krankheit"
medicalReason := "Arzttermin"
tests := []struct {
name string
report models.AbsenceReport
expectValid bool
}{
{
name: "valid single day absence",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today,
Reason: &reason,
ReasonCategory: "illness",
Status: "reported",
},
expectValid: true,
},
{
name: "valid multi-day absence",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today.AddDate(0, 0, 3),
Reason: &medicalReason,
ReasonCategory: "appointment",
Status: "reported",
},
expectValid: true,
},
{
name: "end before start",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today.AddDate(0, 0, 3),
EndDate: today,
Reason: &reason,
ReasonCategory: "illness",
Status: "reported",
},
expectValid: false,
},
{
name: "missing reason category",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today,
Reason: &reason,
ReasonCategory: "",
Status: "reported",
},
expectValid: false,
},
{
name: "invalid reason category",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today,
Reason: &reason,
ReasonCategory: "invalid_type",
Status: "reported",
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateAbsenceReport(tt.report)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateAbsenceReport validates an absence report
func validateAbsenceReport(report models.AbsenceReport) bool {
if report.StudentID == uuid.Nil {
return false
}
if report.ReportedBy == uuid.Nil {
return false
}
if report.EndDate.Before(report.StartDate) {
return false
}
if report.ReasonCategory == "" {
return false
}
// Validate reason category
validCategories := map[string]bool{
"illness": true,
"appointment": true,
"family": true,
"other": true,
}
if !validCategories[report.ReasonCategory] {
return false
}
return true
}
// TestCalculateAttendanceStats tests attendance statistics calculation
func TestCalculateAttendanceStats(t *testing.T) {
tests := []struct {
name string
records []models.AttendanceRecord
expectedPresent int
expectedAbsent int
expectedLate int
}{
{
name: "all present",
records: []models.AttendanceRecord{
{Status: models.AttendancePresent},
{Status: models.AttendancePresent},
{Status: models.AttendancePresent},
},
expectedPresent: 3,
expectedAbsent: 0,
expectedLate: 0,
},
{
name: "mixed attendance",
records: []models.AttendanceRecord{
{Status: models.AttendancePresent},
{Status: models.AttendanceAbsent},
{Status: models.AttendanceLate},
{Status: models.AttendancePresent},
{Status: models.AttendanceAbsentExcused},
},
expectedPresent: 2,
expectedAbsent: 2,
expectedLate: 1,
},
{
name: "empty records",
records: []models.AttendanceRecord{},
expectedPresent: 0,
expectedAbsent: 0,
expectedLate: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
present, absent, late := calculateAttendanceStats(tt.records)
if present != tt.expectedPresent {
t.Errorf("expected present=%d, got present=%d", tt.expectedPresent, present)
}
if absent != tt.expectedAbsent {
t.Errorf("expected absent=%d, got absent=%d", tt.expectedAbsent, absent)
}
if late != tt.expectedLate {
t.Errorf("expected late=%d, got late=%d", tt.expectedLate, late)
}
})
}
}
// calculateAttendanceStats calculates attendance statistics
func calculateAttendanceStats(records []models.AttendanceRecord) (present, absent, late int) {
for _, r := range records {
switch r.Status {
case models.AttendancePresent:
present++
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused:
absent++
case models.AttendanceLate, models.AttendanceLateExcused:
late++
}
}
return
}
// TestAttendanceRateCalculation tests attendance rate percentage calculation
func TestAttendanceRateCalculation(t *testing.T) {
tests := []struct {
name string
present int
total int
expectedRate float64
}{
{
name: "100% attendance",
present: 26,
total: 26,
expectedRate: 100.0,
},
{
name: "92.3% attendance",
present: 24,
total: 26,
expectedRate: 92.31,
},
{
name: "0% attendance",
present: 0,
total: 26,
expectedRate: 0.0,
},
{
name: "empty class",
present: 0,
total: 0,
expectedRate: 0.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rate := calculateAttendanceRate(tt.present, tt.total)
// Allow small floating point differences
if rate < tt.expectedRate-0.1 || rate > tt.expectedRate+0.1 {
t.Errorf("expected rate=%.2f, got rate=%.2f", tt.expectedRate, rate)
}
})
}
}
// calculateAttendanceRate calculates attendance rate as percentage
func calculateAttendanceRate(present, total int) float64 {
if total == 0 {
return 0.0
}
rate := float64(present) / float64(total) * 100
// Round to 2 decimal places
return float64(int(rate*100)) / 100
}
@@ -0,0 +1,568 @@
package services
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
"github.com/breakpilot/consent-service/internal/models"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotFound = errors.New("user not found")
ErrUserExists = errors.New("user with this email already exists")
ErrInvalidToken = errors.New("invalid or expired token")
ErrAccountLocked = errors.New("account is temporarily locked")
ErrAccountSuspended = errors.New("account is suspended")
ErrEmailNotVerified = errors.New("email not verified")
)
// AuthService handles authentication logic
type AuthService struct {
db *pgxpool.Pool
jwtSecret string
jwtRefreshSecret string
accessTokenExp time.Duration
refreshTokenExp time.Duration
}
// NewAuthService creates a new AuthService
func NewAuthService(db *pgxpool.Pool, jwtSecret, jwtRefreshSecret string) *AuthService {
return &AuthService{
db: db,
jwtSecret: jwtSecret,
jwtRefreshSecret: jwtRefreshSecret,
accessTokenExp: time.Hour * 1, // 1 hour
refreshTokenExp: time.Hour * 24 * 30, // 30 days
}
}
// HashPassword hashes a password using bcrypt
func (s *AuthService) HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash password: %w", err)
}
return string(bytes), nil
}
// VerifyPassword verifies a password against a hash
func (s *AuthService) VerifyPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// GenerateSecureToken generates a cryptographically secure token
func (s *AuthService) GenerateSecureToken(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
// HashToken creates a SHA256 hash of a token for storage
func (s *AuthService) HashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// JWTClaims for access tokens
type JWTClaims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
AccountStatus string `json:"account_status"`
jwt.RegisteredClaims
}
// GenerateAccessToken creates a new JWT access token
func (s *AuthService) GenerateAccessToken(user *models.User) (string, error) {
claims := JWTClaims{
UserID: user.ID.String(),
Email: user.Email,
Role: user.Role,
AccountStatus: user.AccountStatus,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.accessTokenExp)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Subject: user.ID.String(),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.jwtSecret))
}
// GenerateRefreshToken creates a new refresh token
func (s *AuthService) GenerateRefreshToken() (string, string, error) {
token, err := s.GenerateSecureToken(32)
if err != nil {
return "", "", err
}
hash := s.HashToken(token)
return token, hash, nil
}
// ValidateAccessToken validates a JWT access token
func (s *AuthService) ValidateAccessToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.jwtSecret), nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, ErrInvalidToken
}
return claims, nil
}
// Register creates a new user account
func (s *AuthService) Register(ctx context.Context, req *models.RegisterRequest) (*models.User, string, error) {
// Check if user already exists
var exists bool
err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", req.Email).Scan(&exists)
if err != nil {
return nil, "", fmt.Errorf("failed to check existing user: %w", err)
}
if exists {
return nil, "", ErrUserExists
}
// Hash password
passwordHash, err := s.HashPassword(req.Password)
if err != nil {
return nil, "", err
}
// Create user
user := &models.User{
ID: uuid.New(),
Email: req.Email,
PasswordHash: &passwordHash,
Name: req.Name,
Role: "user",
EmailVerified: false,
AccountStatus: "active",
}
_, err = s.db.Exec(ctx, `
INSERT INTO users (id, email, password_hash, name, role, email_verified, account_status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
`, user.ID, user.Email, user.PasswordHash, user.Name, user.Role, user.EmailVerified, user.AccountStatus)
if err != nil {
return nil, "", fmt.Errorf("failed to create user: %w", err)
}
// Generate email verification token
verificationToken, err := s.GenerateSecureToken(32)
if err != nil {
return nil, "", err
}
// Store verification token
_, err = s.db.Exec(ctx, `
INSERT INTO email_verification_tokens (user_id, token, expires_at, created_at)
VALUES ($1, $2, $3, NOW())
`, user.ID, verificationToken, time.Now().Add(24*time.Hour))
if err != nil {
return nil, "", fmt.Errorf("failed to create verification token: %w", err)
}
// Create notification preferences
_, err = s.db.Exec(ctx, `
INSERT INTO notification_preferences (user_id, email_enabled, push_enabled, in_app_enabled, reminder_frequency, created_at, updated_at)
VALUES ($1, true, true, true, 'weekly', NOW(), NOW())
`, user.ID)
if err != nil {
// Non-critical error, just log
fmt.Printf("Warning: failed to create notification preferences: %v\n", err)
}
return user, verificationToken, nil
}
// Login authenticates a user and returns tokens
func (s *AuthService) Login(ctx context.Context, req *models.LoginRequest, ipAddress, userAgent string) (*models.LoginResponse, error) {
var user models.User
var passwordHash *string
err := s.db.QueryRow(ctx, `
SELECT id, email, password_hash, name, role, email_verified, account_status,
failed_login_attempts, locked_until, created_at, updated_at
FROM users WHERE email = $1
`, req.Email).Scan(
&user.ID, &user.Email, &passwordHash, &user.Name, &user.Role, &user.EmailVerified,
&user.AccountStatus, &user.FailedLoginAttempts, &user.LockedUntil, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, ErrInvalidCredentials
}
// Check if account is locked
if user.LockedUntil != nil && user.LockedUntil.After(time.Now()) {
return nil, ErrAccountLocked
}
// Check if account is suspended
if user.AccountStatus == "suspended" {
return nil, ErrAccountSuspended
}
// Verify password
if passwordHash == nil || !s.VerifyPassword(req.Password, *passwordHash) {
// Increment failed login attempts
_, _ = s.db.Exec(ctx, `
UPDATE users SET
failed_login_attempts = failed_login_attempts + 1,
locked_until = CASE WHEN failed_login_attempts >= 4 THEN NOW() + INTERVAL '30 minutes' ELSE locked_until END,
updated_at = NOW()
WHERE id = $1
`, user.ID)
return nil, ErrInvalidCredentials
}
// Reset failed login attempts and update last login
_, _ = s.db.Exec(ctx, `
UPDATE users SET
failed_login_attempts = 0,
locked_until = NULL,
last_login_at = NOW(),
updated_at = NOW()
WHERE id = $1
`, user.ID)
// Generate tokens
accessToken, err := s.GenerateAccessToken(&user)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, refreshTokenHash, err := s.GenerateRefreshToken()
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
// Store session
_, err = s.db.Exec(ctx, `
INSERT INTO user_sessions (user_id, token_hash, ip_address, user_agent, expires_at, created_at, last_activity_at)
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())
`, user.ID, refreshTokenHash, ipAddress, userAgent, time.Now().Add(s.refreshTokenExp))
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &models.LoginResponse{
User: user,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: int(s.accessTokenExp.Seconds()),
}, nil
}
// RefreshToken refreshes the access token using a refresh token
func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*models.LoginResponse, error) {
tokenHash := s.HashToken(refreshToken)
var session models.UserSession
var userID uuid.UUID
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, revoked_at FROM user_sessions
WHERE token_hash = $1
`, tokenHash).Scan(&session.ID, &userID, &session.ExpiresAt, &session.RevokedAt)
if err != nil {
return nil, ErrInvalidToken
}
// Check if session is expired or revoked
if session.RevokedAt != nil || session.ExpiresAt.Before(time.Now()) {
return nil, ErrInvalidToken
}
// Get user
var user models.User
err = s.db.QueryRow(ctx, `
SELECT id, email, name, role, email_verified, account_status, created_at, updated_at
FROM users WHERE id = $1
`, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified,
&user.AccountStatus, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, ErrUserNotFound
}
// Check account status
if user.AccountStatus == "suspended" {
return nil, ErrAccountSuspended
}
// Generate new access token
accessToken, err := s.GenerateAccessToken(&user)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
// Update session last activity
_, _ = s.db.Exec(ctx, `
UPDATE user_sessions SET last_activity_at = NOW() WHERE id = $1
`, session.ID)
return &models.LoginResponse{
User: user,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: int(s.accessTokenExp.Seconds()),
}, nil
}
// VerifyEmail verifies a user's email address
func (s *AuthService) VerifyEmail(ctx context.Context, token string) error {
var tokenID uuid.UUID
var userID uuid.UUID
var expiresAt time.Time
var usedAt *time.Time
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, used_at FROM email_verification_tokens
WHERE token = $1
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
if err != nil {
return ErrInvalidToken
}
if usedAt != nil || expiresAt.Before(time.Now()) {
return ErrInvalidToken
}
// Mark token as used
_, err = s.db.Exec(ctx, `UPDATE email_verification_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
if err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
// Verify user email
_, err = s.db.Exec(ctx, `
UPDATE users SET email_verified = true, email_verified_at = NOW(), updated_at = NOW()
WHERE id = $1
`, userID)
if err != nil {
return fmt.Errorf("failed to verify email: %w", err)
}
return nil
}
// CreatePasswordResetToken creates a password reset token
func (s *AuthService) CreatePasswordResetToken(ctx context.Context, email, ipAddress string) (string, *uuid.UUID, error) {
var userID uuid.UUID
err := s.db.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", email).Scan(&userID)
if err != nil {
// Don't reveal if user exists
return "", nil, nil
}
token, err := s.GenerateSecureToken(32)
if err != nil {
return "", nil, err
}
_, err = s.db.Exec(ctx, `
INSERT INTO password_reset_tokens (user_id, token, expires_at, ip_address, created_at)
VALUES ($1, $2, $3, $4, NOW())
`, userID, token, time.Now().Add(time.Hour), ipAddress)
if err != nil {
return "", nil, fmt.Errorf("failed to create reset token: %w", err)
}
return token, &userID, nil
}
// ResetPassword resets a user's password using a reset token
func (s *AuthService) ResetPassword(ctx context.Context, token, newPassword string) error {
var tokenID uuid.UUID
var userID uuid.UUID
var expiresAt time.Time
var usedAt *time.Time
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, used_at FROM password_reset_tokens
WHERE token = $1
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
if err != nil {
return ErrInvalidToken
}
if usedAt != nil || expiresAt.Before(time.Now()) {
return ErrInvalidToken
}
// Hash new password
passwordHash, err := s.HashPassword(newPassword)
if err != nil {
return err
}
// Mark token as used
_, err = s.db.Exec(ctx, `UPDATE password_reset_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
if err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
// Update password
_, err = s.db.Exec(ctx, `
UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2
`, passwordHash, userID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// Revoke all sessions for security
_, err = s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
if err != nil {
fmt.Printf("Warning: failed to revoke sessions: %v\n", err)
}
return nil
}
// ChangePassword changes a user's password (requires current password)
func (s *AuthService) ChangePassword(ctx context.Context, userID uuid.UUID, currentPassword, newPassword string) error {
var passwordHash *string
err := s.db.QueryRow(ctx, "SELECT password_hash FROM users WHERE id = $1", userID).Scan(&passwordHash)
if err != nil {
return ErrUserNotFound
}
if passwordHash == nil || !s.VerifyPassword(currentPassword, *passwordHash) {
return ErrInvalidCredentials
}
newPasswordHash, err := s.HashPassword(newPassword)
if err != nil {
return err
}
_, err = s.db.Exec(ctx, `UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2`, newPasswordHash, userID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// GetUserByID retrieves a user by ID
func (s *AuthService) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
var user models.User
err := s.db.QueryRow(ctx, `
SELECT id, email, name, role, email_verified, email_verified_at, account_status,
last_login_at, created_at, updated_at
FROM users WHERE id = $1
`, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified, &user.EmailVerifiedAt,
&user.AccountStatus, &user.LastLoginAt, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, ErrUserNotFound
}
return &user, nil
}
// UpdateProfile updates a user's profile
func (s *AuthService) UpdateProfile(ctx context.Context, userID uuid.UUID, req *models.UpdateProfileRequest) (*models.User, error) {
_, err := s.db.Exec(ctx, `UPDATE users SET name = $1, updated_at = NOW() WHERE id = $2`, req.Name, userID)
if err != nil {
return nil, fmt.Errorf("failed to update profile: %w", err)
}
return s.GetUserByID(ctx, userID)
}
// GetActiveSessions retrieves all active sessions for a user
func (s *AuthService) GetActiveSessions(ctx context.Context, userID uuid.UUID) ([]models.UserSession, error) {
rows, err := s.db.Query(ctx, `
SELECT id, user_id, device_info, ip_address, user_agent, expires_at, created_at, last_activity_at
FROM user_sessions
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW()
ORDER BY last_activity_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to get sessions: %w", err)
}
defer rows.Close()
var sessions []models.UserSession
for rows.Next() {
var session models.UserSession
err := rows.Scan(
&session.ID, &session.UserID, &session.DeviceInfo, &session.IPAddress,
&session.UserAgent, &session.ExpiresAt, &session.CreatedAt, &session.LastActivityAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan session: %w", err)
}
sessions = append(sessions, session)
}
return sessions, nil
}
// RevokeSession revokes a specific session
func (s *AuthService) RevokeSession(ctx context.Context, userID, sessionID uuid.UUID) error {
result, err := s.db.Exec(ctx, `
UPDATE user_sessions SET revoked_at = NOW() WHERE id = $1 AND user_id = $2 AND revoked_at IS NULL
`, sessionID, userID)
if err != nil {
return fmt.Errorf("failed to revoke session: %w", err)
}
if result.RowsAffected() == 0 {
return errors.New("session not found")
}
return nil
}
// Logout revokes a session by refresh token
func (s *AuthService) Logout(ctx context.Context, refreshToken string) error {
tokenHash := s.HashToken(refreshToken)
_, err := s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE token_hash = $1`, tokenHash)
return err
}
@@ -0,0 +1,367 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
)
// TestHashPassword tests password hashing
func TestHashPassword(t *testing.T) {
// Create service without DB for unit tests
s := &AuthService{}
password := "testPassword123!"
hash, err := s.HashPassword(password)
if err != nil {
t.Fatalf("HashPassword failed: %v", err)
}
if hash == "" {
t.Error("Hash should not be empty")
}
if hash == password {
t.Error("Hash should not equal the original password")
}
// Hash should be different each time (bcrypt uses random salt)
hash2, _ := s.HashPassword(password)
if hash == hash2 {
t.Error("Same password should produce different hashes due to salt")
}
}
// TestVerifyPassword tests password verification
func TestVerifyPassword(t *testing.T) {
s := &AuthService{}
password := "testPassword123!"
hash, _ := s.HashPassword(password)
// Should verify correct password
if !s.VerifyPassword(password, hash) {
t.Error("VerifyPassword should return true for correct password")
}
// Should reject incorrect password
if s.VerifyPassword("wrongPassword", hash) {
t.Error("VerifyPassword should return false for incorrect password")
}
// Should reject empty password
if s.VerifyPassword("", hash) {
t.Error("VerifyPassword should return false for empty password")
}
}
// TestGenerateSecureToken tests token generation
func TestGenerateSecureToken(t *testing.T) {
s := &AuthService{}
tests := []struct {
name string
length int
}{
{"short token", 16},
{"standard token", 32},
{"long token", 64},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, err := s.GenerateSecureToken(tt.length)
if err != nil {
t.Fatalf("GenerateSecureToken failed: %v", err)
}
if token == "" {
t.Error("Token should not be empty")
}
// Tokens should be unique
token2, _ := s.GenerateSecureToken(tt.length)
if token == token2 {
t.Error("Generated tokens should be unique")
}
})
}
}
// TestHashToken tests token hashing for storage
func TestHashToken(t *testing.T) {
s := &AuthService{}
token := "test-token-123"
hash := s.HashToken(token)
if hash == "" {
t.Error("Hash should not be empty")
}
if hash == token {
t.Error("Hash should not equal the original token")
}
// Same token should produce same hash (deterministic)
hash2 := s.HashToken(token)
if hash != hash2 {
t.Error("Same token should produce same hash")
}
// Different tokens should produce different hashes
differentHash := s.HashToken("different-token")
if hash == differentHash {
t.Error("Different tokens should produce different hashes")
}
}
// TestGenerateAccessToken tests JWT access token generation
func TestGenerateAccessToken(t *testing.T) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
token, err := s.GenerateAccessToken(user)
if err != nil {
t.Fatalf("GenerateAccessToken failed: %v", err)
}
if token == "" {
t.Error("Token should not be empty")
}
// Token should have three parts (header.payload.signature)
parts := 0
for _, c := range token {
if c == '.' {
parts++
}
}
if parts != 2 {
t.Errorf("JWT token should have 3 parts, got %d dots", parts)
}
}
// TestValidateAccessToken tests JWT token validation
func TestValidateAccessToken(t *testing.T) {
secret := "test-secret-key-for-testing-purposes"
s := &AuthService{
jwtSecret: secret,
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "admin",
AccountStatus: "active",
}
token, _ := s.GenerateAccessToken(user)
// Should validate valid token
claims, err := s.ValidateAccessToken(token)
if err != nil {
t.Fatalf("ValidateAccessToken failed: %v", err)
}
if claims.UserID != user.ID.String() {
t.Errorf("Expected UserID %s, got %s", user.ID.String(), claims.UserID)
}
if claims.Email != user.Email {
t.Errorf("Expected Email %s, got %s", user.Email, claims.Email)
}
if claims.Role != user.Role {
t.Errorf("Expected Role %s, got %s", user.Role, claims.Role)
}
}
// TestValidateAccessToken_Invalid tests invalid token scenarios
func TestValidateAccessToken_Invalid(t *testing.T) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
tests := []struct {
name string
token string
}{
{"empty token", ""},
{"invalid format", "not-a-jwt-token"},
{"invalid signature", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMTIzIn0.invalidsignature"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := s.ValidateAccessToken(tt.token)
if err == nil {
t.Error("ValidateAccessToken should fail for invalid token")
}
})
}
}
// TestValidateAccessToken_WrongSecret tests token with wrong secret
func TestValidateAccessToken_WrongSecret(t *testing.T) {
s1 := &AuthService{
jwtSecret: "secret-one",
accessTokenExp: time.Hour,
}
s2 := &AuthService{
jwtSecret: "secret-two",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
// Generate token with first secret
token, _ := s1.GenerateAccessToken(user)
// Try to validate with second secret (should fail)
_, err := s2.ValidateAccessToken(token)
if err == nil {
t.Error("ValidateAccessToken should fail when using wrong secret")
}
}
// TestGenerateRefreshToken tests refresh token generation
func TestGenerateRefreshToken(t *testing.T) {
s := &AuthService{}
token, hash, err := s.GenerateRefreshToken()
if err != nil {
t.Fatalf("GenerateRefreshToken failed: %v", err)
}
if token == "" {
t.Error("Token should not be empty")
}
if hash == "" {
t.Error("Hash should not be empty")
}
// Verify hash matches token
expectedHash := s.HashToken(token)
if hash != expectedHash {
t.Error("Returned hash should match hashed token")
}
// Tokens should be unique
token2, hash2, _ := s.GenerateRefreshToken()
if token == token2 {
t.Error("Generated tokens should be unique")
}
if hash == hash2 {
t.Error("Generated hashes should be unique")
}
}
// TestPasswordStrength tests various password scenarios
func TestPasswordStrength(t *testing.T) {
s := &AuthService{}
passwords := []struct {
password string
valid bool
}{
{"short", true}, // bcrypt accepts any length
{"12345678", true}, // numbers only
{"password", true}, // letters only
{"Pass123!", true}, // mixed
{"", true}, // empty (bcrypt allows)
{string(make([]byte, 72)), true}, // max bcrypt length
}
for _, p := range passwords {
hash, err := s.HashPassword(p.password)
if p.valid && err != nil {
t.Errorf("HashPassword failed for valid password %q: %v", p.password, err)
}
if p.valid && !s.VerifyPassword(p.password, hash) {
t.Errorf("VerifyPassword failed for password %q", p.password)
}
}
}
// BenchmarkHashPassword benchmarks password hashing
func BenchmarkHashPassword(b *testing.B) {
s := &AuthService{}
password := "testPassword123!"
for i := 0; i < b.N; i++ {
s.HashPassword(password)
}
}
// BenchmarkVerifyPassword benchmarks password verification
func BenchmarkVerifyPassword(b *testing.B) {
s := &AuthService{}
password := "testPassword123!"
hash, _ := s.HashPassword(password)
for i := 0; i < b.N; i++ {
s.VerifyPassword(password, hash)
}
}
// BenchmarkGenerateAccessToken benchmarks JWT token generation
func BenchmarkGenerateAccessToken(b *testing.B) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
for i := 0; i < b.N; i++ {
s.GenerateAccessToken(user)
}
}
// BenchmarkValidateAccessToken benchmarks JWT token validation
func BenchmarkValidateAccessToken(b *testing.B) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
token, _ := s.GenerateAccessToken(user)
for i := 0; i < b.N; i++ {
s.ValidateAccessToken(token)
}
}
@@ -0,0 +1,518 @@
package services
import (
"testing"
"time"
"github.com/google/uuid"
)
// TestConsentService_CreateConsent tests creating a new consent
func TestConsentService_CreateConsent(t *testing.T) {
// This is a unit test with table-driven approach
tests := []struct {
name string
userID uuid.UUID
versionID uuid.UUID
consented bool
expectError bool
errorContains string
}{
{
name: "valid consent - accepted",
userID: uuid.New(),
versionID: uuid.New(),
consented: true,
expectError: false,
},
{
name: "valid consent - declined",
userID: uuid.New(),
versionID: uuid.New(),
consented: false,
expectError: false,
},
{
name: "empty user ID",
userID: uuid.Nil,
versionID: uuid.New(),
consented: true,
expectError: true,
errorContains: "user ID",
},
{
name: "empty version ID",
userID: uuid.New(),
versionID: uuid.Nil,
consented: true,
expectError: true,
errorContains: "version ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate inputs (in real implementation this would be in the service)
var hasError bool
if tt.userID == uuid.Nil {
hasError = true
} else if tt.versionID == uuid.Nil {
hasError = true
}
// Assert
if tt.expectError && !hasError {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
if !tt.expectError && hasError {
t.Error("Expected no error, got error")
}
})
}
}
// TestConsentService_WithdrawConsent tests withdrawing consent
func TestConsentService_WithdrawConsent(t *testing.T) {
tests := []struct {
name string
consentID uuid.UUID
userID uuid.UUID
expectError bool
errorContains string
}{
{
name: "valid withdrawal",
consentID: uuid.New(),
userID: uuid.New(),
expectError: false,
},
{
name: "empty consent ID",
consentID: uuid.Nil,
userID: uuid.New(),
expectError: true,
errorContains: "consent ID",
},
{
name: "empty user ID",
consentID: uuid.New(),
userID: uuid.Nil,
expectError: true,
errorContains: "user ID",
},
{
name: "both empty",
consentID: uuid.Nil,
userID: uuid.Nil,
expectError: true,
errorContains: "ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate
var hasError bool
if tt.consentID == uuid.Nil || tt.userID == uuid.Nil {
hasError = true
}
// Assert
if tt.expectError && !hasError {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
if !tt.expectError && hasError {
t.Error("Expected no error, got error")
}
})
}
}
// TestConsentService_CheckConsent tests checking consent status
func TestConsentService_CheckConsent(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
documentType string
language string
hasConsent bool
needsUpdate bool
expectedConsent bool
expectedNeedsUpd bool
}{
{
name: "user has current consent",
userID: uuid.New(),
documentType: "terms",
language: "de",
hasConsent: true,
needsUpdate: false,
expectedConsent: true,
expectedNeedsUpd: false,
},
{
name: "user has outdated consent",
userID: uuid.New(),
documentType: "privacy",
language: "de",
hasConsent: true,
needsUpdate: true,
expectedConsent: true,
expectedNeedsUpd: true,
},
{
name: "user has no consent",
userID: uuid.New(),
documentType: "cookies",
language: "de",
hasConsent: false,
needsUpdate: true,
expectedConsent: false,
expectedNeedsUpd: true,
},
{
name: "english language",
userID: uuid.New(),
documentType: "terms",
language: "en",
hasConsent: true,
needsUpdate: false,
expectedConsent: true,
expectedNeedsUpd: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate consent check logic
hasConsent := tt.hasConsent
needsUpdate := tt.needsUpdate
// Assert
if hasConsent != tt.expectedConsent {
t.Errorf("Expected hasConsent=%v, got %v", tt.expectedConsent, hasConsent)
}
if needsUpdate != tt.expectedNeedsUpd {
t.Errorf("Expected needsUpdate=%v, got %v", tt.expectedNeedsUpd, needsUpdate)
}
})
}
}
// TestConsentService_GetConsentHistory tests retrieving consent history
func TestConsentService_GetConsentHistory(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
expectError bool
expectEmpty bool
}{
{
name: "valid user with consents",
userID: uuid.New(),
expectError: false,
expectEmpty: false,
},
{
name: "valid user without consents",
userID: uuid.New(),
expectError: false,
expectEmpty: true,
},
{
name: "invalid user ID",
userID: uuid.Nil,
expectError: true,
expectEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
// Assert error expectation
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestConsentService_UpdateConsent tests updating existing consent
func TestConsentService_UpdateConsent(t *testing.T) {
tests := []struct {
name string
consentID uuid.UUID
userID uuid.UUID
newConsented bool
expectError bool
}{
{
name: "update to consented",
consentID: uuid.New(),
userID: uuid.New(),
newConsented: true,
expectError: false,
},
{
name: "update to not consented",
consentID: uuid.New(),
userID: uuid.New(),
newConsented: false,
expectError: false,
},
{
name: "invalid consent ID",
consentID: uuid.Nil,
userID: uuid.New(),
newConsented: true,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.consentID == uuid.Nil {
err = &ValidationError{Field: "consent ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestConsentService_GetConsentStats tests getting consent statistics
func TestConsentService_GetConsentStats(t *testing.T) {
tests := []struct {
name string
documentType string
totalUsers int
consentedUsers int
expectedRate float64
}{
{
name: "100% consent rate",
documentType: "terms",
totalUsers: 100,
consentedUsers: 100,
expectedRate: 100.0,
},
{
name: "50% consent rate",
documentType: "privacy",
totalUsers: 100,
consentedUsers: 50,
expectedRate: 50.0,
},
{
name: "0% consent rate",
documentType: "cookies",
totalUsers: 100,
consentedUsers: 0,
expectedRate: 0.0,
},
{
name: "no users",
documentType: "terms",
totalUsers: 0,
consentedUsers: 0,
expectedRate: 0.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Calculate consent rate
var consentRate float64
if tt.totalUsers > 0 {
consentRate = float64(tt.consentedUsers) / float64(tt.totalUsers) * 100
}
// Assert
if consentRate != tt.expectedRate {
t.Errorf("Expected consent rate %.2f%%, got %.2f%%", tt.expectedRate, consentRate)
}
})
}
}
// TestConsentService_BulkConsentCheck tests checking multiple consents at once
func TestConsentService_BulkConsentCheck(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
documentTypes []string
expectError bool
}{
{
name: "check multiple documents",
userID: uuid.New(),
documentTypes: []string{"terms", "privacy", "cookies"},
expectError: false,
},
{
name: "check single document",
userID: uuid.New(),
documentTypes: []string{"terms"},
expectError: false,
},
{
name: "empty document list",
userID: uuid.New(),
documentTypes: []string{},
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
documentTypes: []string{"terms"},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestConsentService_ConsentVersionComparison tests version comparison logic
func TestConsentService_ConsentVersionComparison(t *testing.T) {
tests := []struct {
name string
currentVersion string
consentedVersion string
needsUpdate bool
}{
{
name: "same version",
currentVersion: "1.0.0",
consentedVersion: "1.0.0",
needsUpdate: false,
},
{
name: "minor version update",
currentVersion: "1.1.0",
consentedVersion: "1.0.0",
needsUpdate: true,
},
{
name: "major version update",
currentVersion: "2.0.0",
consentedVersion: "1.0.0",
needsUpdate: true,
},
{
name: "patch version update",
currentVersion: "1.0.1",
consentedVersion: "1.0.0",
needsUpdate: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simple version comparison (in real implementation use proper semver)
needsUpdate := tt.currentVersion != tt.consentedVersion
if needsUpdate != tt.needsUpdate {
t.Errorf("Expected needsUpdate=%v, got %v", tt.needsUpdate, needsUpdate)
}
})
}
}
// TestConsentService_ConsentDeadlineCheck tests deadline validation
func TestConsentService_ConsentDeadlineCheck(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadline time.Time
isOverdue bool
daysLeft int
}{
{
name: "deadline in 30 days",
deadline: now.AddDate(0, 0, 30),
isOverdue: false,
daysLeft: 30,
},
{
name: "deadline in 7 days",
deadline: now.AddDate(0, 0, 7),
isOverdue: false,
daysLeft: 7,
},
{
name: "deadline today",
deadline: now,
isOverdue: false,
daysLeft: 0,
},
{
name: "deadline 1 day overdue",
deadline: now.AddDate(0, 0, -1),
isOverdue: true,
daysLeft: -1,
},
{
name: "deadline 30 days overdue",
deadline: now.AddDate(0, 0, -30),
isOverdue: true,
daysLeft: -30,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Calculate if overdue
isOverdue := tt.deadline.Before(now)
daysLeft := int(tt.deadline.Sub(now).Hours() / 24)
if isOverdue != tt.isOverdue {
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
}
// Allow 1 day difference due to time precision
if abs(daysLeft-tt.daysLeft) > 1 {
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
}
})
}
}
// Helper functions
// abs returns the absolute value of an integer
func abs(n int) int {
if n < 0 {
return -n
}
return n
}
@@ -0,0 +1,434 @@
package services
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
// DeadlineService handles consent deadlines and account suspensions
type DeadlineService struct {
pool *pgxpool.Pool
notificationService *NotificationService
}
// ConsentDeadline represents a consent deadline for a user
type ConsentDeadline struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
DocumentVersionID uuid.UUID `json:"document_version_id"`
DeadlineAt time.Time `json:"deadline_at"`
ReminderCount int `json:"reminder_count"`
LastReminderAt *time.Time `json:"last_reminder_at"`
ConsentGivenAt *time.Time `json:"consent_given_at"`
CreatedAt time.Time `json:"created_at"`
// Joined fields
DocumentName string `json:"document_name"`
VersionNumber string `json:"version_number"`
}
// AccountSuspension represents an account suspension
type AccountSuspension struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Reason string `json:"reason"`
Details map[string]interface{} `json:"details"`
SuspendedAt time.Time `json:"suspended_at"`
LiftedAt *time.Time `json:"lifted_at"`
LiftedBy *uuid.UUID `json:"lifted_by"`
}
// NewDeadlineService creates a new deadline service
func NewDeadlineService(pool *pgxpool.Pool, notificationService *NotificationService) *DeadlineService {
return &DeadlineService{
pool: pool,
notificationService: notificationService,
}
}
// CreateDeadlinesForPublishedVersion creates consent deadlines for all active users
// when a new mandatory document version is published
func (s *DeadlineService) CreateDeadlinesForPublishedVersion(ctx context.Context, versionID uuid.UUID) error {
// Get version info
var documentName, versionNumber string
var isMandatory bool
err := s.pool.QueryRow(ctx, `
SELECT ld.name, dv.version, ld.is_mandatory
FROM document_versions dv
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE dv.id = $1
`, versionID).Scan(&documentName, &versionNumber, &isMandatory)
if err != nil {
return fmt.Errorf("failed to get version info: %w", err)
}
// Only create deadlines for mandatory documents
if !isMandatory {
return nil
}
// Deadline is 30 days from now
deadlineAt := time.Now().AddDate(0, 0, 30)
// Get all active users who haven't given consent to this version
_, err = s.pool.Exec(ctx, `
INSERT INTO consent_deadlines (user_id, document_version_id, deadline_at)
SELECT u.id, $1, $2
FROM users u
WHERE u.account_status = 'active'
AND NOT EXISTS (
SELECT 1 FROM user_consents uc
WHERE uc.user_id = u.id AND uc.document_version_id = $1 AND uc.consented = TRUE
)
ON CONFLICT (user_id, document_version_id) DO NOTHING
`, versionID, deadlineAt)
if err != nil {
return fmt.Errorf("failed to create deadlines: %w", err)
}
// Notify users via notification service
if s.notificationService != nil {
go s.notificationService.NotifyConsentRequired(ctx, documentName, versionID.String())
}
return nil
}
// MarkConsentGiven marks a deadline as fulfilled when user gives consent
func (s *DeadlineService) MarkConsentGiven(ctx context.Context, userID, versionID uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE consent_deadlines
SET consent_given_at = NOW()
WHERE user_id = $1 AND document_version_id = $2 AND consent_given_at IS NULL
`, userID, versionID)
if err != nil {
return err
}
// Check if user should be unsuspended
return s.checkAndLiftSuspension(ctx, userID)
}
// GetPendingDeadlines returns all pending deadlines for a user
func (s *DeadlineService) GetPendingDeadlines(ctx context.Context, userID uuid.UUID) ([]ConsentDeadline, error) {
rows, err := s.pool.Query(ctx, `
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at,
cd.reminder_count, cd.last_reminder_at, cd.consent_given_at, cd.created_at,
ld.name as document_name, dv.version as version_number
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE cd.user_id = $1 AND cd.consent_given_at IS NULL
ORDER BY cd.deadline_at ASC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var deadlines []ConsentDeadline
for rows.Next() {
var d ConsentDeadline
if err := rows.Scan(&d.ID, &d.UserID, &d.DocumentVersionID, &d.DeadlineAt,
&d.ReminderCount, &d.LastReminderAt, &d.ConsentGivenAt, &d.CreatedAt,
&d.DocumentName, &d.VersionNumber); err != nil {
continue
}
deadlines = append(deadlines, d)
}
return deadlines, nil
}
// ProcessDailyDeadlines is meant to be called by a cron job daily
// It sends reminders and suspends accounts that have missed deadlines
func (s *DeadlineService) ProcessDailyDeadlines(ctx context.Context) error {
now := time.Now()
// 1. Send reminders for upcoming deadlines
if err := s.sendReminders(ctx, now); err != nil {
fmt.Printf("Error sending reminders: %v\n", err)
}
// 2. Suspend accounts with expired deadlines
if err := s.suspendExpiredAccounts(ctx, now); err != nil {
fmt.Printf("Error suspending accounts: %v\n", err)
}
return nil
}
// sendReminders sends reminder notifications based on days remaining
func (s *DeadlineService) sendReminders(ctx context.Context, now time.Time) error {
// Reminder schedule: Day 7, 14, 21, 28
reminderDays := []int{7, 14, 21, 28}
for _, days := range reminderDays {
targetDate := now.AddDate(0, 0, days)
dayStart := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, targetDate.Location())
dayEnd := dayStart.AddDate(0, 0, 1)
// Find deadlines that fall on this reminder day
rows, err := s.pool.Query(ctx, `
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at, cd.reminder_count,
ld.name as document_name
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE cd.consent_given_at IS NULL
AND cd.deadline_at >= $1 AND cd.deadline_at < $2
AND (cd.last_reminder_at IS NULL OR cd.last_reminder_at < $3)
`, dayStart, dayEnd, dayStart)
if err != nil {
continue
}
for rows.Next() {
var id, userID, versionID uuid.UUID
var deadlineAt time.Time
var reminderCount int
var documentName string
if err := rows.Scan(&id, &userID, &versionID, &deadlineAt, &reminderCount, &documentName); err != nil {
continue
}
// Send reminder notification
daysLeft := 30 - (30 - days)
urgency := "freundlich"
if days <= 7 {
urgency = "dringend"
} else if days <= 14 {
urgency = "wichtig"
}
title := fmt.Sprintf("Erinnerung: Zustimmung erforderlich (%s)", urgency)
body := fmt.Sprintf("Bitte bestätigen Sie '%s' innerhalb von %d Tagen.", documentName, daysLeft)
if s.notificationService != nil {
s.notificationService.CreateNotification(ctx, userID, NotificationTypeConsentReminder, title, body, map[string]interface{}{
"document_name": documentName,
"days_left": daysLeft,
"version_id": versionID.String(),
})
}
// Update reminder count and timestamp
s.pool.Exec(ctx, `
UPDATE consent_deadlines
SET reminder_count = reminder_count + 1, last_reminder_at = NOW()
WHERE id = $1
`, id)
}
rows.Close()
}
return nil
}
// suspendExpiredAccounts suspends accounts that have missed their deadline
func (s *DeadlineService) suspendExpiredAccounts(ctx context.Context, now time.Time) error {
// Find users with expired deadlines
rows, err := s.pool.Query(ctx, `
SELECT DISTINCT cd.user_id, array_agg(ld.name) as documents
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
JOIN users u ON cd.user_id = u.id
WHERE cd.consent_given_at IS NULL
AND cd.deadline_at < $1
AND u.account_status = 'active'
AND ld.is_mandatory = TRUE
GROUP BY cd.user_id
`, now)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var userID uuid.UUID
var documents []string
if err := rows.Scan(&userID, &documents); err != nil {
continue
}
// Suspend the account
if err := s.suspendAccount(ctx, userID, "consent_deadline_missed", documents); err != nil {
fmt.Printf("Failed to suspend user %s: %v\n", userID, err)
}
}
return nil
}
// suspendAccount suspends a user account
func (s *DeadlineService) suspendAccount(ctx context.Context, userID uuid.UUID, reason string, documents []string) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
// Update user status
_, err = tx.Exec(ctx, `
UPDATE users SET account_status = 'suspended', updated_at = NOW()
WHERE id = $1 AND account_status = 'active'
`, userID)
if err != nil {
return err
}
// Create suspension record
_, err = tx.Exec(ctx, `
INSERT INTO account_suspensions (user_id, reason, details)
VALUES ($1, $2, $3)
`, userID, reason, map[string]interface{}{"documents": documents})
if err != nil {
return err
}
// Log to audit
_, err = tx.Exec(ctx, `
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id, details)
VALUES ($1, 'account_suspended', 'user', $1, $2)
`, userID, map[string]interface{}{"reason": reason, "documents": documents})
if err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
// Send suspension notification
if s.notificationService != nil {
title := "Account vorübergehend gesperrt"
body := "Ihr Account wurde gesperrt, da ausstehende Zustimmungen nicht innerhalb der Frist erteilt wurden. Bitte bestätigen Sie die ausstehenden Dokumente."
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountSuspended, title, body, map[string]interface{}{
"documents": documents,
})
}
return nil
}
// checkAndLiftSuspension checks if user has completed all required consents and lifts suspension
func (s *DeadlineService) checkAndLiftSuspension(ctx context.Context, userID uuid.UUID) error {
// Check if user is currently suspended
var accountStatus string
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
if err != nil || accountStatus != "suspended" {
return nil
}
// Check if there are any pending mandatory consents
var pendingCount int
err = s.pool.QueryRow(ctx, `
SELECT COUNT(*)
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE cd.user_id = $1
AND cd.consent_given_at IS NULL
AND ld.is_mandatory = TRUE
`, userID).Scan(&pendingCount)
if err != nil {
return err
}
// If no pending consents, lift the suspension
if pendingCount == 0 {
return s.liftSuspension(ctx, userID)
}
return nil
}
// liftSuspension lifts a user's suspension
func (s *DeadlineService) liftSuspension(ctx context.Context, userID uuid.UUID) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
// Update user status
_, err = tx.Exec(ctx, `
UPDATE users SET account_status = 'active', updated_at = NOW()
WHERE id = $1 AND account_status = 'suspended'
`, userID)
if err != nil {
return err
}
// Update suspension record
_, err = tx.Exec(ctx, `
UPDATE account_suspensions
SET lifted_at = NOW()
WHERE user_id = $1 AND lifted_at IS NULL
`, userID)
if err != nil {
return err
}
// Log to audit
_, err = tx.Exec(ctx, `
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id)
VALUES ($1, 'account_restored', 'user', $1)
`, userID)
if err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
// Send restoration notification
if s.notificationService != nil {
title := "Account wiederhergestellt"
body := "Vielen Dank! Ihr Account wurde wiederhergestellt. Sie können die Anwendung wieder vollständig nutzen."
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountRestored, title, body, nil)
}
return nil
}
// GetAccountSuspension returns the current suspension for a user
func (s *DeadlineService) GetAccountSuspension(ctx context.Context, userID uuid.UUID) (*AccountSuspension, error) {
var suspension AccountSuspension
err := s.pool.QueryRow(ctx, `
SELECT id, user_id, reason, details, suspended_at, lifted_at, lifted_by
FROM account_suspensions
WHERE user_id = $1 AND lifted_at IS NULL
ORDER BY suspended_at DESC
LIMIT 1
`, userID).Scan(&suspension.ID, &suspension.UserID, &suspension.Reason, &suspension.Details,
&suspension.SuspendedAt, &suspension.LiftedAt, &suspension.LiftedBy)
if err != nil {
return nil, err
}
return &suspension, nil
}
// IsUserSuspended checks if a user is currently suspended
func (s *DeadlineService) IsUserSuspended(ctx context.Context, userID uuid.UUID) (bool, error) {
var status string
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&status)
if err != nil {
return false, err
}
return status == "suspended", nil
}
@@ -0,0 +1,439 @@
package services
import (
"testing"
"time"
"github.com/google/uuid"
)
// TestDeadlineService_CreateDeadline tests creating consent deadlines
func TestDeadlineService_CreateDeadline(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
versionID uuid.UUID
deadlineAt time.Time
expectError bool
}{
{
name: "valid deadline - 30 days",
userID: uuid.New(),
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, 30),
expectError: false,
},
{
name: "valid deadline - 14 days",
userID: uuid.New(),
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, 14),
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, 30),
expectError: true,
},
{
name: "invalid version ID",
userID: uuid.New(),
versionID: uuid.Nil,
deadlineAt: time.Now().AddDate(0, 0, 30),
expectError: true,
},
{
name: "deadline in past",
userID: uuid.New(),
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, -1),
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if tt.versionID == uuid.Nil {
err = &ValidationError{Field: "version ID", Message: "required"}
} else if tt.deadlineAt.Before(time.Now()) {
err = &ValidationError{Field: "deadline", Message: "must be in the future"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDeadlineService_CheckDeadlineStatus tests deadline status checking
func TestDeadlineService_CheckDeadlineStatus(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadlineAt time.Time
isOverdue bool
daysLeft int
urgency string
}{
{
name: "30 days left",
deadlineAt: now.AddDate(0, 0, 30),
isOverdue: false,
daysLeft: 30,
urgency: "normal",
},
{
name: "7 days left - warning",
deadlineAt: now.AddDate(0, 0, 7),
isOverdue: false,
daysLeft: 7,
urgency: "warning",
},
{
name: "3 days left - urgent",
deadlineAt: now.AddDate(0, 0, 3),
isOverdue: false,
daysLeft: 3,
urgency: "urgent",
},
{
name: "1 day left - critical",
deadlineAt: now.AddDate(0, 0, 1),
isOverdue: false,
daysLeft: 1,
urgency: "critical",
},
{
name: "overdue by 1 day",
deadlineAt: now.AddDate(0, 0, -1),
isOverdue: true,
daysLeft: -1,
urgency: "overdue",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isOverdue := tt.deadlineAt.Before(now)
daysLeft := int(tt.deadlineAt.Sub(now).Hours() / 24)
var urgency string
if isOverdue {
urgency = "overdue"
} else if daysLeft <= 1 {
urgency = "critical"
} else if daysLeft <= 3 {
urgency = "urgent"
} else if daysLeft <= 7 {
urgency = "warning"
} else {
urgency = "normal"
}
if isOverdue != tt.isOverdue {
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
}
if abs(daysLeft-tt.daysLeft) > 1 { // Allow 1 day difference
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
}
if urgency != tt.urgency {
t.Errorf("Expected urgency=%s, got %s", tt.urgency, urgency)
}
})
}
}
// TestDeadlineService_SendReminders tests reminder scheduling
func TestDeadlineService_SendReminders(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadlineAt time.Time
lastReminderAt *time.Time
reminderCount int
shouldSend bool
nextReminder int // days before deadline
}{
{
name: "first reminder - 14 days before",
deadlineAt: now.AddDate(0, 0, 14),
lastReminderAt: nil,
reminderCount: 0,
shouldSend: true,
nextReminder: 14,
},
{
name: "second reminder - 7 days before",
deadlineAt: now.AddDate(0, 0, 7),
lastReminderAt: ptrTime(now.AddDate(0, 0, -7)),
reminderCount: 1,
shouldSend: true,
nextReminder: 7,
},
{
name: "third reminder - 3 days before",
deadlineAt: now.AddDate(0, 0, 3),
lastReminderAt: ptrTime(now.AddDate(0, 0, -4)),
reminderCount: 2,
shouldSend: true,
nextReminder: 3,
},
{
name: "final reminder - 1 day before",
deadlineAt: now.AddDate(0, 0, 1),
lastReminderAt: ptrTime(now.AddDate(0, 0, -2)),
reminderCount: 3,
shouldSend: true,
nextReminder: 1,
},
{
name: "too soon for next reminder",
deadlineAt: now.AddDate(0, 0, 10),
lastReminderAt: ptrTime(now.AddDate(0, 0, -1)),
reminderCount: 1,
shouldSend: false,
nextReminder: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
daysUntilDeadline := int(tt.deadlineAt.Sub(now).Hours() / 24)
// Reminder schedule: 14, 7, 3, 1 days before deadline
reminderDays := []int{14, 7, 3, 1}
shouldSend := false
for _, day := range reminderDays {
if daysUntilDeadline == day {
// Check if enough time passed since last reminder
if tt.lastReminderAt == nil || now.Sub(*tt.lastReminderAt) > 12*time.Hour {
shouldSend = true
break
}
}
}
if shouldSend != tt.shouldSend {
t.Errorf("Expected shouldSend=%v, got %v", tt.shouldSend, shouldSend)
}
})
}
}
// TestDeadlineService_SuspendAccount tests account suspension logic
func TestDeadlineService_SuspendAccount(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
reason string
shouldSuspend bool
expectError bool
}{
{
name: "suspend for missed deadline",
userID: uuid.New(),
reason: "consent_deadline_exceeded",
shouldSuspend: true,
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
reason: "consent_deadline_exceeded",
shouldSuspend: false,
expectError: true,
},
{
name: "invalid reason",
userID: uuid.New(),
reason: "",
shouldSuspend: false,
expectError: true,
},
}
validReasons := map[string]bool{
"consent_deadline_exceeded": true,
"mandatory_consent_missing": true,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if !validReasons[tt.reason] && tt.reason != "" {
err = &ValidationError{Field: "reason", Message: "invalid suspension reason"}
} else if tt.reason == "" {
err = &ValidationError{Field: "reason", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDeadlineService_LiftSuspension tests lifting account suspension
func TestDeadlineService_LiftSuspension(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
adminID uuid.UUID
reason string
expectError bool
}{
{
name: "lift valid suspension",
userID: uuid.New(),
adminID: uuid.New(),
reason: "consent provided",
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
adminID: uuid.New(),
reason: "consent provided",
expectError: true,
},
{
name: "invalid admin ID",
userID: uuid.New(),
adminID: uuid.Nil,
reason: "consent provided",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if tt.adminID == uuid.Nil {
err = &ValidationError{Field: "admin ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDeadlineService_GetOverdueDeadlines tests finding overdue deadlines
func TestDeadlineService_GetOverdueDeadlines(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadlines []time.Time
expected int // number of overdue
}{
{
name: "no overdue deadlines",
deadlines: []time.Time{
now.AddDate(0, 0, 1),
now.AddDate(0, 0, 7),
now.AddDate(0, 0, 30),
},
expected: 0,
},
{
name: "some overdue",
deadlines: []time.Time{
now.AddDate(0, 0, -1),
now.AddDate(0, 0, -5),
now.AddDate(0, 0, 7),
},
expected: 2,
},
{
name: "all overdue",
deadlines: []time.Time{
now.AddDate(0, 0, -1),
now.AddDate(0, 0, -7),
now.AddDate(0, 0, -30),
},
expected: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
overdueCount := 0
for _, deadline := range tt.deadlines {
if deadline.Before(now) {
overdueCount++
}
}
if overdueCount != tt.expected {
t.Errorf("Expected %d overdue, got %d", tt.expected, overdueCount)
}
})
}
}
// TestDeadlineService_ProcessScheduledTasks tests scheduled task processing
func TestDeadlineService_ProcessScheduledTasks(t *testing.T) {
now := time.Now()
tests := []struct {
name string
task string
scheduledAt time.Time
shouldProcess bool
}{
{
name: "process due task",
task: "send_reminder",
scheduledAt: now.Add(-1 * time.Hour),
shouldProcess: true,
},
{
name: "skip future task",
task: "send_reminder",
scheduledAt: now.Add(1 * time.Hour),
shouldProcess: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldProcess := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
if shouldProcess != tt.shouldProcess {
t.Errorf("Expected shouldProcess=%v, got %v", tt.shouldProcess, shouldProcess)
}
})
}
}
// Helper functions
func ptrTime(t time.Time) *time.Time {
return &t
}
@@ -0,0 +1,728 @@
package services
import (
"regexp"
"testing"
"time"
"github.com/google/uuid"
)
// TestDocumentService_CreateDocument tests creating a new legal document
func TestDocumentService_CreateDocument(t *testing.T) {
tests := []struct {
name string
docType string
docName string
description string
isMandatory bool
expectError bool
errorContains string
}{
{
name: "valid mandatory document",
docType: "terms",
docName: "Terms of Service",
description: "Our terms and conditions",
isMandatory: true,
expectError: false,
},
{
name: "valid optional document",
docType: "cookies",
docName: "Cookie Policy",
description: "How we use cookies",
isMandatory: false,
expectError: false,
},
{
name: "empty document type",
docType: "",
docName: "Test Document",
description: "Test",
isMandatory: true,
expectError: true,
errorContains: "type",
},
{
name: "empty document name",
docType: "privacy",
docName: "",
description: "Test",
isMandatory: true,
expectError: true,
errorContains: "name",
},
{
name: "invalid document type",
docType: "invalid_type",
docName: "Test",
description: "Test",
isMandatory: false,
expectError: true,
errorContains: "type",
},
}
validTypes := map[string]bool{
"terms": true,
"privacy": true,
"cookies": true,
"community_guidelines": true,
"imprint": true,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate inputs
var err error
if tt.docType == "" {
err = &ValidationError{Field: "type", Message: "required"}
} else if !validTypes[tt.docType] {
err = &ValidationError{Field: "type", Message: "invalid document type"}
} else if tt.docName == "" {
err = &ValidationError{Field: "name", Message: "required"}
}
// Assert
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
})
}
}
// TestDocumentService_UpdateDocument tests updating a document
func TestDocumentService_UpdateDocument(t *testing.T) {
tests := []struct {
name string
documentID uuid.UUID
newName string
newActive bool
expectError bool
}{
{
name: "valid update",
documentID: uuid.New(),
newName: "Updated Name",
newActive: true,
expectError: false,
},
{
name: "deactivate document",
documentID: uuid.New(),
newName: "Test",
newActive: false,
expectError: false,
},
{
name: "invalid document ID",
documentID: uuid.Nil,
newName: "Test",
newActive: true,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.documentID == uuid.Nil {
err = &ValidationError{Field: "document ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDocumentService_CreateVersion tests creating a document version
func TestDocumentService_CreateVersion(t *testing.T) {
tests := []struct {
name string
documentID uuid.UUID
version string
language string
title string
content string
expectError bool
errorContains string
}{
{
name: "valid version - German",
documentID: uuid.New(),
version: "1.0.0",
language: "de",
title: "Nutzungsbedingungen",
content: "<h1>Terms</h1><p>Content...</p>",
expectError: false,
},
{
name: "valid version - English",
documentID: uuid.New(),
version: "1.0.0",
language: "en",
title: "Terms of Service",
content: "<h1>Terms</h1><p>Content...</p>",
expectError: false,
},
{
name: "invalid version format",
documentID: uuid.New(),
version: "1.0",
language: "de",
title: "Test",
content: "Content",
expectError: true,
errorContains: "version",
},
{
name: "invalid language",
documentID: uuid.New(),
version: "1.0.0",
language: "fr",
title: "Test",
content: "Content",
expectError: true,
errorContains: "language",
},
{
name: "empty title",
documentID: uuid.New(),
version: "1.0.0",
language: "de",
title: "",
content: "Content",
expectError: true,
errorContains: "title",
},
{
name: "empty content",
documentID: uuid.New(),
version: "1.0.0",
language: "de",
title: "Test",
content: "",
expectError: true,
errorContains: "content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate semver format (X.Y.Z pattern)
validVersion := regexp.MustCompile(`^\d+\.\d+\.\d+$`).MatchString(tt.version)
validLanguage := tt.language == "de" || tt.language == "en"
var err error
if !validVersion {
err = &ValidationError{Field: "version", Message: "invalid format"}
} else if !validLanguage {
err = &ValidationError{Field: "language", Message: "must be 'de' or 'en'"}
} else if tt.title == "" {
err = &ValidationError{Field: "title", Message: "required"}
} else if tt.content == "" {
err = &ValidationError{Field: "content", Message: "required"}
}
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
})
}
}
// TestDocumentService_VersionStatusTransitions tests version status workflow
func TestDocumentService_VersionStatusTransitions(t *testing.T) {
tests := []struct {
name string
fromStatus string
toStatus string
isAllowed bool
}{
// Valid transitions
{"draft to review", "draft", "review", true},
{"review to approved", "review", "approved", true},
{"review to rejected", "review", "rejected", true},
{"approved to published", "approved", "published", true},
{"approved to scheduled", "approved", "scheduled", true},
{"scheduled to published", "scheduled", "published", true},
{"published to archived", "published", "archived", true},
{"rejected to draft", "rejected", "draft", true},
// Invalid transitions
{"draft to published", "draft", "published", false},
{"draft to approved", "draft", "approved", false},
{"review to published", "review", "published", false},
{"published to draft", "published", "draft", false},
{"published to review", "published", "review", false},
{"archived to draft", "archived", "draft", false},
{"archived to published", "archived", "published", false},
}
// Define valid transitions
validTransitions := map[string][]string{
"draft": {"review"},
"review": {"approved", "rejected"},
"approved": {"published", "scheduled"},
"scheduled": {"published"},
"published": {"archived"},
"rejected": {"draft"},
"archived": {}, // terminal state
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Check if transition is allowed
allowed := false
if transitions, ok := validTransitions[tt.fromStatus]; ok {
for _, validTo := range transitions {
if validTo == tt.toStatus {
allowed = true
break
}
}
}
if allowed != tt.isAllowed {
t.Errorf("Transition %s->%s: expected allowed=%v, got %v",
tt.fromStatus, tt.toStatus, tt.isAllowed, allowed)
}
})
}
}
// TestDocumentService_PublishVersion tests publishing a version
func TestDocumentService_PublishVersion(t *testing.T) {
tests := []struct {
name string
versionID uuid.UUID
currentStatus string
expectError bool
errorContains string
}{
{
name: "publish approved version",
versionID: uuid.New(),
currentStatus: "approved",
expectError: false,
},
{
name: "publish scheduled version",
versionID: uuid.New(),
currentStatus: "scheduled",
expectError: false,
},
{
name: "cannot publish draft",
versionID: uuid.New(),
currentStatus: "draft",
expectError: true,
errorContains: "draft",
},
{
name: "cannot publish review",
versionID: uuid.New(),
currentStatus: "review",
expectError: true,
errorContains: "review",
},
{
name: "invalid version ID",
versionID: uuid.Nil,
currentStatus: "approved",
expectError: true,
errorContains: "ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.versionID == uuid.Nil {
err = &ValidationError{Field: "version ID", Message: "required"}
} else if tt.currentStatus != "approved" && tt.currentStatus != "scheduled" {
err = &ValidationError{Field: "status", Message: "only approved or scheduled versions can be published"}
}
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
})
}
}
// TestDocumentService_ArchiveVersion tests archiving a version
func TestDocumentService_ArchiveVersion(t *testing.T) {
tests := []struct {
name string
versionID uuid.UUID
expectError bool
}{
{
name: "archive valid version",
versionID: uuid.New(),
expectError: false,
},
{
name: "invalid version ID",
versionID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.versionID == uuid.Nil {
err = &ValidationError{Field: "version ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDocumentService_DeleteVersion tests deleting a version
func TestDocumentService_DeleteVersion(t *testing.T) {
tests := []struct {
name string
versionID uuid.UUID
status string
canDelete bool
expectError bool
}{
{
name: "delete draft version",
versionID: uuid.New(),
status: "draft",
canDelete: true,
expectError: false,
},
{
name: "delete rejected version",
versionID: uuid.New(),
status: "rejected",
canDelete: true,
expectError: false,
},
{
name: "cannot delete published version",
versionID: uuid.New(),
status: "published",
canDelete: false,
expectError: true,
},
{
name: "cannot delete approved version",
versionID: uuid.New(),
status: "approved",
canDelete: false,
expectError: true,
},
{
name: "cannot delete archived version",
versionID: uuid.New(),
status: "archived",
canDelete: false,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Only draft and rejected can be deleted
canDelete := tt.status == "draft" || tt.status == "rejected"
var err error
if !canDelete {
err = &ValidationError{Field: "status", Message: "only draft or rejected versions can be deleted"}
}
if tt.expectError {
if err == nil {
t.Error("Expected error, got nil")
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
if canDelete != tt.canDelete {
t.Errorf("Expected canDelete=%v, got %v", tt.canDelete, canDelete)
}
})
}
}
// TestDocumentService_GetLatestVersion tests retrieving the latest version
func TestDocumentService_GetLatestVersion(t *testing.T) {
tests := []struct {
name string
documentID uuid.UUID
language string
status string
expectError bool
}{
{
name: "get latest German version",
documentID: uuid.New(),
language: "de",
status: "published",
expectError: false,
},
{
name: "get latest English version",
documentID: uuid.New(),
language: "en",
status: "published",
expectError: false,
},
{
name: "invalid document ID",
documentID: uuid.Nil,
language: "de",
status: "published",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.documentID == uuid.Nil {
err = &ValidationError{Field: "document ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDocumentService_CompareVersions tests version comparison
func TestDocumentService_CompareVersions(t *testing.T) {
tests := []struct {
name string
version1 string
version2 string
isDifferent bool
}{
{
name: "same version",
version1: "1.0.0",
version2: "1.0.0",
isDifferent: false,
},
{
name: "different major version",
version1: "2.0.0",
version2: "1.0.0",
isDifferent: true,
},
{
name: "different minor version",
version1: "1.1.0",
version2: "1.0.0",
isDifferent: true,
},
{
name: "different patch version",
version1: "1.0.1",
version2: "1.0.0",
isDifferent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isDifferent := tt.version1 != tt.version2
if isDifferent != tt.isDifferent {
t.Errorf("Expected isDifferent=%v, got %v", tt.isDifferent, isDifferent)
}
})
}
}
// TestDocumentService_ScheduledPublishing tests scheduled publishing
func TestDocumentService_ScheduledPublishing(t *testing.T) {
now := time.Now()
tests := []struct {
name string
scheduledAt time.Time
shouldPublish bool
}{
{
name: "scheduled for past - should publish",
scheduledAt: now.Add(-1 * time.Hour),
shouldPublish: true,
},
{
name: "scheduled for now - should publish",
scheduledAt: now,
shouldPublish: true,
},
{
name: "scheduled for future - should not publish",
scheduledAt: now.Add(1 * time.Hour),
shouldPublish: false,
},
{
name: "scheduled for tomorrow - should not publish",
scheduledAt: now.AddDate(0, 0, 1),
shouldPublish: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldPublish := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
if shouldPublish != tt.shouldPublish {
t.Errorf("Expected shouldPublish=%v, got %v", tt.shouldPublish, shouldPublish)
}
})
}
}
// TestDocumentService_ApprovalWorkflow tests the approval workflow
func TestDocumentService_ApprovalWorkflow(t *testing.T) {
tests := []struct {
name string
action string
userRole string
isAllowed bool
}{
// Admin permissions
{"admin submit for review", "submit_review", "admin", true},
{"admin cannot approve", "approve", "admin", false},
{"admin can publish", "publish", "admin", true},
// DSB permissions
{"dsb can approve", "approve", "data_protection_officer", true},
{"dsb can reject", "reject", "data_protection_officer", true},
{"dsb can publish", "publish", "data_protection_officer", true},
// User permissions
{"user cannot submit", "submit_review", "user", false},
{"user cannot approve", "approve", "user", false},
{"user cannot publish", "publish", "user", false},
}
permissions := map[string]map[string]bool{
"admin": {
"submit_review": true,
"approve": false,
"reject": false,
"publish": true,
},
"data_protection_officer": {
"submit_review": true,
"approve": true,
"reject": true,
"publish": true,
},
"user": {
"submit_review": false,
"approve": false,
"reject": false,
"publish": false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rolePerms, ok := permissions[tt.userRole]
if !ok {
t.Fatalf("Unknown role: %s", tt.userRole)
}
isAllowed := rolePerms[tt.action]
if isAllowed != tt.isAllowed {
t.Errorf("Role %s action %s: expected allowed=%v, got %v",
tt.userRole, tt.action, tt.isAllowed, isAllowed)
}
})
}
}
// TestDocumentService_FourEyesPrinciple tests the four-eyes principle
func TestDocumentService_FourEyesPrinciple(t *testing.T) {
tests := []struct {
name string
createdBy uuid.UUID
approver uuid.UUID
approverRole string
canApprove bool
}{
{
name: "different users - DSB can approve",
createdBy: uuid.New(),
approver: uuid.New(),
approverRole: "data_protection_officer",
canApprove: true,
},
{
name: "same user - DSB cannot approve own",
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approverRole: "data_protection_officer",
canApprove: false,
},
{
name: "same user - admin CAN approve own (exception)",
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approverRole: "admin",
canApprove: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Four-eyes principle: DSB cannot approve their own work
// Exception: Admins can (for development/testing)
canApprove := tt.createdBy != tt.approver || tt.approverRole == "admin"
if canApprove != tt.canApprove {
t.Errorf("Expected canApprove=%v, got %v", tt.canApprove, canApprove)
}
})
}
}
@@ -0,0 +1,947 @@
package services
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
// DSRService handles Data Subject Request business logic
type DSRService struct {
pool *pgxpool.Pool
notificationService *NotificationService
emailService *EmailService
}
// NewDSRService creates a new DSRService
func NewDSRService(pool *pgxpool.Pool, notificationService *NotificationService, emailService *EmailService) *DSRService {
return &DSRService{
pool: pool,
notificationService: notificationService,
emailService: emailService,
}
}
// GetPool returns the database pool for direct queries
func (s *DSRService) GetPool() *pgxpool.Pool {
return s.pool
}
// generateRequestNumber generates a unique request number like DSR-2025-000001
func (s *DSRService) generateRequestNumber(ctx context.Context) (string, error) {
var seqNum int64
err := s.pool.QueryRow(ctx, "SELECT nextval('dsr_request_number_seq')").Scan(&seqNum)
if err != nil {
return "", fmt.Errorf("failed to get next sequence number: %w", err)
}
year := time.Now().Year()
return fmt.Sprintf("DSR-%d-%06d", year, seqNum), nil
}
// CreateRequest creates a new data subject request
func (s *DSRService) CreateRequest(ctx context.Context, req models.CreateDSRRequest, createdBy *uuid.UUID) (*models.DataSubjectRequest, error) {
// Validate request type
requestType := models.DSRRequestType(req.RequestType)
if !isValidRequestType(requestType) {
return nil, fmt.Errorf("invalid request type: %s", req.RequestType)
}
// Generate request number
requestNumber, err := s.generateRequestNumber(ctx)
if err != nil {
return nil, err
}
// Calculate deadline
deadlineDays := requestType.DeadlineDays()
deadline := time.Now().AddDate(0, 0, deadlineDays)
// Determine priority
priority := models.DSRPriorityNormal
if req.Priority != "" {
priority = models.DSRPriority(req.Priority)
} else if requestType.IsExpedited() {
priority = models.DSRPriorityExpedited
}
// Determine source
source := models.DSRSourceAPI
if req.Source != "" {
source = models.DSRSource(req.Source)
}
// Serialize request details
detailsJSON, err := json.Marshal(req.RequestDetails)
if err != nil {
detailsJSON = []byte("{}")
}
// Try to find existing user by email
var userID *uuid.UUID
var foundUserID uuid.UUID
err = s.pool.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", req.RequesterEmail).Scan(&foundUserID)
if err == nil {
userID = &foundUserID
}
// Insert request
var dsr models.DataSubjectRequest
err = s.pool.QueryRow(ctx, `
INSERT INTO data_subject_requests (
user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone,
request_details, deadline_at, legal_deadline_days, created_by
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone, identity_verified,
request_details, deadline_at, legal_deadline_days, created_at, updated_at, created_by
`, userID, requestNumber, requestType, models.DSRStatusIntake, priority, source,
req.RequesterEmail, req.RequesterName, req.RequesterPhone,
detailsJSON, deadline, deadlineDays, createdBy,
).Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.RequesterPhone, &dsr.IdentityVerified, &detailsJSON,
&dsr.DeadlineAt, &dsr.LegalDeadlineDays, &dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
)
if err != nil {
return nil, fmt.Errorf("failed to create DSR: %w", err)
}
// Parse details back
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
// Record initial status
s.recordStatusChange(ctx, dsr.ID, nil, models.DSRStatusIntake, createdBy, "Anfrage eingegangen")
// Notify DPOs about new request
go s.notifyNewRequest(context.Background(), &dsr)
return &dsr, nil
}
// GetByID retrieves a DSR by ID
func (s *DSRService) GetByID(ctx context.Context, id uuid.UUID) (*models.DataSubjectRequest, error) {
var dsr models.DataSubjectRequest
var detailsJSON, resultDataJSON []byte
err := s.pool.QueryRow(ctx, `
SELECT id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone,
identity_verified, identity_verified_at, identity_verified_by, identity_verification_method,
request_details, deadline_at, legal_deadline_days, extended_deadline_at, extension_reason,
assigned_to, processing_notes, completed_at, completed_by, result_summary, result_data,
rejected_at, rejected_by, rejection_reason, rejection_legal_basis,
created_at, updated_at, created_by
FROM data_subject_requests WHERE id = $1
`, id).Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.IdentityVerifiedAt,
&dsr.IdentityVerifiedBy, &dsr.IdentityVerificationMethod,
&detailsJSON, &dsr.DeadlineAt, &dsr.LegalDeadlineDays,
&dsr.ExtendedDeadlineAt, &dsr.ExtensionReason, &dsr.AssignedTo,
&dsr.ProcessingNotes, &dsr.CompletedAt, &dsr.CompletedBy,
&dsr.ResultSummary, &resultDataJSON, &dsr.RejectedAt, &dsr.RejectedBy,
&dsr.RejectionReason, &dsr.RejectionLegalBasis,
&dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
)
if err != nil {
return nil, fmt.Errorf("DSR not found: %w", err)
}
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
json.Unmarshal(resultDataJSON, &dsr.ResultData)
return &dsr, nil
}
// GetByNumber retrieves a DSR by request number
func (s *DSRService) GetByNumber(ctx context.Context, requestNumber string) (*models.DataSubjectRequest, error) {
var id uuid.UUID
err := s.pool.QueryRow(ctx, "SELECT id FROM data_subject_requests WHERE request_number = $1", requestNumber).Scan(&id)
if err != nil {
return nil, fmt.Errorf("DSR not found: %w", err)
}
return s.GetByID(ctx, id)
}
// List retrieves DSRs with filters and pagination
func (s *DSRService) List(ctx context.Context, filters models.DSRListFilters, limit, offset int) ([]models.DataSubjectRequest, int, error) {
// Build query
baseQuery := "FROM data_subject_requests WHERE 1=1"
args := []interface{}{}
argIndex := 1
if filters.Status != nil && *filters.Status != "" {
baseQuery += fmt.Sprintf(" AND status = $%d", argIndex)
args = append(args, *filters.Status)
argIndex++
}
if filters.RequestType != nil && *filters.RequestType != "" {
baseQuery += fmt.Sprintf(" AND request_type = $%d", argIndex)
args = append(args, *filters.RequestType)
argIndex++
}
if filters.AssignedTo != nil && *filters.AssignedTo != "" {
baseQuery += fmt.Sprintf(" AND assigned_to = $%d", argIndex)
args = append(args, *filters.AssignedTo)
argIndex++
}
if filters.Priority != nil && *filters.Priority != "" {
baseQuery += fmt.Sprintf(" AND priority = $%d", argIndex)
args = append(args, *filters.Priority)
argIndex++
}
if filters.OverdueOnly {
baseQuery += " AND deadline_at < NOW() AND status NOT IN ('completed', 'rejected', 'cancelled')"
}
if filters.FromDate != nil {
baseQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex)
args = append(args, *filters.FromDate)
argIndex++
}
if filters.ToDate != nil {
baseQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex)
args = append(args, *filters.ToDate)
argIndex++
}
if filters.Search != nil && *filters.Search != "" {
searchPattern := "%" + *filters.Search + "%"
baseQuery += fmt.Sprintf(" AND (request_number ILIKE $%d OR requester_email ILIKE $%d OR requester_name ILIKE $%d)", argIndex, argIndex, argIndex)
args = append(args, searchPattern)
argIndex++
}
// Get total count
var total int
err := s.pool.QueryRow(ctx, "SELECT COUNT(*) "+baseQuery, args...).Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("failed to count DSRs: %w", err)
}
// Get paginated results
query := fmt.Sprintf(`
SELECT id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone, identity_verified,
deadline_at, legal_deadline_days, assigned_to, created_at, updated_at
%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d
`, baseQuery, argIndex, argIndex+1)
args = append(args, limit, offset)
rows, err := s.pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query DSRs: %w", err)
}
defer rows.Close()
var dsrs []models.DataSubjectRequest
for rows.Next() {
var dsr models.DataSubjectRequest
err := rows.Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.DeadlineAt,
&dsr.LegalDeadlineDays, &dsr.AssignedTo, &dsr.CreatedAt, &dsr.UpdatedAt,
)
if err != nil {
return nil, 0, fmt.Errorf("failed to scan DSR: %w", err)
}
dsrs = append(dsrs, dsr)
}
return dsrs, total, nil
}
// ListByUser retrieves DSRs for a specific user
func (s *DSRService) ListByUser(ctx context.Context, userID uuid.UUID) ([]models.DataSubjectRequest, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, deadline_at, created_at, updated_at
FROM data_subject_requests
WHERE user_id = $1 OR requester_email = (SELECT email FROM users WHERE id = $1)
ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to query user DSRs: %w", err)
}
defer rows.Close()
var dsrs []models.DataSubjectRequest
for rows.Next() {
var dsr models.DataSubjectRequest
err := rows.Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.DeadlineAt, &dsr.CreatedAt, &dsr.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan DSR: %w", err)
}
dsrs = append(dsrs, dsr)
}
return dsrs, nil
}
// UpdateStatus changes the status of a DSR
func (s *DSRService) UpdateStatus(ctx context.Context, id uuid.UUID, newStatus models.DSRStatus, comment string, changedBy *uuid.UUID) error {
// Get current status
var currentStatus models.DSRStatus
err := s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
if err != nil {
return fmt.Errorf("DSR not found: %w", err)
}
// Validate transition
if !isValidStatusTransition(currentStatus, newStatus) {
return fmt.Errorf("invalid status transition from %s to %s", currentStatus, newStatus)
}
// Update status
_, err = s.pool.Exec(ctx, `
UPDATE data_subject_requests SET status = $1, updated_at = NOW() WHERE id = $2
`, newStatus, id)
if err != nil {
return fmt.Errorf("failed to update status: %w", err)
}
// Record status change
s.recordStatusChange(ctx, id, &currentStatus, newStatus, changedBy, comment)
return nil
}
// VerifyIdentity marks identity as verified
func (s *DSRService) VerifyIdentity(ctx context.Context, id uuid.UUID, method string, verifiedBy uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET identity_verified = TRUE,
identity_verified_at = NOW(),
identity_verified_by = $1,
identity_verification_method = $2,
status = CASE WHEN status = 'intake' THEN 'identity_verification' ELSE status END,
updated_at = NOW()
WHERE id = $3
`, verifiedBy, method, id)
if err != nil {
return fmt.Errorf("failed to verify identity: %w", err)
}
s.recordStatusChange(ctx, id, nil, models.DSRStatusIdentityVerification, &verifiedBy, "Identität verifiziert via "+method)
return nil
}
// AssignRequest assigns a DSR to a handler
func (s *DSRService) AssignRequest(ctx context.Context, id uuid.UUID, assigneeID uuid.UUID, assignedBy uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests SET assigned_to = $1, updated_at = NOW() WHERE id = $2
`, assigneeID, id)
if err != nil {
return fmt.Errorf("failed to assign DSR: %w", err)
}
// Get assignee name for comment
var assigneeName string
s.pool.QueryRow(ctx, "SELECT COALESCE(name, email) FROM users WHERE id = $1", assigneeID).Scan(&assigneeName)
s.recordStatusChange(ctx, id, nil, "", &assignedBy, "Zugewiesen an "+assigneeName)
// Notify assignee
go s.notifyAssignment(context.Background(), id, assigneeID)
return nil
}
// ExtendDeadline extends the deadline for a DSR
func (s *DSRService) ExtendDeadline(ctx context.Context, id uuid.UUID, reason string, days int, extendedBy uuid.UUID) error {
// Default extension is 2 months (60 days) per Art. 12(3)
if days <= 0 {
days = 60
}
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET extended_deadline_at = deadline_at + ($1 || ' days')::INTERVAL,
extension_reason = $2,
updated_at = NOW()
WHERE id = $3
`, days, reason, id)
if err != nil {
return fmt.Errorf("failed to extend deadline: %w", err)
}
s.recordStatusChange(ctx, id, nil, "", &extendedBy, fmt.Sprintf("Frist um %d Tage verlängert: %s", days, reason))
return nil
}
// CompleteRequest marks a DSR as completed
func (s *DSRService) CompleteRequest(ctx context.Context, id uuid.UUID, summary string, resultData map[string]interface{}, completedBy uuid.UUID) error {
resultJSON, _ := json.Marshal(resultData)
// Get current status
var currentStatus models.DSRStatus
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET status = 'completed',
completed_at = NOW(),
completed_by = $1,
result_summary = $2,
result_data = $3,
updated_at = NOW()
WHERE id = $4
`, completedBy, summary, resultJSON, id)
if err != nil {
return fmt.Errorf("failed to complete DSR: %w", err)
}
s.recordStatusChange(ctx, id, &currentStatus, models.DSRStatusCompleted, &completedBy, summary)
return nil
}
// RejectRequest rejects a DSR with legal basis
func (s *DSRService) RejectRequest(ctx context.Context, id uuid.UUID, reason, legalBasis string, rejectedBy uuid.UUID) error {
// Get current status
var currentStatus models.DSRStatus
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET status = 'rejected',
rejected_at = NOW(),
rejected_by = $1,
rejection_reason = $2,
rejection_legal_basis = $3,
updated_at = NOW()
WHERE id = $4
`, rejectedBy, reason, legalBasis, id)
if err != nil {
return fmt.Errorf("failed to reject DSR: %w", err)
}
s.recordStatusChange(ctx, id, &currentStatus, models.DSRStatusRejected, &rejectedBy, fmt.Sprintf("Abgelehnt (%s): %s", legalBasis, reason))
return nil
}
// CancelRequest cancels a DSR (by user)
func (s *DSRService) CancelRequest(ctx context.Context, id uuid.UUID, cancelledBy uuid.UUID) error {
// Verify ownership
var userID *uuid.UUID
err := s.pool.QueryRow(ctx, "SELECT user_id FROM data_subject_requests WHERE id = $1", id).Scan(&userID)
if err != nil {
return fmt.Errorf("DSR not found: %w", err)
}
if userID == nil || *userID != cancelledBy {
return fmt.Errorf("unauthorized: can only cancel own requests")
}
// Get current status
var currentStatus models.DSRStatus
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
_, err = s.pool.Exec(ctx, `
UPDATE data_subject_requests SET status = 'cancelled', updated_at = NOW() WHERE id = $1
`, id)
if err != nil {
return fmt.Errorf("failed to cancel DSR: %w", err)
}
s.recordStatusChange(ctx, id, &currentStatus, models.DSRStatusCancelled, &cancelledBy, "Vom Antragsteller storniert")
return nil
}
// GetDashboardStats returns statistics for the admin dashboard
func (s *DSRService) GetDashboardStats(ctx context.Context) (*models.DSRDashboardStats, error) {
stats := &models.DSRDashboardStats{
ByType: make(map[string]int),
ByStatus: make(map[string]int),
}
// Total requests
s.pool.QueryRow(ctx, "SELECT COUNT(*) FROM data_subject_requests").Scan(&stats.TotalRequests)
// Pending requests (not completed, rejected, or cancelled)
s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM data_subject_requests
WHERE status NOT IN ('completed', 'rejected', 'cancelled')
`).Scan(&stats.PendingRequests)
// Overdue requests
s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) < NOW()
AND status NOT IN ('completed', 'rejected', 'cancelled')
`).Scan(&stats.OverdueRequests)
// Completed this month
s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM data_subject_requests
WHERE status = 'completed'
AND completed_at >= DATE_TRUNC('month', NOW())
`).Scan(&stats.CompletedThisMonth)
// Average processing days
s.pool.QueryRow(ctx, `
SELECT COALESCE(AVG(EXTRACT(EPOCH FROM (completed_at - created_at)) / 86400), 0)
FROM data_subject_requests WHERE status = 'completed'
`).Scan(&stats.AverageProcessingDays)
// Count by type
rows, _ := s.pool.Query(ctx, `
SELECT request_type, COUNT(*) FROM data_subject_requests GROUP BY request_type
`)
for rows.Next() {
var t string
var count int
rows.Scan(&t, &count)
stats.ByType[t] = count
}
rows.Close()
// Count by status
rows, _ = s.pool.Query(ctx, `
SELECT status, COUNT(*) FROM data_subject_requests GROUP BY status
`)
for rows.Next() {
var s string
var count int
rows.Scan(&s, &count)
stats.ByStatus[s] = count
}
rows.Close()
// Upcoming deadlines (next 7 days)
rows, _ = s.pool.Query(ctx, `
SELECT id, request_number, request_type, status, requester_email, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN NOW() AND NOW() + INTERVAL '7 days'
AND status NOT IN ('completed', 'rejected', 'cancelled')
ORDER BY deadline_at ASC LIMIT 10
`)
for rows.Next() {
var dsr models.DataSubjectRequest
rows.Scan(&dsr.ID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status, &dsr.RequesterEmail, &dsr.DeadlineAt)
stats.UpcomingDeadlines = append(stats.UpcomingDeadlines, dsr)
}
rows.Close()
return stats, nil
}
// GetStatusHistory retrieves the status history for a DSR
func (s *DSRService) GetStatusHistory(ctx context.Context, requestID uuid.UUID) ([]models.DSRStatusHistory, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, request_id, from_status, to_status, changed_by, comment, metadata, created_at
FROM dsr_status_history WHERE request_id = $1 ORDER BY created_at DESC
`, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query status history: %w", err)
}
defer rows.Close()
var history []models.DSRStatusHistory
for rows.Next() {
var h models.DSRStatusHistory
var metadataJSON []byte
err := rows.Scan(&h.ID, &h.RequestID, &h.FromStatus, &h.ToStatus, &h.ChangedBy, &h.Comment, &metadataJSON, &h.CreatedAt)
if err != nil {
continue
}
json.Unmarshal(metadataJSON, &h.Metadata)
history = append(history, h)
}
return history, nil
}
// GetCommunications retrieves communications for a DSR
func (s *DSRService) GetCommunications(ctx context.Context, requestID uuid.UUID) ([]models.DSRCommunication, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, request_id, direction, channel, communication_type, template_version_id,
subject, body_html, body_text, recipient_email, sent_at, error_message,
attachments, created_at, created_by
FROM dsr_communications WHERE request_id = $1 ORDER BY created_at DESC
`, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query communications: %w", err)
}
defer rows.Close()
var comms []models.DSRCommunication
for rows.Next() {
var c models.DSRCommunication
var attachmentsJSON []byte
err := rows.Scan(&c.ID, &c.RequestID, &c.Direction, &c.Channel, &c.CommunicationType,
&c.TemplateVersionID, &c.Subject, &c.BodyHTML, &c.BodyText, &c.RecipientEmail,
&c.SentAt, &c.ErrorMessage, &attachmentsJSON, &c.CreatedAt, &c.CreatedBy)
if err != nil {
continue
}
json.Unmarshal(attachmentsJSON, &c.Attachments)
comms = append(comms, c)
}
return comms, nil
}
// SendCommunication sends a communication for a DSR
func (s *DSRService) SendCommunication(ctx context.Context, requestID uuid.UUID, req models.SendDSRCommunicationRequest, sentBy uuid.UUID) error {
// Get DSR details
dsr, err := s.GetByID(ctx, requestID)
if err != nil {
return err
}
// Get template if specified
var subject, bodyHTML, bodyText string
if req.TemplateVersionID != nil {
templateVersionID, _ := uuid.Parse(*req.TemplateVersionID)
err := s.pool.QueryRow(ctx, `
SELECT subject, body_html, body_text FROM dsr_template_versions WHERE id = $1 AND status = 'published'
`, templateVersionID).Scan(&subject, &bodyHTML, &bodyText)
if err != nil {
return fmt.Errorf("template version not found or not published: %w", err)
}
}
// Use custom content if provided
if req.CustomSubject != nil {
subject = *req.CustomSubject
}
if req.CustomBody != nil {
bodyHTML = *req.CustomBody
bodyText = stripHTML(*req.CustomBody)
}
// Replace variables
variables := map[string]string{
"requester_name": stringOrDefault(dsr.RequesterName, "Antragsteller/in"),
"request_number": dsr.RequestNumber,
"request_type_de": dsr.RequestType.Label(),
"request_date": dsr.CreatedAt.Format("02.01.2006"),
"deadline_date": dsr.DeadlineAt.Format("02.01.2006"),
}
for k, v := range req.Variables {
variables[k] = v
}
subject = replaceVariables(subject, variables)
bodyHTML = replaceVariables(bodyHTML, variables)
bodyText = replaceVariables(bodyText, variables)
// Send email
if s.emailService != nil {
err = s.emailService.SendEmail(dsr.RequesterEmail, subject, bodyHTML, bodyText)
if err != nil {
// Log error but continue
_, _ = s.pool.Exec(ctx, `
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
template_version_id, subject, body_html, body_text, recipient_email, error_message, created_by)
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
dsr.RequesterEmail, err.Error(), sentBy)
return fmt.Errorf("failed to send email: %w", err)
}
}
// Log communication
now := time.Now()
_, err = s.pool.Exec(ctx, `
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
template_version_id, subject, body_html, body_text, recipient_email, sent_at, created_by)
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
dsr.RequesterEmail, now, sentBy)
return err
}
// InitErasureExceptionChecks initializes exception checks for an erasure request
func (s *DSRService) InitErasureExceptionChecks(ctx context.Context, requestID uuid.UUID) error {
exceptions := []struct {
Type string
Description string
}{
{models.DSRExceptionFreedomExpression, "Ausübung des Rechts auf freie Meinungsäußerung und Information (Art. 17 Abs. 3 lit. a)"},
{models.DSRExceptionLegalObligation, "Erfüllung einer rechtlichen Verpflichtung oder öffentlichen Aufgabe (Art. 17 Abs. 3 lit. b)"},
{models.DSRExceptionPublicHealth, "Gründe des öffentlichen Interesses im Bereich der öffentlichen Gesundheit (Art. 17 Abs. 3 lit. c)"},
{models.DSRExceptionArchiving, "Im öffentlichen Interesse liegende Archivzwecke, Forschung oder Statistik (Art. 17 Abs. 3 lit. d)"},
{models.DSRExceptionLegalClaims, "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen (Art. 17 Abs. 3 lit. e)"},
}
for _, exc := range exceptions {
_, err := s.pool.Exec(ctx, `
INSERT INTO dsr_exception_checks (request_id, exception_type, description)
VALUES ($1, $2, $3) ON CONFLICT DO NOTHING
`, requestID, exc.Type, exc.Description)
if err != nil {
return fmt.Errorf("failed to create exception check: %w", err)
}
}
return nil
}
// GetExceptionChecks retrieves exception checks for a DSR
func (s *DSRService) GetExceptionChecks(ctx context.Context, requestID uuid.UUID) ([]models.DSRExceptionCheck, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, request_id, exception_type, description, applies, checked_by, checked_at, notes, created_at
FROM dsr_exception_checks WHERE request_id = $1 ORDER BY created_at
`, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query exception checks: %w", err)
}
defer rows.Close()
var checks []models.DSRExceptionCheck
for rows.Next() {
var c models.DSRExceptionCheck
err := rows.Scan(&c.ID, &c.RequestID, &c.ExceptionType, &c.Description, &c.Applies,
&c.CheckedBy, &c.CheckedAt, &c.Notes, &c.CreatedAt)
if err != nil {
continue
}
checks = append(checks, c)
}
return checks, nil
}
// UpdateExceptionCheck updates an exception check
func (s *DSRService) UpdateExceptionCheck(ctx context.Context, checkID uuid.UUID, applies bool, notes *string, checkedBy uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE dsr_exception_checks
SET applies = $1, notes = $2, checked_by = $3, checked_at = NOW()
WHERE id = $4
`, applies, notes, checkedBy, checkID)
return err
}
// ProcessDeadlines checks for approaching and overdue deadlines
func (s *DSRService) ProcessDeadlines(ctx context.Context) error {
now := time.Now()
// Find requests with deadlines in 3 days
threeDaysAhead := now.AddDate(0, 0, 3)
rows, _ := s.pool.Query(ctx, `
SELECT id, request_number, request_type, assigned_to, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
AND status NOT IN ('completed', 'rejected', 'cancelled')
`, now, threeDaysAhead)
for rows.Next() {
var id uuid.UUID
var requestNumber, requestType string
var assignedTo *uuid.UUID
var deadline time.Time
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
// Notify assigned user or all DPOs
if assignedTo != nil {
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 3)
} else {
s.notifyAllDPOs(ctx, id, requestNumber, "Frist in 3 Tagen", deadline)
}
}
rows.Close()
// Find requests with deadlines in 1 day
oneDayAhead := now.AddDate(0, 0, 1)
rows, _ = s.pool.Query(ctx, `
SELECT id, request_number, request_type, assigned_to, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
AND status NOT IN ('completed', 'rejected', 'cancelled')
`, now, oneDayAhead)
for rows.Next() {
var id uuid.UUID
var requestNumber, requestType string
var assignedTo *uuid.UUID
var deadline time.Time
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
if assignedTo != nil {
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 1)
} else {
s.notifyAllDPOs(ctx, id, requestNumber, "Frist morgen!", deadline)
}
}
rows.Close()
// Find overdue requests
rows, _ = s.pool.Query(ctx, `
SELECT id, request_number, request_type, assigned_to, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) < $1
AND status NOT IN ('completed', 'rejected', 'cancelled')
`, now)
for rows.Next() {
var id uuid.UUID
var requestNumber, requestType string
var assignedTo *uuid.UUID
var deadline time.Time
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
// Notify all DPOs for overdue
s.notifyAllDPOs(ctx, id, requestNumber, "ÜBERFÄLLIG!", deadline)
// Log to audit
s.pool.Exec(ctx, `
INSERT INTO consent_audit_log (action, entity_type, entity_id, details)
VALUES ('dsr_overdue', 'dsr', $1, $2)
`, id, fmt.Sprintf(`{"request_number": "%s", "deadline": "%s"}`, requestNumber, deadline.Format(time.RFC3339)))
}
rows.Close()
return nil
}
// Helper functions
func (s *DSRService) recordStatusChange(ctx context.Context, requestID uuid.UUID, fromStatus *models.DSRStatus, toStatus models.DSRStatus, changedBy *uuid.UUID, comment string) {
s.pool.Exec(ctx, `
INSERT INTO dsr_status_history (request_id, from_status, to_status, changed_by, comment)
VALUES ($1, $2, $3, $4, $5)
`, requestID, fromStatus, toStatus, changedBy, comment)
}
func (s *DSRService) notifyNewRequest(ctx context.Context, dsr *models.DataSubjectRequest) {
if s.notificationService == nil {
return
}
// Notify all DPOs
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
defer rows.Close()
for rows.Next() {
var userID uuid.UUID
rows.Scan(&userID)
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRReceived,
"Neue Betroffenenanfrage",
fmt.Sprintf("Neue %s eingegangen: %s", dsr.RequestType.Label(), dsr.RequestNumber),
map[string]interface{}{"dsr_id": dsr.ID, "request_number": dsr.RequestNumber})
}
}
func (s *DSRService) notifyAssignment(ctx context.Context, dsrID, assigneeID uuid.UUID) {
if s.notificationService == nil {
return
}
dsr, _ := s.GetByID(ctx, dsrID)
if dsr != nil {
s.notificationService.CreateNotification(ctx, assigneeID, NotificationTypeDSRAssigned,
"Betroffenenanfrage zugewiesen",
fmt.Sprintf("Ihnen wurde die Anfrage %s zugewiesen", dsr.RequestNumber),
map[string]interface{}{"dsr_id": dsrID, "request_number": dsr.RequestNumber})
}
}
func (s *DSRService) notifyDeadlineWarning(ctx context.Context, dsrID, userID uuid.UUID, requestNumber string, deadline time.Time, daysLeft int) {
if s.notificationService == nil {
return
}
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
fmt.Sprintf("Fristwarnung: %s", requestNumber),
fmt.Sprintf("Die Frist für %s läuft in %d Tag(en) ab (%s)", requestNumber, daysLeft, deadline.Format("02.01.2006")),
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline, "days_left": daysLeft})
}
func (s *DSRService) notifyAllDPOs(ctx context.Context, dsrID uuid.UUID, requestNumber, message string, deadline time.Time) {
if s.notificationService == nil {
return
}
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
defer rows.Close()
for rows.Next() {
var userID uuid.UUID
rows.Scan(&userID)
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
fmt.Sprintf("%s: %s", message, requestNumber),
fmt.Sprintf("Anfrage %s: %s (Frist: %s)", requestNumber, message, deadline.Format("02.01.2006")),
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline})
}
}
func isValidRequestType(rt models.DSRRequestType) bool {
switch rt {
case models.DSRTypeAccess, models.DSRTypeRectification, models.DSRTypeErasure,
models.DSRTypeRestriction, models.DSRTypePortability:
return true
}
return false
}
func isValidStatusTransition(from, to models.DSRStatus) bool {
validTransitions := map[models.DSRStatus][]models.DSRStatus{
models.DSRStatusIntake: {models.DSRStatusIdentityVerification, models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
models.DSRStatusIdentityVerification: {models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
models.DSRStatusProcessing: {models.DSRStatusCompleted, models.DSRStatusRejected, models.DSRStatusCancelled},
models.DSRStatusCompleted: {},
models.DSRStatusRejected: {},
models.DSRStatusCancelled: {},
}
allowed, exists := validTransitions[from]
if !exists {
return false
}
for _, s := range allowed {
if s == to {
return true
}
}
return false
}
func stringOrDefault(s *string, def string) string {
if s != nil {
return *s
}
return def
}
func replaceVariables(text string, variables map[string]string) string {
for k, v := range variables {
text = strings.ReplaceAll(text, "{{"+k+"}}", v)
}
return text
}
func stripHTML(html string) string {
// Simple HTML stripping - in production use a proper library
text := strings.ReplaceAll(html, "<br>", "\n")
text = strings.ReplaceAll(text, "<br/>", "\n")
text = strings.ReplaceAll(text, "<br />", "\n")
text = strings.ReplaceAll(text, "</p>", "\n\n")
// Remove all remaining tags
for {
start := strings.Index(text, "<")
if start == -1 {
break
}
end := strings.Index(text[start:], ">")
if end == -1 {
break
}
text = text[:start] + text[start+end+1:]
}
return strings.TrimSpace(text)
}
@@ -0,0 +1,420 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
)
// TestDSRRequestTypeLabel tests label generation for request types
func TestDSRRequestTypeLabel(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
expected string
}{
{"access type", models.DSRTypeAccess, "Auskunftsanfrage (Art. 15)"},
{"rectification type", models.DSRTypeRectification, "Berichtigungsanfrage (Art. 16)"},
{"erasure type", models.DSRTypeErasure, "Löschanfrage (Art. 17)"},
{"restriction type", models.DSRTypeRestriction, "Einschränkungsanfrage (Art. 18)"},
{"portability type", models.DSRTypePortability, "Datenübertragung (Art. 20)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.reqType.Label()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
// TestDSRRequestTypeDeadlineDays tests deadline calculation for different request types
func TestDSRRequestTypeDeadlineDays(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
expectedDays int
}{
{"access has 30 days", models.DSRTypeAccess, 30},
{"portability has 30 days", models.DSRTypePortability, 30},
{"rectification has 14 days", models.DSRTypeRectification, 14},
{"erasure has 14 days", models.DSRTypeErasure, 14},
{"restriction has 14 days", models.DSRTypeRestriction, 14},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.reqType.DeadlineDays()
if result != tt.expectedDays {
t.Errorf("Expected %d days, got %d", tt.expectedDays, result)
}
})
}
}
// TestDSRRequestTypeIsExpedited tests expedited flag for request types
func TestDSRRequestTypeIsExpedited(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
isExpedited bool
}{
{"access not expedited", models.DSRTypeAccess, false},
{"portability not expedited", models.DSRTypePortability, false},
{"rectification is expedited", models.DSRTypeRectification, true},
{"erasure is expedited", models.DSRTypeErasure, true},
{"restriction is expedited", models.DSRTypeRestriction, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.reqType.IsExpedited()
if result != tt.isExpedited {
t.Errorf("Expected IsExpedited=%v, got %v", tt.isExpedited, result)
}
})
}
}
// TestDSRStatusLabel tests label generation for statuses
func TestDSRStatusLabel(t *testing.T) {
tests := []struct {
name string
status models.DSRStatus
expected string
}{
{"intake status", models.DSRStatusIntake, "Eingang"},
{"identity verification", models.DSRStatusIdentityVerification, "Identitätsprüfung"},
{"processing status", models.DSRStatusProcessing, "In Bearbeitung"},
{"completed status", models.DSRStatusCompleted, "Abgeschlossen"},
{"rejected status", models.DSRStatusRejected, "Abgelehnt"},
{"cancelled status", models.DSRStatusCancelled, "Storniert"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.status.Label()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
// TestValidDSRRequestType tests request type validation
func TestValidDSRRequestType(t *testing.T) {
tests := []struct {
name string
reqType string
valid bool
}{
{"valid access", "access", true},
{"valid rectification", "rectification", true},
{"valid erasure", "erasure", true},
{"valid restriction", "restriction", true},
{"valid portability", "portability", true},
{"invalid type", "invalid", false},
{"empty type", "", false},
{"random string", "delete_everything", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := models.IsValidDSRRequestType(tt.reqType)
if result != tt.valid {
t.Errorf("Expected IsValidDSRRequestType=%v for %s, got %v", tt.valid, tt.reqType, result)
}
})
}
}
// TestValidDSRStatus tests status validation
func TestValidDSRStatus(t *testing.T) {
tests := []struct {
name string
status string
valid bool
}{
{"valid intake", "intake", true},
{"valid identity_verification", "identity_verification", true},
{"valid processing", "processing", true},
{"valid completed", "completed", true},
{"valid rejected", "rejected", true},
{"valid cancelled", "cancelled", true},
{"invalid status", "invalid", false},
{"empty status", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := models.IsValidDSRStatus(tt.status)
if result != tt.valid {
t.Errorf("Expected IsValidDSRStatus=%v for %s, got %v", tt.valid, tt.status, result)
}
})
}
}
// TestDSRStatusTransitionValidation tests allowed status transitions
func TestDSRStatusTransitionValidation(t *testing.T) {
tests := []struct {
name string
fromStatus models.DSRStatus
toStatus models.DSRStatus
allowed bool
}{
// From intake
{"intake to identity_verification", models.DSRStatusIntake, models.DSRStatusIdentityVerification, true},
{"intake to processing", models.DSRStatusIntake, models.DSRStatusProcessing, true},
{"intake to rejected", models.DSRStatusIntake, models.DSRStatusRejected, true},
{"intake to cancelled", models.DSRStatusIntake, models.DSRStatusCancelled, true},
{"intake to completed invalid", models.DSRStatusIntake, models.DSRStatusCompleted, false},
// From identity_verification
{"identity to processing", models.DSRStatusIdentityVerification, models.DSRStatusProcessing, true},
{"identity to rejected", models.DSRStatusIdentityVerification, models.DSRStatusRejected, true},
{"identity to cancelled", models.DSRStatusIdentityVerification, models.DSRStatusCancelled, true},
// From processing
{"processing to completed", models.DSRStatusProcessing, models.DSRStatusCompleted, true},
{"processing to rejected", models.DSRStatusProcessing, models.DSRStatusRejected, true},
{"processing to intake invalid", models.DSRStatusProcessing, models.DSRStatusIntake, false},
// From completed
{"completed to anything invalid", models.DSRStatusCompleted, models.DSRStatusProcessing, false},
// From rejected
{"rejected to anything invalid", models.DSRStatusRejected, models.DSRStatusProcessing, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := testIsValidStatusTransition(tt.fromStatus, tt.toStatus)
if result != tt.allowed {
t.Errorf("Expected transition %s->%s allowed=%v, got %v",
tt.fromStatus, tt.toStatus, tt.allowed, result)
}
})
}
}
// testIsValidStatusTransition is a test helper for validating status transitions
// This mirrors the logic in dsr_service.go for testing purposes
func testIsValidStatusTransition(from, to models.DSRStatus) bool {
validTransitions := map[models.DSRStatus][]models.DSRStatus{
models.DSRStatusIntake: {
models.DSRStatusIdentityVerification,
models.DSRStatusProcessing,
models.DSRStatusRejected,
models.DSRStatusCancelled,
},
models.DSRStatusIdentityVerification: {
models.DSRStatusProcessing,
models.DSRStatusRejected,
models.DSRStatusCancelled,
},
models.DSRStatusProcessing: {
models.DSRStatusCompleted,
models.DSRStatusRejected,
models.DSRStatusCancelled,
},
models.DSRStatusCompleted: {},
models.DSRStatusRejected: {},
models.DSRStatusCancelled: {},
}
allowed, exists := validTransitions[from]
if !exists {
return false
}
for _, s := range allowed {
if s == to {
return true
}
}
return false
}
// TestCalculateDeadline tests deadline calculation
func TestCalculateDeadline(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
expectedDays int
}{
{"access 30 days", models.DSRTypeAccess, 30},
{"erasure 14 days", models.DSRTypeErasure, 14},
{"rectification 14 days", models.DSRTypeRectification, 14},
{"restriction 14 days", models.DSRTypeRestriction, 14},
{"portability 30 days", models.DSRTypePortability, 30},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
now := time.Now()
deadline := now.AddDate(0, 0, tt.expectedDays)
days := tt.reqType.DeadlineDays()
if days != tt.expectedDays {
t.Errorf("Expected %d days, got %d", tt.expectedDays, days)
}
// Verify deadline is approximately correct (within 1 day due to test timing)
calculatedDeadline := now.AddDate(0, 0, days)
diff := calculatedDeadline.Sub(deadline)
if diff > time.Hour*24 || diff < -time.Hour*24 {
t.Errorf("Deadline calculation off by more than a day")
}
})
}
}
// TestCreateDSRRequest_Validation tests validation of create request
func TestCreateDSRRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.CreateDSRRequest
expectError bool
}{
{
name: "valid access request",
request: models.CreateDSRRequest{
RequestType: "access",
RequesterEmail: "test@example.com",
},
expectError: false,
},
{
name: "valid erasure request with name",
request: models.CreateDSRRequest{
RequestType: "erasure",
RequesterEmail: "test@example.com",
RequesterName: stringPtr("Max Mustermann"),
},
expectError: false,
},
{
name: "missing email",
request: models.CreateDSRRequest{
RequestType: "access",
},
expectError: true,
},
{
name: "invalid request type",
request: models.CreateDSRRequest{
RequestType: "invalid_type",
RequesterEmail: "test@example.com",
},
expectError: true,
},
{
name: "empty request type",
request: models.CreateDSRRequest{
RequestType: "",
RequesterEmail: "test@example.com",
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := testValidateCreateDSRRequest(tt.request)
hasError := err != nil
if hasError != tt.expectError {
t.Errorf("Expected error=%v, got error=%v (err: %v)", tt.expectError, hasError, err)
}
})
}
}
// testValidateCreateDSRRequest is a test helper for validating create DSR requests
func testValidateCreateDSRRequest(req models.CreateDSRRequest) error {
if req.RequesterEmail == "" {
return &dsrValidationError{"requester_email is required"}
}
if !models.IsValidDSRRequestType(req.RequestType) {
return &dsrValidationError{"invalid request_type"}
}
return nil
}
type dsrValidationError struct {
Message string
}
func (e *dsrValidationError) Error() string {
return e.Message
}
// TestDSRTemplateTypes tests the template types
func TestDSRTemplateTypes(t *testing.T) {
expectedTemplates := []string{
"dsr_receipt_access",
"dsr_receipt_rectification",
"dsr_receipt_erasure",
"dsr_receipt_restriction",
"dsr_receipt_portability",
"dsr_identity_request",
"dsr_processing_started",
"dsr_processing_update",
"dsr_clarification_request",
"dsr_completed_access",
"dsr_completed_rectification",
"dsr_completed_erasure",
"dsr_completed_restriction",
"dsr_completed_portability",
"dsr_restriction_lifted",
"dsr_rejected_identity",
"dsr_rejected_exception",
"dsr_rejected_unfounded",
"dsr_deadline_warning",
}
// This test documents the expected template types
// The actual templates are created in database migration
for _, template := range expectedTemplates {
if template == "" {
t.Error("Template type should not be empty")
}
}
if len(expectedTemplates) != 19 {
t.Errorf("Expected 19 template types, got %d", len(expectedTemplates))
}
}
// TestErasureExceptionTypes tests Art. 17(3) exception types
func TestErasureExceptionTypes(t *testing.T) {
exceptions := []struct {
code string
description string
}{
{"art_17_3_a", "Meinungs- und Informationsfreiheit"},
{"art_17_3_b", "Rechtliche Verpflichtung"},
{"art_17_3_c", "Öffentliches Interesse im Gesundheitsbereich"},
{"art_17_3_d", "Archivzwecke, wissenschaftliche/historische Forschung"},
{"art_17_3_e", "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen"},
}
if len(exceptions) != 5 {
t.Errorf("Expected 5 Art. 17(3) exceptions, got %d", len(exceptions))
}
for _, ex := range exceptions {
if ex.code == "" || ex.description == "" {
t.Error("Exception code and description should not be empty")
}
}
}
// stringPtr returns a pointer to the given string
func stringPtr(s string) *string {
return &s
}
@@ -0,0 +1,554 @@
package services
import (
"bytes"
"fmt"
"html/template"
"net/smtp"
"strings"
)
// EmailConfig holds SMTP configuration
type EmailConfig struct {
Host string
Port int
Username string
Password string
FromName string
FromAddr string
BaseURL string // Frontend URL for links
}
// EmailService handles sending emails
type EmailService struct {
config EmailConfig
}
// NewEmailService creates a new EmailService
func NewEmailService(config EmailConfig) *EmailService {
return &EmailService{config: config}
}
// SendEmail sends an email
func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
// Build MIME message
var msg bytes.Buffer
msg.WriteString(fmt.Sprintf("From: %s <%s>\r\n", s.config.FromName, s.config.FromAddr))
msg.WriteString(fmt.Sprintf("To: %s\r\n", to))
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
msg.WriteString("MIME-Version: 1.0\r\n")
msg.WriteString("Content-Type: multipart/alternative; boundary=\"boundary42\"\r\n")
msg.WriteString("\r\n")
// Text part
msg.WriteString("--boundary42\r\n")
msg.WriteString("Content-Type: text/plain; charset=\"UTF-8\"\r\n")
msg.WriteString("\r\n")
msg.WriteString(textBody)
msg.WriteString("\r\n")
// HTML part
msg.WriteString("--boundary42\r\n")
msg.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n")
msg.WriteString("\r\n")
msg.WriteString(htmlBody)
msg.WriteString("\r\n")
msg.WriteString("--boundary42--\r\n")
// Send email
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
auth := smtp.PlainAuth("", s.config.Username, s.config.Password, s.config.Host)
err := smtp.SendMail(addr, auth, s.config.FromAddr, []string{to}, msg.Bytes())
if err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
return nil
}
// SendVerificationEmail sends an email verification email
func (s *EmailService) SendVerificationEmail(to, name, token string) error {
verifyLink := fmt.Sprintf("%s/verify-email?token=%s", s.config.BaseURL, token)
subject := "Bitte bestätigen Sie Ihre E-Mail-Adresse - BreakPilot"
textBody := fmt.Sprintf(`Hallo %s,
Willkommen bei BreakPilot!
Bitte bestätigen Sie Ihre E-Mail-Adresse, indem Sie den folgenden Link öffnen:
%s
Dieser Link ist 24 Stunden gültig.
Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), verifyLink)
htmlBody := s.renderTemplate("verification", map[string]interface{}{
"Name": getDisplayName(name),
"VerifyLink": verifyLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendPasswordResetEmail sends a password reset email
func (s *EmailService) SendPasswordResetEmail(to, name, token string) error {
resetLink := fmt.Sprintf("%s/reset-password?token=%s", s.config.BaseURL, token)
subject := "Passwort zurücksetzen - BreakPilot"
textBody := fmt.Sprintf(`Hallo %s,
Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.
Klicken Sie auf den folgenden Link, um Ihr Passwort zurückzusetzen:
%s
Dieser Link ist 1 Stunde gültig.
Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), resetLink)
htmlBody := s.renderTemplate("password_reset", map[string]interface{}{
"Name": getDisplayName(name),
"ResetLink": resetLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendNewVersionNotification sends a notification about new document version
func (s *EmailService) SendNewVersionNotification(to, name, documentName, documentType string, deadlineDays int) error {
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
subject := fmt.Sprintf("Neue Version: %s - Bitte bestätigen Sie innerhalb von %d Tagen", documentName, deadlineDays)
textBody := fmt.Sprintf(`Hallo %s,
Wir haben unsere %s aktualisiert.
Bitte lesen und bestätigen Sie die neuen Bedingungen innerhalb der nächsten %d Tage:
%s
Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), documentName, deadlineDays, consentLink)
htmlBody := s.renderTemplate("new_version", map[string]interface{}{
"Name": getDisplayName(name),
"DocumentName": documentName,
"DeadlineDays": deadlineDays,
"ConsentLink": consentLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendConsentReminder sends a consent reminder email
func (s *EmailService) SendConsentReminder(to, name string, documents []string, daysLeft int) error {
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
urgency := "Erinnerung"
if daysLeft <= 7 {
urgency = "Dringend"
}
if daysLeft <= 2 {
urgency = "Letzte Warnung"
}
subject := fmt.Sprintf("%s: Noch %d Tage um ausstehende Dokumente zu bestätigen", urgency, daysLeft)
docList := strings.Join(documents, "\n- ")
textBody := fmt.Sprintf(`Hallo %s,
Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.
Ausstehende Dokumente:
- %s
Sie haben noch %d Tage Zeit. Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
Bitte bestätigen Sie hier:
%s
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), docList, daysLeft, consentLink)
htmlBody := s.renderTemplate("reminder", map[string]interface{}{
"Name": getDisplayName(name),
"Documents": documents,
"DaysLeft": daysLeft,
"Urgency": urgency,
"ConsentLink": consentLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendAccountSuspendedNotification sends notification when account is suspended
func (s *EmailService) SendAccountSuspendedNotification(to, name string, documents []string) error {
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
subject := "Ihr Account wurde vorübergehend gesperrt - BreakPilot"
docList := strings.Join(documents, "\n- ")
textBody := fmt.Sprintf(`Hallo %s,
Ihr Account wurde vorübergehend gesperrt, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben:
- %s
Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:
%s
Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), docList, consentLink)
htmlBody := s.renderTemplate("suspended", map[string]interface{}{
"Name": getDisplayName(name),
"Documents": documents,
"ConsentLink": consentLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendAccountReactivatedNotification sends notification when account is reactivated
func (s *EmailService) SendAccountReactivatedNotification(to, name string) error {
appLink := fmt.Sprintf("%s/app", s.config.BaseURL)
subject := "Ihr Account wurde wieder aktiviert - BreakPilot"
textBody := fmt.Sprintf(`Hallo %s,
Vielen Dank für die Bestätigung der rechtlichen Dokumente!
Ihr Account wurde wieder aktiviert und Sie können BreakPilot wie gewohnt nutzen:
%s
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), appLink)
htmlBody := s.renderTemplate("reactivated", map[string]interface{}{
"Name": getDisplayName(name),
"AppLink": appLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// renderTemplate renders an email HTML template
func (s *EmailService) renderTemplate(templateName string, data map[string]interface{}) string {
templates := map[string]string{
"verification": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Willkommen bei BreakPilot!</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Vielen Dank für Ihre Registrierung! Bitte bestätigen Sie Ihre E-Mail-Adresse, um Ihr Konto zu aktivieren.</p>
<p style="text-align: center;">
<a href="{{.VerifyLink}}" class="button">E-Mail bestätigen</a>
</p>
<p>Dieser Link ist 24 Stunden gültig.</p>
<p>Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"password_reset": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Passwort zurücksetzen</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.</p>
<p style="text-align: center;">
<a href="{{.ResetLink}}" class="button">Passwort zurücksetzen</a>
</p>
<div class="warning">
<strong>Hinweis:</strong> Dieser Link ist nur 1 Stunde gültig.
</div>
<p>Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"new_version": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.info-box { background: #e0e7ff; border-left: 4px solid #6366f1; padding: 12px; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Neue Version: {{.DocumentName}}</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Wir haben unsere <strong>{{.DocumentName}}</strong> aktualisiert.</p>
<div class="info-box">
<strong>Wichtig:</strong> Bitte bestätigen Sie die neuen Bedingungen innerhalb der nächsten <strong>{{.DeadlineDays}} Tage</strong>.
</div>
<p style="text-align: center;">
<a href="{{.ConsentLink}}" class="button">Dokument ansehen & bestätigen</a>
</p>
<p>Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"reminder": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #f59e0b; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>{{.Urgency}}: Ausstehende Bestätigungen</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.</p>
<div class="doc-list">
<strong>Ausstehende Dokumente:</strong>
<ul>
{{range .Documents}}<li>{{.}}</li>{{end}}
</ul>
</div>
<div class="warning">
<strong>Sie haben noch {{.DaysLeft}} Tage Zeit.</strong> Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
</div>
<p style="text-align: center;">
<a href="{{.ConsentLink}}" class="button">Jetzt bestätigen</a>
</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"suspended": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.alert { background: #fee2e2; border-left: 4px solid #ef4444; padding: 12px; margin: 20px 0; }
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Account vorübergehend gesperrt</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<div class="alert">
<strong>Ihr Account wurde vorübergehend gesperrt</strong>, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben.
</div>
<div class="doc-list">
<strong>Nicht bestätigte Dokumente:</strong>
<ul>
{{range .Documents}}<li>{{.}}</li>{{end}}
</ul>
</div>
<p>Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:</p>
<p style="text-align: center;">
<a href="{{.ConsentLink}}" class="button">Dokumente bestätigen & Account entsperren</a>
</p>
<p>Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"reactivated": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #22c55e, #16a34a); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #22c55e; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.success { background: #dcfce7; border-left: 4px solid #22c55e; padding: 12px; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Account wieder aktiviert!</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<div class="success">
<strong>Vielen Dank!</strong> Ihr Account wurde erfolgreich wieder aktiviert.
</div>
<p>Sie können BreakPilot ab sofort wieder wie gewohnt nutzen.</p>
<p style="text-align: center;">
<a href="{{.AppLink}}" class="button">Zu BreakPilot</a>
</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"generic_notification": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>{{.Title}}</h1>
</div>
<div class="content">
<p>{{.Body}}</p>
<p style="text-align: center;">
<a href="{{.BaseURL}}/app" class="button">Zu BreakPilot</a>
</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
}
tmplStr, ok := templates[templateName]
if !ok {
return ""
}
tmpl, err := template.New(templateName).Parse(tmplStr)
if err != nil {
return ""
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return ""
}
return buf.String()
}
// SendConsentReminderEmail sends a simplified consent reminder email
func (s *EmailService) SendConsentReminderEmail(to, title, body string) error {
subject := title
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
"Title": title,
"Body": body,
"BaseURL": s.config.BaseURL,
})
return s.SendEmail(to, subject, htmlBody, body)
}
// SendGenericNotificationEmail sends a generic notification email
func (s *EmailService) SendGenericNotificationEmail(to, title, body string) error {
subject := title
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
"Title": title,
"Body": body,
"BaseURL": s.config.BaseURL,
})
return s.SendEmail(to, subject, htmlBody, body)
}
// getDisplayName returns display name or fallback
func getDisplayName(name string) string {
if name != "" {
return name
}
return "Nutzer"
}

Some files were not shown because too many files have changed in this diff Show More