Initial commit: breakpilot-core - Shared Infrastructure
Docker Compose with 24+ services: - PostgreSQL (PostGIS), Valkey, MinIO, Qdrant - Vault (PKI/TLS), Nginx (Reverse Proxy) - Backend Core API, Consent Service, Billing Service - RAG Service, Embedding Service - Gitea, Woodpecker CI/CD - Night Scheduler, Health Aggregator - Jitsi (Web/XMPP/JVB/Jicofo), Mailpit Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
# BreakPilot Core — Shared Infrastructure
|
||||
|
||||
## Entwicklungsumgebung
|
||||
|
||||
### Zwei-Rechner-Setup
|
||||
| Gerät | Rolle |
|
||||
|-------|-------|
|
||||
| **MacBook** | Client/Terminal |
|
||||
| **Mac Mini** | Server/Docker/Git |
|
||||
|
||||
```bash
|
||||
ssh macmini "cd /Users/benjaminadmin/Projekte/breakpilot-core && <cmd>"
|
||||
```
|
||||
|
||||
## Projektübersicht
|
||||
|
||||
**breakpilot-core** ist das Infrastruktur-Projekt der BreakPilot-Plattform. Es stellt alle gemeinsamen Services bereit, die von **breakpilot-lehrer** und **breakpilot-compliance** genutzt werden.
|
||||
|
||||
### Enthaltene Services (~28 Container)
|
||||
|
||||
| Service | Port | Beschreibung |
|
||||
|---------|------|--------------|
|
||||
| nginx | 80/443 | Reverse Proxy (SSL) |
|
||||
| postgres | 5432 | PostGIS 16 (3 Schemas: core, lehrer, compliance) |
|
||||
| valkey | 6379 | Session-Cache |
|
||||
| vault | 8200 | Secrets Management |
|
||||
| qdrant | 6333 | Vektordatenbank |
|
||||
| minio | 9000 | S3 Storage |
|
||||
| backend-core | 8000 | Shared APIs (Auth, RBAC, Notifications) |
|
||||
| rag-service | 8097 | RAG: Dokumente, Suche, Embeddings |
|
||||
| embedding-service | 8087 | Text-Embeddings |
|
||||
| consent-service | 8081 | Consent-Management |
|
||||
| health-aggregator | 8099 | Health-Check aller Services |
|
||||
| gitea | 3003 | Git-Server |
|
||||
| woodpecker | 8090 | CI/CD |
|
||||
| camunda | 8089 | BPMN |
|
||||
| synapse | 8008 | Matrix Chat |
|
||||
| jitsi | 8443 | Video |
|
||||
| mailpit | 8025 | E-Mail (Dev) |
|
||||
|
||||
### Docker-Netzwerk
|
||||
Alle 3 Projekte teilen sich das `breakpilot-network`:
|
||||
```yaml
|
||||
networks:
|
||||
breakpilot-network:
|
||||
driver: bridge
|
||||
name: breakpilot-network
|
||||
```
|
||||
|
||||
### Start-Reihenfolge
|
||||
```bash
|
||||
# 1. Core MUSS zuerst starten
|
||||
docker compose up -d
|
||||
# 2. Dann Lehrer und Compliance (warten auf Core Health)
|
||||
```
|
||||
|
||||
### DB-Schemas
|
||||
- `core` — users, sessions, auth, rbac, notifications
|
||||
- `lehrer` — classroom, units, klausuren, game
|
||||
- `compliance` — compliance, dsr, gdpr, sdk
|
||||
|
||||
## Git Remotes
|
||||
Immer zu BEIDEN pushen:
|
||||
- `origin`: lokale Gitea (macmini:3003)
|
||||
- `gitea`: gitea.meghsakha.com
|
||||
@@ -0,0 +1,58 @@
|
||||
# =========================================================
|
||||
# BreakPilot Core — Environment Variables
|
||||
# =========================================================
|
||||
# Copy to .env and adjust values
|
||||
|
||||
# Database
|
||||
POSTGRES_USER=breakpilot
|
||||
POSTGRES_PASSWORD=breakpilot123
|
||||
POSTGRES_DB=breakpilot_db
|
||||
|
||||
# Security
|
||||
JWT_SECRET=your-super-secret-jwt-key-change-in-production
|
||||
VAULT_TOKEN=breakpilot-dev-token
|
||||
|
||||
# MinIO (S3-compatible storage)
|
||||
MINIO_ROOT_USER=breakpilot
|
||||
MINIO_ROOT_PASSWORD=breakpilot123
|
||||
MINIO_BUCKET=breakpilot-rag
|
||||
|
||||
# Environment
|
||||
ENVIRONMENT=development
|
||||
TZ=Europe/Berlin
|
||||
|
||||
# Embedding Service
|
||||
EMBEDDING_BACKEND=local
|
||||
LOCAL_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
LOCAL_RERANKER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
|
||||
# SMTP (Mailpit in dev)
|
||||
SMTP_HOST=mailpit
|
||||
SMTP_PORT=1025
|
||||
|
||||
# Synapse
|
||||
SYNAPSE_SERVER_NAME=macmini
|
||||
SYNAPSE_DB_PASSWORD=synapse_secret
|
||||
|
||||
# Jitsi
|
||||
JICOFO_AUTH_PASSWORD=jicofo_secret
|
||||
JVB_AUTH_PASSWORD=jvb_secret
|
||||
JIBRI_XMPP_PASSWORD=jibri_secret
|
||||
JIBRI_RECORDER_PASSWORD=recorder_secret
|
||||
JITSI_PUBLIC_URL=https://macmini:8443
|
||||
|
||||
# ERPNext
|
||||
ERPNEXT_DB_ROOT_PASSWORD=erpnext_root
|
||||
ERPNEXT_DB_PASSWORD=erpnext_secret
|
||||
ERPNEXT_ADMIN_PASSWORD=admin
|
||||
|
||||
# Woodpecker CI
|
||||
WOODPECKER_HOST=http://macmini:8090
|
||||
WOODPECKER_ADMIN=pilotadmin
|
||||
WOODPECKER_AGENT_SECRET=woodpecker-secret
|
||||
|
||||
# Gitea Runner
|
||||
GITEA_RUNNER_TOKEN=
|
||||
|
||||
# Session
|
||||
SESSION_TTL_HOURS=24
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.backup
|
||||
|
||||
# Secrets
|
||||
secrets/
|
||||
*.pem
|
||||
*.key
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
.next/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
venv/
|
||||
.venv/
|
||||
|
||||
# Docker
|
||||
backups/*.backup
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
.DS_Store
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Large files
|
||||
*.pdf
|
||||
*.docx
|
||||
*.xlsx
|
||||
*.pptx
|
||||
*.mp4
|
||||
*.mp3
|
||||
*.wav
|
||||
|
||||
# Compiled binaries
|
||||
billing-service/billing-service
|
||||
consent-service/server
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Large files
|
||||
*.zip
|
||||
*.gz
|
||||
*.tar
|
||||
*.sql.gz
|
||||
*.pdf
|
||||
*.docx
|
||||
*.xlsx
|
||||
*.pptx
|
||||
|
||||
# Coverage
|
||||
coverage/
|
||||
*.coverage
|
||||
@@ -0,0 +1,15 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
.git
|
||||
.env
|
||||
.env.*
|
||||
.pytest_cache
|
||||
venv
|
||||
.venv
|
||||
*.egg-info
|
||||
.DS_Store
|
||||
security-reports
|
||||
scripts
|
||||
tests
|
||||
docs
|
||||
@@ -0,0 +1,64 @@
|
||||
# ============================================================
|
||||
# BreakPilot Core Backend -- Multi-stage Docker build
|
||||
# ============================================================
|
||||
|
||||
# ---------- Build stage ----------
|
||||
FROM python:3.12-slim-bookworm AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Build-time system libs (needed for asyncpg / psycopg2)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# ---------- Runtime stage ----------
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Runtime system libs
|
||||
# - libpango / libgdk-pixbuf / shared-mime-info -> WeasyPrint (pdf_service)
|
||||
# - libgl1 / libglib2.0-0 -> OpenCV (file_processor)
|
||||
# - curl -> healthcheck
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libpango-1.0-0 \
|
||||
libpangocairo-1.0-0 \
|
||||
libgdk-pixbuf-2.0-0 \
|
||||
libffi-dev \
|
||||
shared-mime-info \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy virtualenv from builder
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Non-root user
|
||||
RUN useradd --create-home --shell /bin/bash appuser
|
||||
|
||||
# Copy application code
|
||||
COPY --chown=appuser:appuser . .
|
||||
|
||||
USER appuser
|
||||
|
||||
# Python tweaks
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://127.0.0.1:8000/health || exit 1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
BreakPilot Authentication Module
|
||||
|
||||
Hybrid authentication supporting both Keycloak and local JWT tokens.
|
||||
"""
|
||||
|
||||
from .keycloak_auth import (
|
||||
# Config
|
||||
KeycloakConfig,
|
||||
KeycloakUser,
|
||||
|
||||
# Authenticators
|
||||
KeycloakAuthenticator,
|
||||
HybridAuthenticator,
|
||||
|
||||
# Exceptions
|
||||
KeycloakAuthError,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
KeycloakConfigError,
|
||||
|
||||
# Factory functions
|
||||
get_keycloak_config_from_env,
|
||||
get_authenticator,
|
||||
get_auth,
|
||||
|
||||
# FastAPI dependencies
|
||||
get_current_user,
|
||||
require_role,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"KeycloakConfig",
|
||||
"KeycloakUser",
|
||||
|
||||
# Authenticators
|
||||
"KeycloakAuthenticator",
|
||||
"HybridAuthenticator",
|
||||
|
||||
# Exceptions
|
||||
"KeycloakAuthError",
|
||||
"TokenExpiredError",
|
||||
"TokenInvalidError",
|
||||
"KeycloakConfigError",
|
||||
|
||||
# Factory functions
|
||||
"get_keycloak_config_from_env",
|
||||
"get_authenticator",
|
||||
"get_auth",
|
||||
|
||||
# FastAPI dependencies
|
||||
"get_current_user",
|
||||
"require_role",
|
||||
]
|
||||
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
Keycloak Authentication Module
|
||||
|
||||
Implements token validation against Keycloak JWKS endpoint.
|
||||
This module handles authentication (who is the user?), while
|
||||
rbac.py handles authorization (what can the user do?).
|
||||
|
||||
Architecture:
|
||||
- Keycloak validates JWT tokens and provides basic identity
|
||||
- Our custom rbac.py handles fine-grained permissions
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeycloakConfig:
|
||||
"""Keycloak connection configuration."""
|
||||
server_url: str
|
||||
realm: str
|
||||
client_id: str
|
||||
client_secret: Optional[str] = None
|
||||
verify_ssl: bool = True
|
||||
|
||||
@property
|
||||
def issuer_url(self) -> str:
|
||||
return f"{self.server_url}/realms/{self.realm}"
|
||||
|
||||
@property
|
||||
def jwks_url(self) -> str:
|
||||
return f"{self.issuer_url}/protocol/openid-connect/certs"
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
return f"{self.issuer_url}/protocol/openid-connect/token"
|
||||
|
||||
@property
|
||||
def userinfo_url(self) -> str:
|
||||
return f"{self.issuer_url}/protocol/openid-connect/userinfo"
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeycloakUser:
|
||||
"""User information extracted from Keycloak token."""
|
||||
user_id: str # Keycloak subject (sub)
|
||||
email: str
|
||||
email_verified: bool
|
||||
name: Optional[str]
|
||||
given_name: Optional[str]
|
||||
family_name: Optional[str]
|
||||
realm_roles: List[str] # Keycloak realm roles
|
||||
client_roles: Dict[str, List[str]] # Client-specific roles
|
||||
groups: List[str] # Keycloak groups
|
||||
tenant_id: Optional[str] # Custom claim for school/tenant
|
||||
raw_claims: Dict[str, Any] # All claims for debugging
|
||||
|
||||
def has_realm_role(self, role: str) -> bool:
|
||||
"""Check if user has a specific realm role."""
|
||||
return role in self.realm_roles
|
||||
|
||||
def has_client_role(self, client_id: str, role: str) -> bool:
|
||||
"""Check if user has a specific client role."""
|
||||
client_roles = self.client_roles.get(client_id, [])
|
||||
return role in client_roles
|
||||
|
||||
def is_admin(self) -> bool:
|
||||
"""Check if user has admin role."""
|
||||
return self.has_realm_role("admin") or self.has_realm_role("schul_admin")
|
||||
|
||||
def is_teacher(self) -> bool:
|
||||
"""Check if user is a teacher."""
|
||||
return self.has_realm_role("teacher") or self.has_realm_role("lehrer")
|
||||
|
||||
|
||||
class KeycloakAuthError(Exception):
|
||||
"""Base exception for Keycloak authentication errors."""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(KeycloakAuthError):
|
||||
"""Token has expired."""
|
||||
pass
|
||||
|
||||
|
||||
class TokenInvalidError(KeycloakAuthError):
|
||||
"""Token is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class KeycloakConfigError(KeycloakAuthError):
|
||||
"""Keycloak configuration error."""
|
||||
pass
|
||||
|
||||
|
||||
class KeycloakAuthenticator:
|
||||
"""
|
||||
Validates JWT tokens against Keycloak.
|
||||
|
||||
Usage:
|
||||
config = KeycloakConfig(
|
||||
server_url="https://keycloak.example.com",
|
||||
realm="breakpilot",
|
||||
client_id="breakpilot-backend"
|
||||
)
|
||||
auth = KeycloakAuthenticator(config)
|
||||
|
||||
user = await auth.validate_token(token)
|
||||
if user.is_teacher():
|
||||
# Grant access
|
||||
"""
|
||||
|
||||
def __init__(self, config: KeycloakConfig):
|
||||
self.config = config
|
||||
self._jwks_client: Optional[PyJWKClient] = None
|
||||
self._http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
@property
|
||||
def jwks_client(self) -> PyJWKClient:
|
||||
"""Lazy-load JWKS client."""
|
||||
if self._jwks_client is None:
|
||||
self._jwks_client = PyJWKClient(
|
||||
self.config.jwks_url,
|
||||
cache_keys=True,
|
||||
lifespan=3600 # Cache keys for 1 hour
|
||||
)
|
||||
return self._jwks_client
|
||||
|
||||
async def get_http_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create async HTTP client."""
|
||||
if self._http_client is None or self._http_client.is_closed:
|
||||
self._http_client = httpx.AsyncClient(
|
||||
verify=self.config.verify_ssl,
|
||||
timeout=30.0
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client."""
|
||||
if self._http_client and not self._http_client.is_closed:
|
||||
await self._http_client.aclose()
|
||||
|
||||
def validate_token_sync(self, token: str) -> KeycloakUser:
|
||||
"""
|
||||
Synchronously validate a JWT token against Keycloak JWKS.
|
||||
|
||||
Args:
|
||||
token: The JWT access token
|
||||
|
||||
Returns:
|
||||
KeycloakUser with extracted claims
|
||||
|
||||
Raises:
|
||||
TokenExpiredError: If token has expired
|
||||
TokenInvalidError: If token signature is invalid
|
||||
"""
|
||||
try:
|
||||
# Get signing key from JWKS
|
||||
signing_key = self.jwks_client.get_signing_key_from_jwt(token)
|
||||
|
||||
# Decode and validate token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
audience=self.config.client_id,
|
||||
issuer=self.config.issuer_url,
|
||||
options={
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"verify_aud": True,
|
||||
"verify_iss": True
|
||||
}
|
||||
)
|
||||
|
||||
return self._extract_user(payload)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except jwt.InvalidAudienceError:
|
||||
raise TokenInvalidError("Invalid token audience")
|
||||
except jwt.InvalidIssuerError:
|
||||
raise TokenInvalidError("Invalid token issuer")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise TokenInvalidError(f"Invalid token: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Token validation failed: {e}")
|
||||
raise TokenInvalidError(f"Token validation failed: {e}")
|
||||
|
||||
async def validate_token(self, token: str) -> KeycloakUser:
|
||||
"""
|
||||
Asynchronously validate a JWT token.
|
||||
|
||||
Note: JWKS fetching is synchronous due to PyJWKClient limitations,
|
||||
but this wrapper allows async context usage.
|
||||
"""
|
||||
return self.validate_token_sync(token)
|
||||
|
||||
async def get_userinfo(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch user info from Keycloak userinfo endpoint.
|
||||
|
||||
This provides additional user claims not in the access token.
|
||||
"""
|
||||
client = await self.get_http_client()
|
||||
|
||||
try:
|
||||
response = await client.get(
|
||||
self.config.userinfo_url,
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise TokenExpiredError("Token is invalid or expired")
|
||||
raise TokenInvalidError(f"Failed to fetch userinfo: {e}")
|
||||
|
||||
def _extract_user(self, payload: Dict[str, Any]) -> KeycloakUser:
|
||||
"""Extract KeycloakUser from JWT payload."""
|
||||
|
||||
# Extract realm roles
|
||||
realm_access = payload.get("realm_access", {})
|
||||
realm_roles = realm_access.get("roles", [])
|
||||
|
||||
# Extract client roles
|
||||
resource_access = payload.get("resource_access", {})
|
||||
client_roles = {}
|
||||
for client_id, access in resource_access.items():
|
||||
client_roles[client_id] = access.get("roles", [])
|
||||
|
||||
# Extract groups
|
||||
groups = payload.get("groups", [])
|
||||
|
||||
# Extract custom tenant claim (if configured in Keycloak)
|
||||
tenant_id = payload.get("tenant_id") or payload.get("school_id")
|
||||
|
||||
return KeycloakUser(
|
||||
user_id=payload.get("sub", ""),
|
||||
email=payload.get("email", ""),
|
||||
email_verified=payload.get("email_verified", False),
|
||||
name=payload.get("name"),
|
||||
given_name=payload.get("given_name"),
|
||||
family_name=payload.get("family_name"),
|
||||
realm_roles=realm_roles,
|
||||
client_roles=client_roles,
|
||||
groups=groups,
|
||||
tenant_id=tenant_id,
|
||||
raw_claims=payload
|
||||
)
|
||||
|
||||
|
||||
# =============================================
|
||||
# HYBRID AUTH: Keycloak + Local JWT
|
||||
# =============================================
|
||||
|
||||
class HybridAuthenticator:
|
||||
"""
|
||||
Hybrid authenticator supporting both Keycloak and local JWT tokens.
|
||||
|
||||
This allows gradual migration from local JWT to Keycloak:
|
||||
1. Development: Use local JWT (fast, no external dependencies)
|
||||
2. Production: Use Keycloak for full IAM capabilities
|
||||
|
||||
Token type detection:
|
||||
- Keycloak tokens: Have 'iss' claim matching Keycloak URL
|
||||
- Local tokens: Have 'iss' claim as 'breakpilot' or no 'iss'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keycloak_config: Optional[KeycloakConfig] = None,
|
||||
local_jwt_secret: Optional[str] = None,
|
||||
environment: str = "development"
|
||||
):
|
||||
self.environment = environment
|
||||
self.keycloak_enabled = keycloak_config is not None
|
||||
self.local_jwt_secret = local_jwt_secret
|
||||
|
||||
if keycloak_config:
|
||||
self.keycloak_auth = KeycloakAuthenticator(keycloak_config)
|
||||
else:
|
||||
self.keycloak_auth = None
|
||||
|
||||
async def validate_token(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate token using appropriate method.
|
||||
|
||||
Returns a unified user dict compatible with existing code.
|
||||
"""
|
||||
if not token:
|
||||
raise TokenInvalidError("No token provided")
|
||||
|
||||
# Try to peek at the token to determine type
|
||||
try:
|
||||
# Decode without verification to check issuer
|
||||
unverified = jwt.decode(token, options={"verify_signature": False})
|
||||
issuer = unverified.get("iss", "")
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenInvalidError("Cannot decode token")
|
||||
|
||||
# Check if it's a Keycloak token
|
||||
if self.keycloak_auth and self.keycloak_auth.config.issuer_url in issuer:
|
||||
# Validate with Keycloak
|
||||
kc_user = await self.keycloak_auth.validate_token(token)
|
||||
return self._keycloak_user_to_dict(kc_user)
|
||||
|
||||
# Fall back to local JWT validation
|
||||
if self.local_jwt_secret:
|
||||
return self._validate_local_token(token)
|
||||
|
||||
raise TokenInvalidError("No valid authentication method available")
|
||||
|
||||
def _validate_local_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Validate token with local JWT secret."""
|
||||
if not self.local_jwt_secret:
|
||||
raise KeycloakConfigError("Local JWT secret not configured")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.local_jwt_secret,
|
||||
algorithms=["HS256"]
|
||||
)
|
||||
|
||||
# Map local token claims to unified format
|
||||
return {
|
||||
"user_id": payload.get("user_id", payload.get("sub", "")),
|
||||
"email": payload.get("email", ""),
|
||||
"name": payload.get("name", ""),
|
||||
"role": payload.get("role", "user"),
|
||||
"realm_roles": [payload.get("role", "user")],
|
||||
"tenant_id": payload.get("tenant_id", payload.get("school_id")),
|
||||
"auth_method": "local_jwt"
|
||||
}
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise TokenInvalidError(f"Invalid local token: {e}")
|
||||
|
||||
def _keycloak_user_to_dict(self, user: KeycloakUser) -> Dict[str, Any]:
|
||||
"""Convert KeycloakUser to dict compatible with existing code."""
|
||||
# Map Keycloak roles to our role system
|
||||
role = "user"
|
||||
if user.is_admin():
|
||||
role = "admin"
|
||||
elif user.is_teacher():
|
||||
role = "teacher"
|
||||
|
||||
return {
|
||||
"user_id": user.user_id,
|
||||
"email": user.email,
|
||||
"name": user.name or f"{user.given_name or ''} {user.family_name or ''}".strip(),
|
||||
"role": role,
|
||||
"realm_roles": user.realm_roles,
|
||||
"client_roles": user.client_roles,
|
||||
"groups": user.groups,
|
||||
"tenant_id": user.tenant_id,
|
||||
"email_verified": user.email_verified,
|
||||
"auth_method": "keycloak"
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Cleanup resources."""
|
||||
if self.keycloak_auth:
|
||||
await self.keycloak_auth.close()
|
||||
|
||||
|
||||
# =============================================
|
||||
# FACTORY FUNCTIONS
|
||||
# =============================================
|
||||
|
||||
def get_keycloak_config_from_env() -> Optional[KeycloakConfig]:
|
||||
"""
|
||||
Create KeycloakConfig from environment variables.
|
||||
|
||||
Required env vars:
|
||||
- KEYCLOAK_SERVER_URL: e.g., https://keycloak.breakpilot.app
|
||||
- KEYCLOAK_REALM: e.g., breakpilot
|
||||
- KEYCLOAK_CLIENT_ID: e.g., breakpilot-backend
|
||||
|
||||
Optional:
|
||||
- KEYCLOAK_CLIENT_SECRET: For confidential clients
|
||||
- KEYCLOAK_VERIFY_SSL: Default true
|
||||
"""
|
||||
server_url = os.environ.get("KEYCLOAK_SERVER_URL")
|
||||
realm = os.environ.get("KEYCLOAK_REALM")
|
||||
client_id = os.environ.get("KEYCLOAK_CLIENT_ID")
|
||||
|
||||
if not all([server_url, realm, client_id]):
|
||||
logger.info("Keycloak not configured, using local JWT only")
|
||||
return None
|
||||
|
||||
return KeycloakConfig(
|
||||
server_url=server_url,
|
||||
realm=realm,
|
||||
client_id=client_id,
|
||||
client_secret=os.environ.get("KEYCLOAK_CLIENT_SECRET"),
|
||||
verify_ssl=os.environ.get("KEYCLOAK_VERIFY_SSL", "true").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def get_authenticator() -> HybridAuthenticator:
|
||||
"""
|
||||
Get configured authenticator instance.
|
||||
|
||||
Uses environment variables to determine configuration.
|
||||
"""
|
||||
keycloak_config = get_keycloak_config_from_env()
|
||||
|
||||
# JWT_SECRET is required - no default fallback in production
|
||||
jwt_secret = os.environ.get("JWT_SECRET")
|
||||
environment = os.environ.get("ENVIRONMENT", "development")
|
||||
|
||||
if not jwt_secret and environment == "production":
|
||||
raise KeycloakConfigError(
|
||||
"JWT_SECRET environment variable is required in production"
|
||||
)
|
||||
|
||||
return HybridAuthenticator(
|
||||
keycloak_config=keycloak_config,
|
||||
local_jwt_secret=jwt_secret,
|
||||
environment=environment
|
||||
)
|
||||
|
||||
|
||||
# =============================================
|
||||
# FASTAPI DEPENDENCY
|
||||
# =============================================
|
||||
|
||||
from fastapi import Request, HTTPException, Depends
|
||||
|
||||
# Global authenticator instance (lazy-initialized)
|
||||
_authenticator: Optional[HybridAuthenticator] = None
|
||||
|
||||
|
||||
def get_auth() -> HybridAuthenticator:
|
||||
"""Get or create global authenticator."""
|
||||
global _authenticator
|
||||
if _authenticator is None:
|
||||
_authenticator = get_authenticator()
|
||||
return _authenticator
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency to get current authenticated user.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/protected")
|
||||
async def protected_endpoint(user: dict = Depends(get_current_user)):
|
||||
return {"user_id": user["user_id"]}
|
||||
"""
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
# Check for development mode
|
||||
environment = os.environ.get("ENVIRONMENT", "development")
|
||||
if environment == "development":
|
||||
# Return demo user in development without token
|
||||
return {
|
||||
"user_id": "10000000-0000-0000-0000-000000000024",
|
||||
"email": "demo@breakpilot.app",
|
||||
"role": "admin",
|
||||
"realm_roles": ["admin"],
|
||||
"tenant_id": "a0000000-0000-0000-0000-000000000001",
|
||||
"auth_method": "development_bypass"
|
||||
}
|
||||
raise HTTPException(status_code=401, detail="Missing authorization header")
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
|
||||
try:
|
||||
auth = get_auth()
|
||||
return await auth.validate_token(token)
|
||||
except TokenExpiredError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except TokenInvalidError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication failed: {e}")
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
|
||||
async def require_role(required_role: str):
|
||||
"""
|
||||
FastAPI dependency factory for role-based access.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/admin-only")
|
||||
async def admin_endpoint(user: dict = Depends(require_role("admin"))):
|
||||
return {"message": "Admin access granted"}
|
||||
"""
|
||||
async def role_checker(user: dict = Depends(get_current_user)) -> dict:
|
||||
user_role = user.get("role", "user")
|
||||
realm_roles = user.get("realm_roles", [])
|
||||
|
||||
if user_role == required_role or required_role in realm_roles:
|
||||
return user
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role '{required_role}' required"
|
||||
)
|
||||
|
||||
return role_checker
|
||||
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Authentication API Endpoints für BreakPilot
|
||||
Proxy für den Go Consent Service Authentication
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Header, Request, Response
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import os
|
||||
|
||||
# Consent Service URL
|
||||
CONSENT_SERVICE_URL = os.getenv("CONSENT_SERVICE_URL", "http://localhost:8081")
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
# ==========================================
|
||||
# Request/Response Models
|
||||
# ==========================================
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class VerifyEmailRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
|
||||
# ==========================================
|
||||
# Helper Functions
|
||||
# ==========================================
|
||||
|
||||
def get_auth_headers(authorization: Optional[str]) -> dict:
|
||||
"""Erstellt Header mit Authorization Token"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if authorization:
|
||||
headers["Authorization"] = authorization
|
||||
return headers
|
||||
|
||||
|
||||
async def proxy_to_consent_service(
|
||||
method: str,
|
||||
path: str,
|
||||
json_data: dict = None,
|
||||
headers: dict = None,
|
||||
params: dict = None
|
||||
) -> dict:
|
||||
"""
|
||||
Proxy-Aufruf zum Go Consent Service.
|
||||
Wirft HTTPException bei Fehlern.
|
||||
"""
|
||||
url = f"{CONSENT_SERVICE_URL}/api/v1{path}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params, timeout=10.0)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=json_data, timeout=10.0)
|
||||
elif method == "PUT":
|
||||
response = await client.put(url, headers=headers, json=json_data, timeout=10.0)
|
||||
elif method == "DELETE":
|
||||
response = await client.delete(url, headers=headers, params=params, timeout=10.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
data = response.json()
|
||||
except:
|
||||
data = {"message": response.text}
|
||||
|
||||
# Handle error responses
|
||||
if response.status_code >= 400:
|
||||
error_msg = data.get("error", "Unknown error")
|
||||
raise HTTPException(status_code=response.status_code, detail=error_msg)
|
||||
|
||||
return data
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Consent Service nicht erreichbar: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# Public Auth Endpoints (No Auth Required)
|
||||
# ==========================================
|
||||
|
||||
@router.post("/register")
|
||||
async def register(request: RegisterRequest, req: Request):
|
||||
"""
|
||||
Registriert einen neuen Benutzer.
|
||||
Sendet eine Verifizierungs-E-Mail.
|
||||
"""
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/register",
|
||||
json_data={
|
||||
"email": request.email,
|
||||
"password": request.password,
|
||||
"name": request.name
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(request: LoginRequest, req: Request):
|
||||
"""
|
||||
Meldet einen Benutzer an.
|
||||
Gibt Access Token und Refresh Token zurück.
|
||||
"""
|
||||
# Get client info for session tracking
|
||||
client_ip = req.client.host if req.client else "unknown"
|
||||
user_agent = req.headers.get("user-agent", "unknown")
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/login",
|
||||
json_data={
|
||||
"email": request.email,
|
||||
"password": request.password
|
||||
},
|
||||
headers={
|
||||
"X-Forwarded-For": client_ip,
|
||||
"User-Agent": user_agent
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(request: LogoutRequest):
|
||||
"""
|
||||
Meldet den Benutzer ab und invalidiert den Refresh Token.
|
||||
"""
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/logout",
|
||||
json_data={"refresh_token": request.refresh_token} if request.refresh_token else {}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(request: RefreshTokenRequest):
|
||||
"""
|
||||
Erneuert den Access Token mit einem gültigen Refresh Token.
|
||||
"""
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/refresh",
|
||||
json_data={"refresh_token": request.refresh_token}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/verify-email")
|
||||
async def verify_email(request: VerifyEmailRequest):
|
||||
"""
|
||||
Verifiziert die E-Mail-Adresse mit dem Token aus der E-Mail.
|
||||
"""
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/verify-email",
|
||||
json_data={"token": request.token}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/resend-verification")
|
||||
async def resend_verification(email: EmailStr):
|
||||
"""
|
||||
Sendet die Verifizierungs-E-Mail erneut.
|
||||
"""
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/resend-verification",
|
||||
json_data={"email": email}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/forgot-password")
|
||||
async def forgot_password(request: ForgotPasswordRequest, req: Request):
|
||||
"""
|
||||
Initiiert den Passwort-Reset-Prozess.
|
||||
Sendet eine E-Mail mit Reset-Link.
|
||||
"""
|
||||
client_ip = req.client.host if req.client else "unknown"
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/forgot-password",
|
||||
json_data={"email": request.email},
|
||||
headers={"X-Forwarded-For": client_ip}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/reset-password")
|
||||
async def reset_password(request: ResetPasswordRequest):
|
||||
"""
|
||||
Setzt das Passwort mit dem Token aus der E-Mail zurück.
|
||||
"""
|
||||
data = await proxy_to_consent_service(
|
||||
"POST",
|
||||
"/auth/reset-password",
|
||||
json_data={
|
||||
"token": request.token,
|
||||
"new_password": request.new_password
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
# ==========================================
|
||||
# Protected Profile Endpoints (Auth Required)
|
||||
# ==========================================
|
||||
|
||||
@router.get("/profile")
|
||||
async def get_profile(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Gibt das Profil des angemeldeten Benutzers zurück.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"GET",
|
||||
"/profile",
|
||||
headers=get_auth_headers(authorization)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.put("/profile")
|
||||
async def update_profile(
|
||||
request: UpdateProfileRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Aktualisiert das Profil des angemeldeten Benutzers.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"PUT",
|
||||
"/profile",
|
||||
json_data={"name": request.name},
|
||||
headers=get_auth_headers(authorization)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.put("/profile/password")
|
||||
async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Ändert das Passwort des angemeldeten Benutzers.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"PUT",
|
||||
"/profile/password",
|
||||
json_data={
|
||||
"current_password": request.current_password,
|
||||
"new_password": request.new_password
|
||||
},
|
||||
headers=get_auth_headers(authorization)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/profile/sessions")
|
||||
async def get_sessions(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Gibt alle aktiven Sessions des Benutzers zurück.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"GET",
|
||||
"/profile/sessions",
|
||||
headers=get_auth_headers(authorization)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@router.delete("/profile/sessions/{session_id}")
|
||||
async def revoke_session(
|
||||
session_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Beendet eine bestimmte Session.
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||
|
||||
data = await proxy_to_consent_service(
|
||||
"DELETE",
|
||||
f"/profile/sessions/{session_id}",
|
||||
headers=get_auth_headers(authorization)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
# ==========================================
|
||||
# Health Check
|
||||
# ==========================================
|
||||
|
||||
@router.get("/health")
|
||||
async def auth_health():
|
||||
"""
|
||||
Prüft die Verbindung zum Auth Service.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{CONSENT_SERVICE_URL}/health",
|
||||
timeout=5.0
|
||||
)
|
||||
is_healthy = response.status_code == 200
|
||||
except:
|
||||
is_healthy = False
|
||||
|
||||
return {
|
||||
"auth_service": "healthy" if is_healthy else "unavailable",
|
||||
"connected": is_healthy
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
|
||||
BASE_DIR = Path.home() / "Arbeitsblaetter"
|
||||
EINGANG_DIR = BASE_DIR / "Eingang"
|
||||
BEREINIGT_DIR = BASE_DIR / "Bereinigt"
|
||||
EDITIERBAR_DIR = BASE_DIR / "Editierbar"
|
||||
NEU_GENERIERT_DIR = BASE_DIR / "Neu_generiert"
|
||||
|
||||
VALID_SUFFIXES = {".jpg", ".jpeg", ".png", ".pdf", ".JPG", ".JPEG", ".PNG", ".PDF"}
|
||||
|
||||
# Ordner sicherstellen
|
||||
for d in [EINGANG_DIR, BEREINIGT_DIR, EDITIERBAR_DIR, NEU_GENERIERT_DIR]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def is_valid_input_file(path: Path) -> bool:
|
||||
"""Gemeinsame Filterlogik für Eingangsdateien."""
|
||||
return path.is_file() and not path.name.startswith(".") and path.suffix in VALID_SUFFIXES
|
||||
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Consent Service Client für BreakPilot
|
||||
Kommuniziert mit dem Consent Management Service für GDPR-Compliance
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import os
|
||||
import uuid
|
||||
|
||||
# Consent Service URL (aus Umgebungsvariable oder Standard für lokale Entwicklung)
|
||||
CONSENT_SERVICE_URL = os.getenv("CONSENT_SERVICE_URL", "http://localhost:8081")
|
||||
|
||||
# JWT Secret - MUSS mit dem Go Consent Service übereinstimmen!
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "breakpilot-dev-jwt-secret-2024")
|
||||
|
||||
|
||||
def generate_jwt_token(
|
||||
user_id: str = None,
|
||||
email: str = "demo@breakpilot.app",
|
||||
role: str = "user",
|
||||
expires_hours: int = 24
|
||||
) -> str:
|
||||
"""
|
||||
Generiert einen JWT Token für die Authentifizierung beim Consent Service.
|
||||
|
||||
Args:
|
||||
user_id: Die User-ID (wird generiert falls nicht angegeben)
|
||||
email: Die E-Mail-Adresse des Benutzers
|
||||
role: Die Rolle (user, admin, super_admin)
|
||||
expires_hours: Gültigkeitsdauer in Stunden
|
||||
|
||||
Returns:
|
||||
JWT Token als String
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"exp": datetime.utcnow() + timedelta(hours=expires_hours),
|
||||
"iat": datetime.utcnow(),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
|
||||
|
||||
|
||||
def generate_demo_token() -> str:
|
||||
"""Generiert einen Demo-Token für nicht-authentifizierte Benutzer"""
|
||||
return generate_jwt_token(
|
||||
user_id="demo-user-" + str(uuid.uuid4())[:8],
|
||||
email="demo@breakpilot.app",
|
||||
role="user"
|
||||
)
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
TERMS = "terms"
|
||||
PRIVACY = "privacy"
|
||||
COOKIES = "cookies"
|
||||
COMMUNITY = "community"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsentStatus:
|
||||
has_consent: bool
|
||||
current_version_id: Optional[str] = None
|
||||
consented_version: Optional[str] = None
|
||||
needs_update: bool = False
|
||||
consented_at: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentVersion:
|
||||
id: str
|
||||
document_id: str
|
||||
version: str
|
||||
language: str
|
||||
title: str
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
|
||||
|
||||
class ConsentClient:
|
||||
"""Client für die Kommunikation mit dem Consent Service"""
|
||||
|
||||
def __init__(self, base_url: str = CONSENT_SERVICE_URL):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_url = f"{self.base_url}/api/v1"
|
||||
|
||||
def _get_headers(self, jwt_token: str) -> Dict[str, str]:
|
||||
"""Erstellt die Header mit JWT Token"""
|
||||
return {
|
||||
"Authorization": f"Bearer {jwt_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async def check_consent(
|
||||
self,
|
||||
jwt_token: str,
|
||||
document_type: DocumentType,
|
||||
language: str = "de"
|
||||
) -> ConsentStatus:
|
||||
"""
|
||||
Prüft ob der Benutzer dem Dokument zugestimmt hat.
|
||||
Gibt zurück ob eine Zustimmung vorliegt und ob sie aktuell ist.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/consent/check/{document_type.value}",
|
||||
headers=self._get_headers(jwt_token),
|
||||
params={"language": language},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return ConsentStatus(
|
||||
has_consent=data.get("has_consent", False),
|
||||
current_version_id=data.get("current_version_id"),
|
||||
consented_version=data.get("consented_version"),
|
||||
needs_update=data.get("needs_update", False),
|
||||
consented_at=data.get("consented_at")
|
||||
)
|
||||
else:
|
||||
return ConsentStatus(has_consent=False, needs_update=True)
|
||||
|
||||
except httpx.RequestError:
|
||||
# Bei Verbindungsproblemen: Consent nicht erzwingen
|
||||
return ConsentStatus(has_consent=True, needs_update=False)
|
||||
|
||||
async def check_all_mandatory_consents(
|
||||
self,
|
||||
jwt_token: str,
|
||||
language: str = "de"
|
||||
) -> Dict[str, ConsentStatus]:
|
||||
"""
|
||||
Prüft alle verpflichtenden Dokumente (Terms, Privacy).
|
||||
Gibt ein Dictionary mit dem Status für jedes Dokument zurück.
|
||||
"""
|
||||
mandatory_docs = [DocumentType.TERMS, DocumentType.PRIVACY]
|
||||
results = {}
|
||||
|
||||
for doc_type in mandatory_docs:
|
||||
results[doc_type.value] = await self.check_consent(jwt_token, doc_type, language)
|
||||
|
||||
return results
|
||||
|
||||
async def get_pending_consents(
|
||||
self,
|
||||
jwt_token: str,
|
||||
language: str = "de"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Gibt eine Liste aller Dokumente zurück, die noch Zustimmung benötigen.
|
||||
Nützlich für die Anzeige beim Login/Registration.
|
||||
"""
|
||||
pending = []
|
||||
statuses = await self.check_all_mandatory_consents(jwt_token, language)
|
||||
|
||||
for doc_type, status in statuses.items():
|
||||
if not status.has_consent or status.needs_update:
|
||||
# Hole das aktuelle Dokument
|
||||
doc = await self.get_latest_document(jwt_token, doc_type, language)
|
||||
if doc:
|
||||
pending.append({
|
||||
"type": doc_type,
|
||||
"version_id": status.current_version_id,
|
||||
"title": doc.title,
|
||||
"content": doc.content,
|
||||
"summary": doc.summary,
|
||||
"is_update": status.has_consent and status.needs_update
|
||||
})
|
||||
|
||||
return pending
|
||||
|
||||
async def get_latest_document(
|
||||
self,
|
||||
jwt_token: str,
|
||||
document_type: str,
|
||||
language: str = "de"
|
||||
) -> Optional[DocumentVersion]:
|
||||
"""Holt die aktuellste Version eines Dokuments"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/documents/{document_type}/latest",
|
||||
headers=self._get_headers(jwt_token),
|
||||
params={"language": language},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return DocumentVersion(
|
||||
id=data["id"],
|
||||
document_id=data["document_id"],
|
||||
version=data["version"],
|
||||
language=data["language"],
|
||||
title=data["title"],
|
||||
content=data["content"],
|
||||
summary=data.get("summary")
|
||||
)
|
||||
return None
|
||||
|
||||
except httpx.RequestError:
|
||||
return None
|
||||
|
||||
async def give_consent(
|
||||
self,
|
||||
jwt_token: str,
|
||||
document_type: str,
|
||||
version_id: str,
|
||||
consented: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Speichert die Zustimmung des Benutzers.
|
||||
Gibt True zurück bei Erfolg.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.api_url}/consent",
|
||||
headers=self._get_headers(jwt_token),
|
||||
json={
|
||||
"document_type": document_type,
|
||||
"version_id": version_id,
|
||||
"consented": consented
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
return response.status_code == 201
|
||||
|
||||
except httpx.RequestError:
|
||||
return False
|
||||
|
||||
async def get_cookie_categories(
|
||||
self,
|
||||
jwt_token: str,
|
||||
language: str = "de"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Holt alle Cookie-Kategorien für das Cookie-Banner"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/cookies/categories",
|
||||
headers=self._get_headers(jwt_token),
|
||||
params={"language": language},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("categories", [])
|
||||
return []
|
||||
|
||||
except httpx.RequestError:
|
||||
return []
|
||||
|
||||
async def set_cookie_consent(
|
||||
self,
|
||||
jwt_token: str,
|
||||
categories: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Speichert die Cookie-Präferenzen.
|
||||
categories: [{"category_id": "...", "consented": true/false}, ...]
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.api_url}/cookies/consent",
|
||||
headers=self._get_headers(jwt_token),
|
||||
json={"categories": categories},
|
||||
timeout=10.0
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
except httpx.RequestError:
|
||||
return False
|
||||
|
||||
async def get_my_data(self, jwt_token: str) -> Optional[Dict[str, Any]]:
|
||||
"""GDPR Art. 15: Holt alle Daten des Benutzers"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/privacy/my-data",
|
||||
headers=self._get_headers(jwt_token),
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return None
|
||||
|
||||
except httpx.RequestError:
|
||||
return None
|
||||
|
||||
async def request_data_export(self, jwt_token: str) -> Optional[str]:
|
||||
"""GDPR Art. 20: Fordert einen Datenexport an"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.api_url}/privacy/export",
|
||||
headers=self._get_headers(jwt_token),
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 202:
|
||||
return response.json().get("request_id")
|
||||
return None
|
||||
|
||||
except httpx.RequestError:
|
||||
return None
|
||||
|
||||
async def request_data_deletion(
|
||||
self,
|
||||
jwt_token: str,
|
||||
reason: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""GDPR Art. 17: Fordert Löschung aller Daten an"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.api_url}/privacy/delete",
|
||||
headers=self._get_headers(jwt_token),
|
||||
json={"reason": reason} if reason else {},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 202:
|
||||
return response.json().get("request_id")
|
||||
return None
|
||||
|
||||
except httpx.RequestError:
|
||||
return None
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Prüft ob der Consent Service erreichbar ist"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/health",
|
||||
timeout=5.0
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
except httpx.RequestError:
|
||||
return False
|
||||
|
||||
|
||||
# Singleton-Instanz für einfachen Zugriff
|
||||
consent_client = ConsentClient()
|
||||
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
E-Mail Template API für BreakPilot
|
||||
Proxy für den Go Consent Service E-Mail Template Management
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from consent_client import CONSENT_SERVICE_URL, generate_jwt_token
|
||||
|
||||
router = APIRouter(prefix="/api/consent/admin/email-templates", tags=["Email Templates"])
|
||||
|
||||
# Base URL für E-Mail-Template-Endpunkte im Go Consent Service
|
||||
EMAIL_TEMPLATE_BASE = f"{CONSENT_SERVICE_URL}/api/v1/admin"
|
||||
|
||||
|
||||
async def get_admin_token() -> str:
|
||||
"""Generiert einen Admin-Token für API-Calls zum Consent Service"""
|
||||
return generate_jwt_token(
|
||||
user_id="a0000000-0000-0000-0000-000000000001",
|
||||
email="admin@breakpilot.app",
|
||||
role="admin",
|
||||
expires_hours=1
|
||||
)
|
||||
|
||||
|
||||
async def proxy_request(
|
||||
method: str,
|
||||
path: str,
|
||||
token: str,
|
||||
json_data: dict = None,
|
||||
params: dict = None
|
||||
) -> dict:
|
||||
"""Proxy-Funktion für API-Calls zum Go Consent Service"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
url = f"{EMAIL_TEMPLATE_BASE}{path}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
try:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=json_data)
|
||||
elif method == "PUT":
|
||||
response = await client.put(url, headers=headers, json=json_data)
|
||||
elif method == "DELETE":
|
||||
response = await client.delete(url, headers=headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {method}")
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_detail = response.text
|
||||
try:
|
||||
error_detail = response.json().get("error", response.text)
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
if response.status_code == 204:
|
||||
return {"success": True}
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Consent Service nicht erreichbar: {str(e)}")
|
||||
|
||||
|
||||
# ==========================================
|
||||
# E-Mail Template Typen
|
||||
# ==========================================
|
||||
|
||||
@router.get("/types")
|
||||
async def get_all_template_types():
|
||||
"""Gibt alle verfügbaren E-Mail-Template-Typen zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", "/email-templates/types", token)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# E-Mail Templates
|
||||
# ==========================================
|
||||
|
||||
@router.get("")
|
||||
async def get_all_templates():
|
||||
"""Gibt alle E-Mail-Templates zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", "/email-templates", token)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_template(request: Request):
|
||||
"""Erstellt ein neues E-Mail-Template"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("POST", "/email-templates", token, json_data=data)
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_settings():
|
||||
"""Gibt die E-Mail-Einstellungen zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", "/email-templates/settings", token)
|
||||
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_settings(request: Request):
|
||||
"""Aktualisiert die E-Mail-Einstellungen"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("PUT", "/email-templates/settings", token, json_data=data)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_email_stats():
|
||||
"""Gibt E-Mail-Statistiken zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", "/email-templates/stats", token)
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def get_send_logs(
|
||||
template_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
):
|
||||
"""Gibt E-Mail-Send-Logs zurück"""
|
||||
token = await get_admin_token()
|
||||
params = {"limit": limit, "offset": offset}
|
||||
if template_id:
|
||||
params["template_id"] = template_id
|
||||
if status:
|
||||
params["status"] = status
|
||||
return await proxy_request("GET", "/email-templates/logs", token, params=params)
|
||||
|
||||
|
||||
@router.get("/default/{template_type}")
|
||||
async def get_default_content(template_type: str):
|
||||
"""Gibt den Default-Inhalt für einen Template-Typ zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", f"/email-templates/default/{template_type}", token)
|
||||
|
||||
|
||||
@router.post("/initialize")
|
||||
async def initialize_templates():
|
||||
"""Initialisiert alle Standard-Templates"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("POST", "/email-templates/initialize", token)
|
||||
|
||||
|
||||
@router.get("/{template_id}")
|
||||
async def get_template(template_id: str):
|
||||
"""Gibt ein einzelnes E-Mail-Template zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", f"/email-templates/{template_id}", token)
|
||||
|
||||
|
||||
@router.get("/{template_id}/versions")
|
||||
async def get_template_versions(template_id: str):
|
||||
"""Gibt alle Versionen eines Templates zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", f"/email-templates/{template_id}/versions", token)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# E-Mail Template Versionen
|
||||
# ==========================================
|
||||
|
||||
versions_router = APIRouter(prefix="/api/consent/admin/email-template-versions", tags=["Email Template Versions"])
|
||||
|
||||
|
||||
@versions_router.get("/{version_id}")
|
||||
async def get_version(version_id: str):
|
||||
"""Gibt eine einzelne Version zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", f"/email-template-versions/{version_id}", token)
|
||||
|
||||
|
||||
@versions_router.post("")
|
||||
async def create_version(request: Request):
|
||||
"""Erstellt eine neue Version"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("POST", "/email-template-versions", token, json_data=data)
|
||||
|
||||
|
||||
@versions_router.put("/{version_id}")
|
||||
async def update_version(version_id: str, request: Request):
|
||||
"""Aktualisiert eine Version"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("PUT", f"/email-template-versions/{version_id}", token, json_data=data)
|
||||
|
||||
|
||||
@versions_router.post("/{version_id}/submit")
|
||||
async def submit_for_review(version_id: str):
|
||||
"""Sendet eine Version zur Überprüfung"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("POST", f"/email-template-versions/{version_id}/submit", token)
|
||||
|
||||
|
||||
@versions_router.post("/{version_id}/approve")
|
||||
async def approve_version(version_id: str, request: Request):
|
||||
"""Genehmigt eine Version"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("POST", f"/email-template-versions/{version_id}/approve", token, json_data=data)
|
||||
|
||||
|
||||
@versions_router.post("/{version_id}/reject")
|
||||
async def reject_version(version_id: str, request: Request):
|
||||
"""Lehnt eine Version ab"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("POST", f"/email-template-versions/{version_id}/reject", token, json_data=data)
|
||||
|
||||
|
||||
@versions_router.post("/{version_id}/publish")
|
||||
async def publish_version(version_id: str):
|
||||
"""Veröffentlicht eine Version"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("POST", f"/email-template-versions/{version_id}/publish", token)
|
||||
|
||||
|
||||
@versions_router.get("/{version_id}/approvals")
|
||||
async def get_approvals(version_id: str):
|
||||
"""Gibt die Genehmigungshistorie einer Version zurück"""
|
||||
token = await get_admin_token()
|
||||
return await proxy_request("GET", f"/email-template-versions/{version_id}/approvals", token)
|
||||
|
||||
|
||||
@versions_router.post("/{version_id}/preview")
|
||||
async def preview_version(version_id: str, request: Request):
|
||||
"""Generiert eine Vorschau einer Version"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("POST", f"/email-template-versions/{version_id}/preview", token, json_data=data)
|
||||
|
||||
|
||||
@versions_router.post("/{version_id}/send-test")
|
||||
async def send_test_email(version_id: str, request: Request):
|
||||
"""Sendet eine Test-E-Mail"""
|
||||
token = await get_admin_token()
|
||||
data = await request.json()
|
||||
return await proxy_request("POST", f"/email-template-versions/{version_id}/send-test", token, json_data=data)
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
BreakPilot Core Backend
|
||||
|
||||
Shared APIs for authentication, RBAC, notifications, email templates,
|
||||
system health, security (DevSecOps), and common middleware.
|
||||
|
||||
This is the extracted "core" service from the monorepo backend.
|
||||
It runs on port 8000 and uses the `core` schema in PostgreSQL.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router imports (shared APIs only)
|
||||
# ---------------------------------------------------------------------------
|
||||
from auth_api import router as auth_router
|
||||
from rbac_api import router as rbac_router
|
||||
from notification_api import router as notification_router
|
||||
from email_template_api import (
|
||||
router as email_template_router,
|
||||
versions_router as email_template_versions_router,
|
||||
)
|
||||
from system_api import router as system_router
|
||||
from security_api import router as security_router
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware imports
|
||||
# ---------------------------------------------------------------------------
|
||||
from middleware import (
|
||||
RequestIDMiddleware,
|
||||
SecurityHeadersMiddleware,
|
||||
RateLimiterMiddleware,
|
||||
PIIRedactor,
|
||||
InputGateMiddleware,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging
|
||||
# ---------------------------------------------------------------------------
|
||||
logging.basicConfig(
|
||||
level=os.getenv("LOG_LEVEL", "INFO"),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("backend-core")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Application
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(
|
||||
title="BreakPilot Core Backend",
|
||||
description="Shared APIs: Auth, RBAC, Notifications, Email Templates, System, Security",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CORS
|
||||
# ---------------------------------------------------------------------------
|
||||
ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom middleware stack (order matters -- outermost first)
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Request-ID (outermost so every response has it)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# 2. Security headers
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# 3. Input gate (body-size / content-type validation)
|
||||
app.add_middleware(InputGateMiddleware)
|
||||
|
||||
# 4. Rate limiter (Valkey-backed)
|
||||
VALKEY_URL = os.getenv("VALKEY_URL", os.getenv("REDIS_URL", "redis://valkey:6379/0"))
|
||||
app.add_middleware(RateLimiterMiddleware, valkey_url=VALKEY_URL)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth (proxy to consent-service)
|
||||
app.include_router(auth_router, prefix="/api")
|
||||
|
||||
# RBAC (teacher / role management)
|
||||
app.include_router(rbac_router, prefix="/api")
|
||||
|
||||
# Notifications (proxy to consent-service)
|
||||
app.include_router(notification_router, prefix="/api")
|
||||
|
||||
# Email templates (proxy to consent-service)
|
||||
app.include_router(email_template_router) # already has /api/consent/admin/email-templates prefix
|
||||
app.include_router(email_template_versions_router) # already has /api/consent/admin/email-template-versions prefix
|
||||
|
||||
# System (health, local-ip)
|
||||
app.include_router(system_router) # already has paths defined in router
|
||||
|
||||
# Security / DevSecOps dashboard
|
||||
app.include_router(security_router, prefix="/api")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup / Shutdown events
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
logger.info("backend-core starting up")
|
||||
# Ensure DATABASE_URL uses search_path=core,public
|
||||
db_url = os.getenv("DATABASE_URL", "")
|
||||
if db_url and "search_path" not in db_url:
|
||||
separator = "&" if "?" in db_url else "?"
|
||||
new_url = f"{db_url}{separator}search_path=core,public"
|
||||
os.environ["DATABASE_URL"] = new_url
|
||||
logger.info("DATABASE_URL updated with search_path=core,public")
|
||||
elif "search_path" in db_url:
|
||||
logger.info("DATABASE_URL already contains search_path")
|
||||
else:
|
||||
logger.warning("DATABASE_URL is not set -- database features will fail")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown():
|
||||
logger.info("backend-core shutting down")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entrypoint (for `python main.py` during development)
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=int(os.getenv("PORT", "8000")),
|
||||
reload=os.getenv("ENVIRONMENT", "development") == "development",
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
BreakPilot Middleware Stack
|
||||
|
||||
This module provides middleware components for the FastAPI backend:
|
||||
- Request-ID: Adds unique request identifiers for tracing
|
||||
- Security Headers: Adds security headers to all responses
|
||||
- Rate Limiter: Protects against abuse (Valkey-based)
|
||||
- PII Redactor: Redacts sensitive data from logs
|
||||
- Input Gate: Validates request body size and content types
|
||||
"""
|
||||
|
||||
from .request_id import RequestIDMiddleware, get_request_id
|
||||
from .security_headers import SecurityHeadersMiddleware
|
||||
from .rate_limiter import RateLimiterMiddleware
|
||||
from .pii_redactor import PIIRedactor, redact_pii
|
||||
from .input_gate import InputGateMiddleware
|
||||
|
||||
__all__ = [
|
||||
"RequestIDMiddleware",
|
||||
"get_request_id",
|
||||
"SecurityHeadersMiddleware",
|
||||
"RateLimiterMiddleware",
|
||||
"PIIRedactor",
|
||||
"redact_pii",
|
||||
"InputGateMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Input Validation Gate Middleware
|
||||
|
||||
Validates incoming requests for:
|
||||
- Request body size limits
|
||||
- Content-Type validation
|
||||
- File upload limits
|
||||
- Malicious content detection
|
||||
|
||||
Usage:
|
||||
from middleware import InputGateMiddleware
|
||||
|
||||
app.add_middleware(
|
||||
InputGateMiddleware,
|
||||
max_body_size=10 * 1024 * 1024, # 10MB
|
||||
allowed_content_types=["application/json", "multipart/form-data"],
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputGateConfig:
|
||||
"""Configuration for input validation."""
|
||||
|
||||
# Maximum request body size (default: 10MB)
|
||||
max_body_size: int = 10 * 1024 * 1024
|
||||
|
||||
# Allowed content types
|
||||
allowed_content_types: Set[str] = field(default_factory=lambda: {
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
"text/plain",
|
||||
})
|
||||
|
||||
# File upload specific limits
|
||||
max_file_size: int = 50 * 1024 * 1024 # 50MB for file uploads
|
||||
allowed_file_types: Set[str] = field(default_factory=lambda: {
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"text/csv",
|
||||
})
|
||||
|
||||
# Blocked file extensions (potential malware)
|
||||
blocked_extensions: Set[str] = field(default_factory=lambda: {
|
||||
".exe", ".bat", ".cmd", ".com", ".msi",
|
||||
".dll", ".scr", ".pif", ".vbs", ".js",
|
||||
".jar", ".sh", ".ps1", ".app",
|
||||
})
|
||||
|
||||
# Paths that allow larger uploads (e.g., file upload endpoints)
|
||||
large_upload_paths: List[str] = field(default_factory=lambda: [
|
||||
"/api/files/upload",
|
||||
"/api/documents/upload",
|
||||
"/api/attachments",
|
||||
])
|
||||
|
||||
# Paths excluded from validation
|
||||
excluded_paths: List[str] = field(default_factory=lambda: [
|
||||
"/health",
|
||||
"/metrics",
|
||||
])
|
||||
|
||||
# Enable strict content type checking
|
||||
strict_content_type: bool = True
|
||||
|
||||
|
||||
class InputGateMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that validates incoming request bodies and content types.
|
||||
|
||||
Protects against:
|
||||
- Oversized request bodies
|
||||
- Invalid content types
|
||||
- Potentially malicious file uploads
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
config: Optional[InputGateConfig] = None,
|
||||
max_body_size: Optional[int] = None,
|
||||
allowed_content_types: Optional[Set[str]] = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
|
||||
self.config = config or InputGateConfig()
|
||||
|
||||
# Apply overrides
|
||||
if max_body_size is not None:
|
||||
self.config.max_body_size = max_body_size
|
||||
if allowed_content_types is not None:
|
||||
self.config.allowed_content_types = allowed_content_types
|
||||
|
||||
# Auto-configure from environment
|
||||
env_max_size = os.getenv("MAX_REQUEST_BODY_SIZE")
|
||||
if env_max_size:
|
||||
try:
|
||||
self.config.max_body_size = int(env_max_size)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
"""Check if path is excluded from validation."""
|
||||
return path in self.config.excluded_paths
|
||||
|
||||
def _is_large_upload_path(self, path: str) -> bool:
|
||||
"""Check if path allows larger uploads."""
|
||||
for upload_path in self.config.large_upload_paths:
|
||||
if path.startswith(upload_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_max_size(self, path: str) -> int:
|
||||
"""Get the maximum allowed body size for this path."""
|
||||
if self._is_large_upload_path(path):
|
||||
return self.config.max_file_size
|
||||
return self.config.max_body_size
|
||||
|
||||
def _validate_content_type(self, content_type: Optional[str]) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate the content type.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not content_type:
|
||||
# Allow requests without content type (e.g., GET requests)
|
||||
return True, ""
|
||||
|
||||
# Extract base content type (remove charset, boundary, etc.)
|
||||
base_type = content_type.split(";")[0].strip().lower()
|
||||
|
||||
if base_type not in self.config.allowed_content_types:
|
||||
return False, f"Content-Type '{base_type}' is not allowed"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _check_blocked_extension(self, filename: str) -> bool:
|
||||
"""Check if filename has a blocked extension."""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
lower_filename = filename.lower()
|
||||
for ext in self.config.blocked_extensions:
|
||||
if lower_filename.endswith(ext):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Skip excluded paths
|
||||
if self._is_excluded_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip validation for GET, HEAD, OPTIONS requests
|
||||
if request.method in ("GET", "HEAD", "OPTIONS"):
|
||||
return await call_next(request)
|
||||
|
||||
# Validate content type for requests with body
|
||||
content_type = request.headers.get("Content-Type")
|
||||
if self.config.strict_content_type:
|
||||
is_valid, error_msg = self._validate_content_type(content_type)
|
||||
if not is_valid:
|
||||
return JSONResponse(
|
||||
status_code=415,
|
||||
content={
|
||||
"error": "unsupported_media_type",
|
||||
"message": error_msg,
|
||||
},
|
||||
)
|
||||
|
||||
# Check Content-Length header
|
||||
content_length = request.headers.get("Content-Length")
|
||||
if content_length:
|
||||
try:
|
||||
length = int(content_length)
|
||||
max_size = self._get_max_size(request.url.path)
|
||||
|
||||
if length > max_size:
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={
|
||||
"error": "payload_too_large",
|
||||
"message": f"Request body exceeds maximum size of {max_size} bytes",
|
||||
"max_size": max_size,
|
||||
},
|
||||
)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": "invalid_content_length",
|
||||
"message": "Invalid Content-Length header",
|
||||
},
|
||||
)
|
||||
|
||||
# For multipart uploads, check for blocked file extensions
|
||||
if content_type and "multipart/form-data" in content_type:
|
||||
# Note: Full file validation would require reading the body
|
||||
# which we avoid in middleware for performance reasons.
|
||||
# Detailed file validation should happen in the handler.
|
||||
pass
|
||||
|
||||
# Process request
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def validate_file_upload(
|
||||
filename: str,
|
||||
content_type: str,
|
||||
size: int,
|
||||
config: Optional[InputGateConfig] = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate a file upload.
|
||||
|
||||
Use this in upload handlers for detailed validation.
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
content_type: MIME type of the file
|
||||
size: File size in bytes
|
||||
config: Optional custom configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
cfg = config or InputGateConfig()
|
||||
|
||||
# Check size
|
||||
if size > cfg.max_file_size:
|
||||
return False, f"File size exceeds maximum of {cfg.max_file_size} bytes"
|
||||
|
||||
# Check extension
|
||||
if filename:
|
||||
lower_filename = filename.lower()
|
||||
for ext in cfg.blocked_extensions:
|
||||
if lower_filename.endswith(ext):
|
||||
return False, f"File extension '{ext}' is not allowed"
|
||||
|
||||
# Check content type
|
||||
if content_type and content_type not in cfg.allowed_file_types:
|
||||
return False, f"File type '{content_type}' is not allowed"
|
||||
|
||||
return True, ""
|
||||
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
PII Redactor
|
||||
|
||||
Redacts Personally Identifiable Information (PII) from logs and responses.
|
||||
Essential for DSGVO/GDPR compliance in BreakPilot.
|
||||
|
||||
Redacted data types:
|
||||
- Email addresses
|
||||
- IP addresses
|
||||
- German phone numbers
|
||||
- Names (when identified)
|
||||
- Student IDs
|
||||
- Credit card numbers
|
||||
- IBAN numbers
|
||||
|
||||
Usage:
|
||||
from middleware import PIIRedactor, redact_pii
|
||||
|
||||
# Use in logging
|
||||
logger.info(redact_pii(f"User {email} logged in from {ip}"))
|
||||
|
||||
# Configure redactor
|
||||
redactor = PIIRedactor(patterns=["email", "ip", "phone"])
|
||||
safe_message = redactor.redact(sensitive_message)
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Pattern, Set
|
||||
|
||||
|
||||
@dataclass
|
||||
class PIIPattern:
|
||||
"""Definition of a PII pattern."""
|
||||
name: str
|
||||
pattern: Pattern
|
||||
replacement: str
|
||||
|
||||
|
||||
# Pre-compiled regex patterns for common PII
|
||||
PII_PATTERNS: Dict[str, PIIPattern] = {
|
||||
"email": PIIPattern(
|
||||
name="email",
|
||||
pattern=re.compile(
|
||||
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
re.IGNORECASE
|
||||
),
|
||||
replacement="[EMAIL_REDACTED]",
|
||||
),
|
||||
"ip_v4": PIIPattern(
|
||||
name="ip_v4",
|
||||
pattern=re.compile(
|
||||
r'\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b'
|
||||
),
|
||||
replacement="[IP_REDACTED]",
|
||||
),
|
||||
"ip_v6": PIIPattern(
|
||||
name="ip_v6",
|
||||
pattern=re.compile(
|
||||
r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b'
|
||||
),
|
||||
replacement="[IP_REDACTED]",
|
||||
),
|
||||
"phone_de": PIIPattern(
|
||||
name="phone_de",
|
||||
pattern=re.compile(
|
||||
r'(?<!\w)(?:\+49|0049|0)[\s.-]?(?:\d{2,4})[\s.-]?(?:\d{3,4})[\s.-]?(?:\d{3,4})(?!\d)'
|
||||
),
|
||||
replacement="[PHONE_REDACTED]",
|
||||
),
|
||||
"phone_intl": PIIPattern(
|
||||
name="phone_intl",
|
||||
pattern=re.compile(
|
||||
r'(?<!\w)\+?(?:\d[\s.-]?){10,15}(?!\d)'
|
||||
),
|
||||
replacement="[PHONE_REDACTED]",
|
||||
),
|
||||
"credit_card": PIIPattern(
|
||||
name="credit_card",
|
||||
pattern=re.compile(
|
||||
r'\b(?:\d{4}[\s.-]?){3}\d{4}\b'
|
||||
),
|
||||
replacement="[CC_REDACTED]",
|
||||
),
|
||||
"iban": PIIPattern(
|
||||
name="iban",
|
||||
pattern=re.compile(
|
||||
r'\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b',
|
||||
re.IGNORECASE
|
||||
),
|
||||
replacement="[IBAN_REDACTED]",
|
||||
),
|
||||
"student_id": PIIPattern(
|
||||
name="student_id",
|
||||
pattern=re.compile(
|
||||
r'\b(?:student|schueler|schüler)[-_]?(?:id|nr)?[:\s]?\d{4,10}\b',
|
||||
re.IGNORECASE
|
||||
),
|
||||
replacement="[STUDENT_ID_REDACTED]",
|
||||
),
|
||||
"uuid": PIIPattern(
|
||||
name="uuid",
|
||||
pattern=re.compile(
|
||||
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b',
|
||||
re.IGNORECASE
|
||||
),
|
||||
replacement="[UUID_REDACTED]",
|
||||
),
|
||||
# German names are harder to detect, but we can catch common patterns
|
||||
"name_prefix": PIIPattern(
|
||||
name="name_prefix",
|
||||
pattern=re.compile(
|
||||
r'\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b'
|
||||
),
|
||||
replacement="[NAME_REDACTED]",
|
||||
),
|
||||
}
|
||||
|
||||
# Default patterns to enable
|
||||
DEFAULT_PATTERNS = ["email", "ip_v4", "ip_v6", "phone_de"]
|
||||
|
||||
|
||||
class PIIRedactor:
|
||||
"""
|
||||
Redacts PII from strings.
|
||||
|
||||
Attributes:
|
||||
patterns: List of pattern names to use (e.g., ["email", "ip_v4"])
|
||||
custom_patterns: Additional custom patterns
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patterns: Optional[List[str]] = None,
|
||||
custom_patterns: Optional[List[PIIPattern]] = None,
|
||||
preserve_format: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the PII redactor.
|
||||
|
||||
Args:
|
||||
patterns: List of pattern names to enable (default: email, ip_v4, ip_v6, phone_de)
|
||||
custom_patterns: Additional custom PIIPattern objects
|
||||
preserve_format: If True, preserve the length of redacted content
|
||||
"""
|
||||
self.patterns = patterns or DEFAULT_PATTERNS
|
||||
self.custom_patterns = custom_patterns or []
|
||||
self.preserve_format = preserve_format
|
||||
|
||||
# Build active patterns list
|
||||
self._active_patterns: List[PIIPattern] = []
|
||||
for pattern_name in self.patterns:
|
||||
if pattern_name in PII_PATTERNS:
|
||||
self._active_patterns.append(PII_PATTERNS[pattern_name])
|
||||
|
||||
# Add custom patterns
|
||||
self._active_patterns.extend(self.custom_patterns)
|
||||
|
||||
def redact(self, text: str) -> str:
|
||||
"""
|
||||
Redact PII from the given text.
|
||||
|
||||
Args:
|
||||
text: The text to redact PII from
|
||||
|
||||
Returns:
|
||||
Text with PII replaced by redaction markers
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
result = text
|
||||
for pattern in self._active_patterns:
|
||||
if self.preserve_format:
|
||||
# Replace with same-length placeholder
|
||||
def replace_preserve(match):
|
||||
length = len(match.group())
|
||||
return "*" * length
|
||||
result = pattern.pattern.sub(replace_preserve, result)
|
||||
else:
|
||||
result = pattern.pattern.sub(pattern.replacement, result)
|
||||
|
||||
return result
|
||||
|
||||
def contains_pii(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text contains any PII.
|
||||
|
||||
Args:
|
||||
text: The text to check
|
||||
|
||||
Returns:
|
||||
True if PII is detected
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
for pattern in self._active_patterns:
|
||||
if pattern.pattern.search(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
def find_pii(self, text: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Find all PII in text with their types.
|
||||
|
||||
Args:
|
||||
text: The text to search
|
||||
|
||||
Returns:
|
||||
List of dicts with 'type' and 'match' keys
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
findings = []
|
||||
for pattern in self._active_patterns:
|
||||
for match in pattern.pattern.finditer(text):
|
||||
findings.append({
|
||||
"type": pattern.name,
|
||||
"match": match.group(),
|
||||
"start": match.start(),
|
||||
"end": match.end(),
|
||||
})
|
||||
|
||||
return findings
|
||||
|
||||
|
||||
# Module-level default redactor instance
|
||||
_default_redactor: Optional[PIIRedactor] = None
|
||||
|
||||
|
||||
def get_default_redactor() -> PIIRedactor:
|
||||
"""Get or create the default redactor instance."""
|
||||
global _default_redactor
|
||||
if _default_redactor is None:
|
||||
_default_redactor = PIIRedactor()
|
||||
return _default_redactor
|
||||
|
||||
|
||||
def redact_pii(text: str) -> str:
|
||||
"""
|
||||
Convenience function to redact PII using the default redactor.
|
||||
|
||||
Args:
|
||||
text: Text to redact
|
||||
|
||||
Returns:
|
||||
Redacted text
|
||||
|
||||
Example:
|
||||
logger.info(redact_pii(f"User {email} logged in"))
|
||||
"""
|
||||
return get_default_redactor().redact(text)
|
||||
|
||||
|
||||
class PIIRedactingLogFilter:
|
||||
"""
|
||||
Logging filter that automatically redacts PII from log messages.
|
||||
|
||||
Usage:
|
||||
import logging
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.addFilter(PIIRedactingLogFilter())
|
||||
logger = logging.getLogger()
|
||||
logger.addHandler(handler)
|
||||
"""
|
||||
|
||||
def __init__(self, redactor: Optional[PIIRedactor] = None):
|
||||
self.redactor = redactor or get_default_redactor()
|
||||
|
||||
def filter(self, record):
|
||||
# Redact the message
|
||||
if record.msg:
|
||||
record.msg = self.redactor.redact(str(record.msg))
|
||||
|
||||
# Redact args if present
|
||||
if record.args:
|
||||
if isinstance(record.args, dict):
|
||||
record.args = {
|
||||
k: self.redactor.redact(str(v)) if isinstance(v, str) else v
|
||||
for k, v in record.args.items()
|
||||
}
|
||||
elif isinstance(record.args, tuple):
|
||||
record.args = tuple(
|
||||
self.redactor.redact(str(v)) if isinstance(v, str) else v
|
||||
for v in record.args
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def create_safe_dict(data: dict, redactor: Optional[PIIRedactor] = None) -> dict:
|
||||
"""
|
||||
Create a copy of a dictionary with PII redacted.
|
||||
|
||||
Args:
|
||||
data: Dictionary to redact
|
||||
redactor: Optional custom redactor
|
||||
|
||||
Returns:
|
||||
New dictionary with redacted values
|
||||
"""
|
||||
r = redactor or get_default_redactor()
|
||||
|
||||
def redact_value(value):
|
||||
if isinstance(value, str):
|
||||
return r.redact(value)
|
||||
elif isinstance(value, dict):
|
||||
return create_safe_dict(value, r)
|
||||
elif isinstance(value, list):
|
||||
return [redact_value(v) for v in value]
|
||||
return value
|
||||
|
||||
return {k: redact_value(v) for k, v in data.items()}
|
||||
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Rate Limiter Middleware
|
||||
|
||||
Implements distributed rate limiting using Valkey (Redis-fork).
|
||||
Supports IP-based, user-based, and endpoint-specific rate limits.
|
||||
|
||||
Features:
|
||||
- Sliding window rate limiting
|
||||
- IP-based limits for unauthenticated requests
|
||||
- User-based limits for authenticated requests
|
||||
- Stricter limits for auth endpoints (anti-brute-force)
|
||||
- IP whitelist/blacklist support
|
||||
- Graceful fallback when Valkey is unavailable
|
||||
|
||||
Usage:
|
||||
from middleware import RateLimiterMiddleware
|
||||
|
||||
app.add_middleware(
|
||||
RateLimiterMiddleware,
|
||||
valkey_url="redis://localhost:6379",
|
||||
ip_limit=100,
|
||||
user_limit=500,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
# Try to import redis (valkey-compatible)
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
redis = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting."""
|
||||
|
||||
# Valkey/Redis connection
|
||||
valkey_url: str = "redis://localhost:6379"
|
||||
|
||||
# Default limits (requests per minute)
|
||||
ip_limit: int = 100
|
||||
user_limit: int = 500
|
||||
|
||||
# Stricter limits for auth endpoints
|
||||
auth_limit: int = 20
|
||||
auth_endpoints: List[str] = field(default_factory=lambda: [
|
||||
"/api/auth/login",
|
||||
"/api/auth/register",
|
||||
"/api/auth/password-reset",
|
||||
"/api/auth/forgot-password",
|
||||
])
|
||||
|
||||
# Window size in seconds
|
||||
window_size: int = 60
|
||||
|
||||
# IP whitelist (never rate limited)
|
||||
ip_whitelist: Set[str] = field(default_factory=lambda: {
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
})
|
||||
|
||||
# IP blacklist (always blocked)
|
||||
ip_blacklist: Set[str] = field(default_factory=set)
|
||||
|
||||
# Skip internal Docker network
|
||||
skip_internal_network: bool = True
|
||||
|
||||
# Excluded paths
|
||||
excluded_paths: List[str] = field(default_factory=lambda: [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/health",
|
||||
])
|
||||
|
||||
# Fallback to in-memory when Valkey is unavailable
|
||||
fallback_enabled: bool = True
|
||||
|
||||
# Key prefix for rate limit keys
|
||||
key_prefix: str = "ratelimit"
|
||||
|
||||
|
||||
class InMemoryRateLimiter:
|
||||
"""Fallback in-memory rate limiter when Valkey is unavailable."""
|
||||
|
||||
def __init__(self):
|
||||
self._counts: Dict[str, List[float]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]:
|
||||
"""
|
||||
Check if rate limit is exceeded.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, remaining_requests)
|
||||
"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - window
|
||||
|
||||
# Get or create entry
|
||||
if key not in self._counts:
|
||||
self._counts[key] = []
|
||||
|
||||
# Remove old entries
|
||||
self._counts[key] = [t for t in self._counts[key] if t > window_start]
|
||||
|
||||
# Check limit
|
||||
current_count = len(self._counts[key])
|
||||
if current_count >= limit:
|
||||
return False, 0
|
||||
|
||||
# Add new request
|
||||
self._counts[key].append(now)
|
||||
return True, limit - current_count - 1
|
||||
|
||||
async def cleanup(self):
|
||||
"""Remove expired entries."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
for key in list(self._counts.keys()):
|
||||
self._counts[key] = [t for t in self._counts[key] if t > now - 3600]
|
||||
if not self._counts[key]:
|
||||
del self._counts[key]
|
||||
|
||||
|
||||
class RateLimiterMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that implements distributed rate limiting.
|
||||
|
||||
Uses Valkey (Redis-fork) for distributed state, with fallback
|
||||
to in-memory rate limiting when Valkey is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
config: Optional[RateLimitConfig] = None,
|
||||
# Individual overrides
|
||||
valkey_url: Optional[str] = None,
|
||||
ip_limit: Optional[int] = None,
|
||||
user_limit: Optional[int] = None,
|
||||
auth_limit: Optional[int] = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
|
||||
self.config = config or RateLimitConfig()
|
||||
|
||||
# Apply overrides
|
||||
if valkey_url is not None:
|
||||
self.config.valkey_url = valkey_url
|
||||
if ip_limit is not None:
|
||||
self.config.ip_limit = ip_limit
|
||||
if user_limit is not None:
|
||||
self.config.user_limit = user_limit
|
||||
if auth_limit is not None:
|
||||
self.config.auth_limit = auth_limit
|
||||
|
||||
# Auto-configure from environment
|
||||
self.config.valkey_url = os.getenv("VALKEY_URL", self.config.valkey_url)
|
||||
|
||||
# Initialize Valkey client
|
||||
self._redis: Optional[redis.Redis] = None
|
||||
self._fallback = InMemoryRateLimiter()
|
||||
self._valkey_available = False
|
||||
|
||||
async def _get_redis(self) -> Optional[redis.Redis]:
|
||||
"""Get or create Redis/Valkey connection."""
|
||||
if not REDIS_AVAILABLE:
|
||||
return None
|
||||
|
||||
if self._redis is None:
|
||||
try:
|
||||
self._redis = redis.from_url(
|
||||
self.config.valkey_url,
|
||||
decode_responses=True,
|
||||
socket_timeout=1.0,
|
||||
socket_connect_timeout=1.0,
|
||||
)
|
||||
await self._redis.ping()
|
||||
self._valkey_available = True
|
||||
except Exception:
|
||||
self._valkey_available = False
|
||||
self._redis = None
|
||||
|
||||
return self._redis
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request."""
|
||||
# Check X-Forwarded-For header
|
||||
xff = request.headers.get("X-Forwarded-For")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
|
||||
# Check X-Real-IP header
|
||||
xri = request.headers.get("X-Real-IP")
|
||||
if xri:
|
||||
return xri
|
||||
|
||||
# Fall back to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
def _get_user_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract user ID from request state (set by session middleware)."""
|
||||
if hasattr(request.state, "session") and request.state.session:
|
||||
return getattr(request.state.session, "user_id", None)
|
||||
return None
|
||||
|
||||
def _is_internal_network(self, ip: str) -> bool:
|
||||
"""Check if IP is from internal Docker network."""
|
||||
return (
|
||||
ip.startswith("172.") or
|
||||
ip.startswith("10.") or
|
||||
ip.startswith("192.168.")
|
||||
)
|
||||
|
||||
def _get_rate_limit(self, request: Request) -> int:
|
||||
"""Determine the rate limit for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Auth endpoints get stricter limits
|
||||
for auth_path in self.config.auth_endpoints:
|
||||
if path.startswith(auth_path):
|
||||
return self.config.auth_limit
|
||||
|
||||
# Authenticated users get higher limits
|
||||
if self._get_user_id(request):
|
||||
return self.config.user_limit
|
||||
|
||||
# Default IP-based limit
|
||||
return self.config.ip_limit
|
||||
|
||||
def _get_rate_limit_key(self, request: Request) -> str:
|
||||
"""Generate the rate limit key for this request."""
|
||||
# Use user ID if authenticated
|
||||
user_id = self._get_user_id(request)
|
||||
if user_id:
|
||||
identifier = f"user:{user_id}"
|
||||
else:
|
||||
ip = self._get_client_ip(request)
|
||||
# Hash IP for privacy
|
||||
ip_hash = hashlib.sha256(ip.encode()).hexdigest()[:16]
|
||||
identifier = f"ip:{ip_hash}"
|
||||
|
||||
# Include path for endpoint-specific limits
|
||||
path = request.url.path
|
||||
for auth_path in self.config.auth_endpoints:
|
||||
if path.startswith(auth_path):
|
||||
return f"{self.config.key_prefix}:auth:{identifier}"
|
||||
|
||||
return f"{self.config.key_prefix}:{identifier}"
|
||||
|
||||
async def _check_rate_limit_valkey(
|
||||
self, key: str, limit: int, window: int
|
||||
) -> tuple[bool, int]:
|
||||
"""Check rate limit using Valkey."""
|
||||
r = await self._get_redis()
|
||||
if not r:
|
||||
return await self._fallback.check_rate_limit(key, limit, window)
|
||||
|
||||
try:
|
||||
# Use sliding window with sorted set
|
||||
now = time.time()
|
||||
window_start = now - window
|
||||
|
||||
pipe = r.pipeline()
|
||||
# Remove old entries
|
||||
pipe.zremrangebyscore(key, "-inf", window_start)
|
||||
# Count current entries
|
||||
pipe.zcard(key)
|
||||
# Add new entry
|
||||
pipe.zadd(key, {str(now): now})
|
||||
# Set expiry
|
||||
pipe.expire(key, window + 10)
|
||||
|
||||
results = await pipe.execute()
|
||||
current_count = results[1]
|
||||
|
||||
if current_count >= limit:
|
||||
return False, 0
|
||||
|
||||
return True, limit - current_count - 1
|
||||
|
||||
except Exception:
|
||||
# Fallback to in-memory
|
||||
self._valkey_available = False
|
||||
return await self._fallback.check_rate_limit(key, limit, window)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Skip excluded paths
|
||||
if request.url.path in self.config.excluded_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
# Check blacklist
|
||||
if ip in self.config.ip_blacklist:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "ip_blocked",
|
||||
"message": "Your IP address has been blocked.",
|
||||
},
|
||||
)
|
||||
|
||||
# Skip whitelist
|
||||
if ip in self.config.ip_whitelist:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip internal network
|
||||
if self.config.skip_internal_network and self._is_internal_network(ip):
|
||||
return await call_next(request)
|
||||
|
||||
# Get rate limit parameters
|
||||
limit = self._get_rate_limit(request)
|
||||
key = self._get_rate_limit_key(request)
|
||||
window = self.config.window_size
|
||||
|
||||
# Check rate limit
|
||||
allowed, remaining = await self._check_rate_limit_valkey(key, limit, window)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
"retry_after": window,
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(window),
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
"X-RateLimit-Reset": str(int(time.time()) + window),
|
||||
},
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(int(time.time()) + window)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Request-ID Middleware
|
||||
|
||||
Generates and propagates unique request identifiers for distributed tracing.
|
||||
Supports both X-Request-ID and X-Correlation-ID headers.
|
||||
|
||||
Usage:
|
||||
from middleware import RequestIDMiddleware, get_request_id
|
||||
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
@app.get("/api/example")
|
||||
async def example():
|
||||
request_id = get_request_id()
|
||||
logger.info(f"Processing request", extra={"request_id": request_id})
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
# Context variable to store request ID across async calls
|
||||
_request_id_ctx: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
|
||||
|
||||
# Header names
|
||||
REQUEST_ID_HEADER = "X-Request-ID"
|
||||
CORRELATION_ID_HEADER = "X-Correlation-ID"
|
||||
|
||||
|
||||
def get_request_id() -> Optional[str]:
|
||||
"""
|
||||
Get the current request ID from context.
|
||||
|
||||
Returns:
|
||||
The request ID string or None if not in a request context.
|
||||
|
||||
Example:
|
||||
request_id = get_request_id()
|
||||
logger.info("Processing", extra={"request_id": request_id})
|
||||
"""
|
||||
return _request_id_ctx.get()
|
||||
|
||||
|
||||
def set_request_id(request_id: str) -> None:
|
||||
"""
|
||||
Set the request ID in the current context.
|
||||
|
||||
Args:
|
||||
request_id: The request ID to set
|
||||
"""
|
||||
_request_id_ctx.set(request_id)
|
||||
|
||||
|
||||
def generate_request_id() -> str:
|
||||
"""
|
||||
Generate a new unique request ID.
|
||||
|
||||
Returns:
|
||||
A UUID4 string
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that generates and propagates request IDs.
|
||||
|
||||
For each incoming request:
|
||||
1. Check for existing X-Request-ID or X-Correlation-ID header
|
||||
2. If not present, generate a new UUID
|
||||
3. Store in context for use by handlers and logging
|
||||
4. Add to response headers
|
||||
|
||||
Attributes:
|
||||
header_name: The primary header name to use (default: X-Request-ID)
|
||||
generator: Function to generate new IDs (default: uuid4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
header_name: str = REQUEST_ID_HEADER,
|
||||
generator=generate_request_id,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.header_name = header_name
|
||||
self.generator = generator
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Try to get existing request ID from headers
|
||||
request_id = (
|
||||
request.headers.get(REQUEST_ID_HEADER)
|
||||
or request.headers.get(CORRELATION_ID_HEADER)
|
||||
)
|
||||
|
||||
# Generate new ID if not provided
|
||||
if not request_id:
|
||||
request_id = self.generator()
|
||||
|
||||
# Store in context for logging and handlers
|
||||
set_request_id(request_id)
|
||||
|
||||
# Store in request state for direct access
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers[REQUEST_ID_HEADER] = request_id
|
||||
response.headers[CORRELATION_ID_HEADER] = request_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RequestIDLogFilter:
|
||||
"""
|
||||
Logging filter that adds request_id to log records.
|
||||
|
||||
Usage:
|
||||
import logging
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.addFilter(RequestIDLogFilter())
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s [%(request_id)s] %(levelname)s %(message)s'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
record.request_id = get_request_id() or "no-request-id"
|
||||
return True
|
||||
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Security Headers Middleware
|
||||
|
||||
Adds security headers to all HTTP responses to protect against common attacks.
|
||||
|
||||
Headers added:
|
||||
- X-Content-Type-Options: nosniff
|
||||
- X-Frame-Options: DENY
|
||||
- X-XSS-Protection: 1; mode=block
|
||||
- Strict-Transport-Security (HSTS)
|
||||
- Content-Security-Policy
|
||||
- Referrer-Policy
|
||||
- Permissions-Policy
|
||||
|
||||
Usage:
|
||||
from middleware import SecurityHeadersMiddleware
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Or with custom configuration:
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts_enabled=True,
|
||||
csp_policy="default-src 'self'",
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityHeadersConfig:
|
||||
"""Configuration for security headers."""
|
||||
|
||||
# X-Content-Type-Options
|
||||
content_type_options: str = "nosniff"
|
||||
|
||||
# X-Frame-Options
|
||||
frame_options: str = "DENY"
|
||||
|
||||
# X-XSS-Protection (legacy, but still useful for older browsers)
|
||||
xss_protection: str = "1; mode=block"
|
||||
|
||||
# Strict-Transport-Security
|
||||
hsts_enabled: bool = True
|
||||
hsts_max_age: int = 31536000 # 1 year
|
||||
hsts_include_subdomains: bool = True
|
||||
hsts_preload: bool = False
|
||||
|
||||
# Content-Security-Policy
|
||||
csp_enabled: bool = True
|
||||
csp_policy: str = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https:; frame-ancestors 'none'"
|
||||
|
||||
# Referrer-Policy
|
||||
referrer_policy: str = "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions-Policy (formerly Feature-Policy)
|
||||
permissions_policy: str = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Cross-Origin headers
|
||||
cross_origin_opener_policy: str = "same-origin"
|
||||
cross_origin_embedder_policy: str = "require-corp"
|
||||
cross_origin_resource_policy: str = "same-origin"
|
||||
|
||||
# Development mode (relaxes some restrictions)
|
||||
development_mode: bool = False
|
||||
|
||||
# Excluded paths (e.g., for health checks)
|
||||
excluded_paths: List[str] = field(default_factory=lambda: ["/health", "/metrics"])
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that adds security headers to all responses.
|
||||
|
||||
Attributes:
|
||||
config: SecurityHeadersConfig instance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
config: Optional[SecurityHeadersConfig] = None,
|
||||
# Individual overrides for convenience
|
||||
hsts_enabled: Optional[bool] = None,
|
||||
csp_policy: Optional[str] = None,
|
||||
csp_enabled: Optional[bool] = None,
|
||||
development_mode: Optional[bool] = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
|
||||
# Use provided config or create default
|
||||
self.config = config or SecurityHeadersConfig()
|
||||
|
||||
# Apply individual overrides
|
||||
if hsts_enabled is not None:
|
||||
self.config.hsts_enabled = hsts_enabled
|
||||
if csp_policy is not None:
|
||||
self.config.csp_policy = csp_policy
|
||||
if csp_enabled is not None:
|
||||
self.config.csp_enabled = csp_enabled
|
||||
if development_mode is not None:
|
||||
self.config.development_mode = development_mode
|
||||
|
||||
# Auto-detect development mode from environment
|
||||
if development_mode is None:
|
||||
env = os.getenv("ENVIRONMENT", "development")
|
||||
self.config.development_mode = env.lower() in ("development", "dev", "local")
|
||||
|
||||
def _build_hsts_header(self) -> str:
|
||||
"""Build the Strict-Transport-Security header value."""
|
||||
parts = [f"max-age={self.config.hsts_max_age}"]
|
||||
if self.config.hsts_include_subdomains:
|
||||
parts.append("includeSubDomains")
|
||||
if self.config.hsts_preload:
|
||||
parts.append("preload")
|
||||
return "; ".join(parts)
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Build the security headers dictionary."""
|
||||
headers = {}
|
||||
|
||||
# Always add these headers
|
||||
headers["X-Content-Type-Options"] = self.config.content_type_options
|
||||
headers["X-Frame-Options"] = self.config.frame_options
|
||||
headers["X-XSS-Protection"] = self.config.xss_protection
|
||||
headers["Referrer-Policy"] = self.config.referrer_policy
|
||||
|
||||
# HSTS (only in production or if explicitly enabled)
|
||||
if self.config.hsts_enabled and not self.config.development_mode:
|
||||
headers["Strict-Transport-Security"] = self._build_hsts_header()
|
||||
|
||||
# Content-Security-Policy
|
||||
if self.config.csp_enabled:
|
||||
headers["Content-Security-Policy"] = self.config.csp_policy
|
||||
|
||||
# Permissions-Policy
|
||||
if self.config.permissions_policy:
|
||||
headers["Permissions-Policy"] = self.config.permissions_policy
|
||||
|
||||
# Cross-Origin headers (relaxed in development)
|
||||
if not self.config.development_mode:
|
||||
headers["Cross-Origin-Opener-Policy"] = self.config.cross_origin_opener_policy
|
||||
# Note: COEP can break loading of external resources, be careful
|
||||
# headers["Cross-Origin-Embedder-Policy"] = self.config.cross_origin_embedder_policy
|
||||
headers["Cross-Origin-Resource-Policy"] = self.config.cross_origin_resource_policy
|
||||
|
||||
return headers
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Skip security headers for excluded paths
|
||||
if request.url.path in self.config.excluded_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
for header_name, header_value in self._get_headers().items():
|
||||
response.headers[header_name] = header_value
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_default_csp_for_environment(environment: str) -> str:
|
||||
"""
|
||||
Get a sensible default CSP for the given environment.
|
||||
|
||||
Args:
|
||||
environment: "development", "staging", or "production"
|
||||
|
||||
Returns:
|
||||
CSP policy string
|
||||
"""
|
||||
if environment.lower() in ("development", "dev", "local"):
|
||||
# Relaxed CSP for development
|
||||
return (
|
||||
"default-src 'self' localhost:* ws://localhost:*; "
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: https: blob:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self' localhost:* ws://localhost:* https:; "
|
||||
"frame-ancestors 'self'"
|
||||
)
|
||||
else:
|
||||
# Strict CSP for production
|
||||
return (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: https:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; "
|
||||
"frame-ancestors 'none'"
|
||||
)
|
||||
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Notification API - Proxy zu Go Consent Service für Benachrichtigungen
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from typing import Optional
|
||||
import httpx
|
||||
|
||||
router = APIRouter(prefix="/v1/notifications", tags=["Notifications"])
|
||||
|
||||
CONSENT_SERVICE_URL = "http://localhost:8081"
|
||||
|
||||
|
||||
async def proxy_request(
|
||||
method: str,
|
||||
path: str,
|
||||
authorization: Optional[str] = None,
|
||||
json_data: dict = None,
|
||||
params: dict = None
|
||||
):
|
||||
"""Proxy request to Go consent service."""
|
||||
headers = {}
|
||||
if authorization:
|
||||
headers["Authorization"] = authorization
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.request(
|
||||
method,
|
||||
f"{CONSENT_SERVICE_URL}{path}",
|
||||
headers=headers,
|
||||
json=json_data,
|
||||
params=params,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=response.json().get("error", "Request failed")
|
||||
)
|
||||
|
||||
return response.json()
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Consent service unavailable: {str(e)}")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_notifications(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
unread_only: bool = Query(False),
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Holt alle Benachrichtigungen des aktuellen Benutzers."""
|
||||
params = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"unread_only": str(unread_only).lower()
|
||||
}
|
||||
return await proxy_request(
|
||||
"GET",
|
||||
"/api/v1/notifications",
|
||||
authorization=authorization,
|
||||
params=params
|
||||
)
|
||||
|
||||
|
||||
@router.get("/unread-count")
|
||||
async def get_unread_count(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Gibt die Anzahl ungelesener Benachrichtigungen zurück."""
|
||||
return await proxy_request(
|
||||
"GET",
|
||||
"/api/v1/notifications/unread-count",
|
||||
authorization=authorization
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{notification_id}/read")
|
||||
async def mark_as_read(
|
||||
notification_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Markiert eine Benachrichtigung als gelesen."""
|
||||
return await proxy_request(
|
||||
"PUT",
|
||||
f"/api/v1/notifications/{notification_id}/read",
|
||||
authorization=authorization
|
||||
)
|
||||
|
||||
|
||||
@router.put("/read-all")
|
||||
async def mark_all_as_read(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Markiert alle Benachrichtigungen als gelesen."""
|
||||
return await proxy_request(
|
||||
"PUT",
|
||||
"/api/v1/notifications/read-all",
|
||||
authorization=authorization
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{notification_id}")
|
||||
async def delete_notification(
|
||||
notification_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Löscht eine Benachrichtigung."""
|
||||
return await proxy_request(
|
||||
"DELETE",
|
||||
f"/api/v1/notifications/{notification_id}",
|
||||
authorization=authorization
|
||||
)
|
||||
|
||||
|
||||
@router.get("/preferences")
|
||||
async def get_preferences(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Holt die Benachrichtigungseinstellungen des Benutzers."""
|
||||
return await proxy_request(
|
||||
"GET",
|
||||
"/api/v1/notifications/preferences",
|
||||
authorization=authorization
|
||||
)
|
||||
|
||||
|
||||
@router.put("/preferences")
|
||||
async def update_preferences(
|
||||
preferences: dict,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""Aktualisiert die Benachrichtigungseinstellungen."""
|
||||
return await proxy_request(
|
||||
"PUT",
|
||||
"/api/v1/notifications/preferences",
|
||||
authorization=authorization,
|
||||
json_data=preferences
|
||||
)
|
||||
@@ -0,0 +1,819 @@
|
||||
"""
|
||||
RBAC API - Teacher and Role Management Endpoints
|
||||
|
||||
Provides API endpoints for:
|
||||
- Listing all teachers
|
||||
- Listing all available roles
|
||||
- Assigning/revoking roles to teachers
|
||||
- Viewing role assignments per teacher
|
||||
|
||||
Architecture:
|
||||
- Authentication: Keycloak (when configured) or local JWT
|
||||
- Authorization: Custom rbac.py for fine-grained permissions
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncpg
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import hybrid auth module
|
||||
try:
|
||||
from auth import get_current_user, TokenExpiredError, TokenInvalidError
|
||||
except ImportError:
|
||||
# Fallback for standalone testing
|
||||
from auth.keycloak_auth import get_current_user, TokenExpiredError, TokenInvalidError
|
||||
|
||||
# Configuration from environment - NO DEFAULT SECRETS
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT", "development")
|
||||
|
||||
router = APIRouter(prefix="/rbac", tags=["rbac"])
|
||||
|
||||
# Connection pool
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
def _get_database_url() -> str:
|
||||
"""Get DATABASE_URL from environment, raising error if not set."""
|
||||
url = os.environ.get("DATABASE_URL")
|
||||
if not url:
|
||||
raise RuntimeError("DATABASE_URL nicht konfiguriert - bitte via Vault oder Umgebungsvariable setzen")
|
||||
return url
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
"""Get or create database connection pool"""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
database_url = _get_database_url()
|
||||
_pool = await asyncpg.create_pool(database_url, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
async def close_pool():
|
||||
"""Close database connection pool"""
|
||||
global _pool
|
||||
if _pool:
|
||||
await _pool.close()
|
||||
_pool = None
|
||||
|
||||
|
||||
# Pydantic Models
|
||||
class RoleAssignmentCreate(BaseModel):
|
||||
user_id: str
|
||||
role: str
|
||||
resource_type: str = "tenant"
|
||||
resource_id: str
|
||||
valid_to: Optional[str] = None
|
||||
|
||||
|
||||
class RoleAssignmentRevoke(BaseModel):
|
||||
assignment_id: str
|
||||
|
||||
|
||||
class TeacherCreate(BaseModel):
|
||||
email: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
teacher_code: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
roles: List[str] = []
|
||||
|
||||
|
||||
class TeacherUpdate(BaseModel):
|
||||
email: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
teacher_code: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class CustomRoleCreate(BaseModel):
|
||||
role_key: str
|
||||
display_name: str
|
||||
description: str
|
||||
category: str
|
||||
|
||||
|
||||
class CustomRoleUpdate(BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
|
||||
class TeacherResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
email: str
|
||||
name: str
|
||||
teacher_code: Optional[str]
|
||||
title: Optional[str]
|
||||
first_name: str
|
||||
last_name: str
|
||||
is_active: bool
|
||||
roles: List[str]
|
||||
|
||||
|
||||
class RoleInfo(BaseModel):
|
||||
role: str
|
||||
display_name: str
|
||||
description: str
|
||||
category: str
|
||||
|
||||
|
||||
class RoleAssignmentResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
role: str
|
||||
resource_type: str
|
||||
resource_id: str
|
||||
valid_from: str
|
||||
valid_to: Optional[str]
|
||||
granted_at: str
|
||||
is_active: bool
|
||||
|
||||
|
||||
# Role definitions with German display names
|
||||
AVAILABLE_ROLES = {
|
||||
# Klausur-Korrekturkette
|
||||
"erstkorrektor": {
|
||||
"display_name": "Erstkorrektor",
|
||||
"description": "Fuehrt die erste Korrektur der Klausur durch",
|
||||
"category": "klausur"
|
||||
},
|
||||
"zweitkorrektor": {
|
||||
"display_name": "Zweitkorrektor",
|
||||
"description": "Fuehrt die zweite Korrektur der Klausur durch",
|
||||
"category": "klausur"
|
||||
},
|
||||
"drittkorrektor": {
|
||||
"display_name": "Drittkorrektor",
|
||||
"description": "Fuehrt die dritte Korrektur bei Notenabweichung durch",
|
||||
"category": "klausur"
|
||||
},
|
||||
# Zeugnis-Workflow
|
||||
"klassenlehrer": {
|
||||
"display_name": "Klassenlehrer/in",
|
||||
"description": "Erstellt Zeugnisse, traegt Kopfnoten und Bemerkungen ein",
|
||||
"category": "zeugnis"
|
||||
},
|
||||
"fachlehrer": {
|
||||
"display_name": "Fachlehrer/in",
|
||||
"description": "Traegt Fachnoten ein",
|
||||
"category": "zeugnis"
|
||||
},
|
||||
"zeugnisbeauftragter": {
|
||||
"display_name": "Zeugnisbeauftragte/r",
|
||||
"description": "Qualitaetskontrolle und Freigabe von Zeugnissen",
|
||||
"category": "zeugnis"
|
||||
},
|
||||
"sekretariat": {
|
||||
"display_name": "Sekretariat",
|
||||
"description": "Druck, Versand und Archivierung von Dokumenten",
|
||||
"category": "verwaltung"
|
||||
},
|
||||
# Leitung
|
||||
"fachvorsitz": {
|
||||
"display_name": "Fachvorsitz",
|
||||
"description": "Fachpruefungsleitung und Qualitaetssicherung",
|
||||
"category": "leitung"
|
||||
},
|
||||
"pruefungsvorsitz": {
|
||||
"display_name": "Pruefungsvorsitz",
|
||||
"description": "Pruefungsleitung und finale Freigabe",
|
||||
"category": "leitung"
|
||||
},
|
||||
"schulleitung": {
|
||||
"display_name": "Schulleitung",
|
||||
"description": "Finale Freigabe und Unterschrift",
|
||||
"category": "leitung"
|
||||
},
|
||||
"stufenleitung": {
|
||||
"display_name": "Stufenleitung",
|
||||
"description": "Koordination einer Jahrgangsstufe",
|
||||
"category": "leitung"
|
||||
},
|
||||
# Administration
|
||||
"schul_admin": {
|
||||
"display_name": "Schul-Administrator",
|
||||
"description": "Technische Administration der Schule",
|
||||
"category": "admin"
|
||||
},
|
||||
"teacher_assistant": {
|
||||
"display_name": "Referendar/in",
|
||||
"description": "Lehrkraft in Ausbildung mit eingeschraenkten Rechten",
|
||||
"category": "other"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Note: get_user_from_token is replaced by the imported get_current_user dependency
|
||||
# from auth module which supports both Keycloak and local JWT authentication
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@router.get("/roles")
|
||||
async def list_available_roles() -> List[RoleInfo]:
|
||||
"""List all available roles with their descriptions"""
|
||||
return [
|
||||
RoleInfo(
|
||||
role=role_key,
|
||||
display_name=role_data["display_name"],
|
||||
description=role_data["description"],
|
||||
category=role_data["category"]
|
||||
)
|
||||
for role_key, role_data in AVAILABLE_ROLES.items()
|
||||
]
|
||||
|
||||
|
||||
@router.get("/teachers")
|
||||
async def list_teachers(user: Dict[str, Any] = Depends(get_current_user)) -> List[TeacherResponse]:
|
||||
"""List all teachers with their current roles"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Get all teachers with their user info
|
||||
teachers = await conn.fetch("""
|
||||
SELECT
|
||||
t.id, t.user_id, t.teacher_code, t.title,
|
||||
t.first_name, t.last_name, t.is_active,
|
||||
u.email, u.name
|
||||
FROM teachers t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE t.school_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
ORDER BY t.last_name, t.first_name
|
||||
""")
|
||||
|
||||
# Get role assignments for all teachers
|
||||
role_assignments = await conn.fetch("""
|
||||
SELECT user_id, role
|
||||
FROM role_assignments
|
||||
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND revoked_at IS NULL
|
||||
AND (valid_to IS NULL OR valid_to > NOW())
|
||||
""")
|
||||
|
||||
# Build role lookup
|
||||
role_lookup: Dict[str, List[str]] = {}
|
||||
for ra in role_assignments:
|
||||
uid = str(ra["user_id"])
|
||||
if uid not in role_lookup:
|
||||
role_lookup[uid] = []
|
||||
role_lookup[uid].append(ra["role"])
|
||||
|
||||
# Build response
|
||||
result = []
|
||||
for t in teachers:
|
||||
uid = str(t["user_id"])
|
||||
result.append(TeacherResponse(
|
||||
id=str(t["id"]),
|
||||
user_id=uid,
|
||||
email=t["email"],
|
||||
name=t["name"] or f"{t['first_name']} {t['last_name']}",
|
||||
teacher_code=t["teacher_code"],
|
||||
title=t["title"],
|
||||
first_name=t["first_name"],
|
||||
last_name=t["last_name"],
|
||||
is_active=t["is_active"],
|
||||
roles=role_lookup.get(uid, [])
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/teachers/{teacher_id}/roles")
|
||||
async def get_teacher_roles(teacher_id: str, user: Dict[str, Any] = Depends(get_current_user)) -> List[RoleAssignmentResponse]:
|
||||
"""Get all role assignments for a specific teacher"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Get teacher's user_id
|
||||
teacher = await conn.fetchrow(
|
||||
"SELECT user_id FROM teachers WHERE id = $1",
|
||||
teacher_id
|
||||
)
|
||||
if not teacher:
|
||||
raise HTTPException(status_code=404, detail="Teacher not found")
|
||||
|
||||
# Get role assignments
|
||||
assignments = await conn.fetch("""
|
||||
SELECT id, user_id, role, resource_type, resource_id,
|
||||
valid_from, valid_to, granted_at, revoked_at
|
||||
FROM role_assignments
|
||||
WHERE user_id = $1
|
||||
ORDER BY granted_at DESC
|
||||
""", teacher["user_id"])
|
||||
|
||||
return [
|
||||
RoleAssignmentResponse(
|
||||
id=str(a["id"]),
|
||||
user_id=str(a["user_id"]),
|
||||
role=a["role"],
|
||||
resource_type=a["resource_type"],
|
||||
resource_id=str(a["resource_id"]),
|
||||
valid_from=a["valid_from"].isoformat() if a["valid_from"] else None,
|
||||
valid_to=a["valid_to"].isoformat() if a["valid_to"] else None,
|
||||
granted_at=a["granted_at"].isoformat() if a["granted_at"] else None,
|
||||
is_active=a["revoked_at"] is None and (
|
||||
a["valid_to"] is None or a["valid_to"] > datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
for a in assignments
|
||||
]
|
||||
|
||||
|
||||
@router.get("/roles/{role}/teachers")
|
||||
async def get_teachers_by_role(role: str, user: Dict[str, Any] = Depends(get_current_user)) -> List[TeacherResponse]:
|
||||
"""Get all teachers with a specific role"""
|
||||
if role not in AVAILABLE_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown role: {role}")
|
||||
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
teachers = await conn.fetch("""
|
||||
SELECT DISTINCT
|
||||
t.id, t.user_id, t.teacher_code, t.title,
|
||||
t.first_name, t.last_name, t.is_active,
|
||||
u.email, u.name
|
||||
FROM teachers t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
JOIN role_assignments ra ON t.user_id = ra.user_id
|
||||
WHERE ra.role = $1
|
||||
AND ra.revoked_at IS NULL
|
||||
AND (ra.valid_to IS NULL OR ra.valid_to > NOW())
|
||||
AND t.school_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
ORDER BY t.last_name, t.first_name
|
||||
""", role)
|
||||
|
||||
# Get all roles for these teachers
|
||||
if teachers:
|
||||
user_ids = [t["user_id"] for t in teachers]
|
||||
role_assignments = await conn.fetch("""
|
||||
SELECT user_id, role
|
||||
FROM role_assignments
|
||||
WHERE user_id = ANY($1)
|
||||
AND revoked_at IS NULL
|
||||
AND (valid_to IS NULL OR valid_to > NOW())
|
||||
""", user_ids)
|
||||
|
||||
role_lookup: Dict[str, List[str]] = {}
|
||||
for ra in role_assignments:
|
||||
uid = str(ra["user_id"])
|
||||
if uid not in role_lookup:
|
||||
role_lookup[uid] = []
|
||||
role_lookup[uid].append(ra["role"])
|
||||
else:
|
||||
role_lookup = {}
|
||||
|
||||
return [
|
||||
TeacherResponse(
|
||||
id=str(t["id"]),
|
||||
user_id=str(t["user_id"]),
|
||||
email=t["email"],
|
||||
name=t["name"] or f"{t['first_name']} {t['last_name']}",
|
||||
teacher_code=t["teacher_code"],
|
||||
title=t["title"],
|
||||
first_name=t["first_name"],
|
||||
last_name=t["last_name"],
|
||||
is_active=t["is_active"],
|
||||
roles=role_lookup.get(str(t["user_id"]), [])
|
||||
)
|
||||
for t in teachers
|
||||
]
|
||||
|
||||
|
||||
@router.post("/assignments")
|
||||
async def assign_role(assignment: RoleAssignmentCreate, user: Dict[str, Any] = Depends(get_current_user)) -> RoleAssignmentResponse:
|
||||
"""Assign a role to a user"""
|
||||
if assignment.role not in AVAILABLE_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown role: {assignment.role}")
|
||||
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Check if assignment already exists
|
||||
existing = await conn.fetchrow("""
|
||||
SELECT id FROM role_assignments
|
||||
WHERE user_id = $1 AND role = $2 AND resource_id = $3
|
||||
AND revoked_at IS NULL
|
||||
""", assignment.user_id, assignment.role, assignment.resource_id)
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Role assignment already exists"
|
||||
)
|
||||
|
||||
# Parse valid_to if provided
|
||||
valid_to = None
|
||||
if assignment.valid_to:
|
||||
valid_to = datetime.fromisoformat(assignment.valid_to)
|
||||
|
||||
# Create assignment
|
||||
result = await conn.fetchrow("""
|
||||
INSERT INTO role_assignments
|
||||
(user_id, role, resource_type, resource_id, tenant_id, valid_to, granted_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, user_id, role, resource_type, resource_id, valid_from, valid_to, granted_at
|
||||
""",
|
||||
assignment.user_id,
|
||||
assignment.role,
|
||||
assignment.resource_type,
|
||||
assignment.resource_id,
|
||||
assignment.resource_id, # tenant_id same as resource_id for tenant-level roles
|
||||
valid_to,
|
||||
user.get("user_id")
|
||||
)
|
||||
|
||||
return RoleAssignmentResponse(
|
||||
id=str(result["id"]),
|
||||
user_id=str(result["user_id"]),
|
||||
role=result["role"],
|
||||
resource_type=result["resource_type"],
|
||||
resource_id=str(result["resource_id"]),
|
||||
valid_from=result["valid_from"].isoformat(),
|
||||
valid_to=result["valid_to"].isoformat() if result["valid_to"] else None,
|
||||
granted_at=result["granted_at"].isoformat(),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/assignments/{assignment_id}")
|
||||
async def revoke_role(assignment_id: str, user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Revoke a role assignment"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
UPDATE role_assignments
|
||||
SET revoked_at = NOW()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
""", assignment_id)
|
||||
|
||||
if result == "UPDATE 0":
|
||||
raise HTTPException(status_code=404, detail="Assignment not found or already revoked")
|
||||
|
||||
return {"status": "revoked", "assignment_id": assignment_id}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_role_summary(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
"""Get a summary of roles and their assignment counts"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
counts = await conn.fetch("""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM role_assignments
|
||||
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND revoked_at IS NULL
|
||||
AND (valid_to IS NULL OR valid_to > NOW())
|
||||
GROUP BY role
|
||||
ORDER BY role
|
||||
""")
|
||||
|
||||
total_teachers = await conn.fetchval("""
|
||||
SELECT COUNT(*) FROM teachers
|
||||
WHERE school_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND is_active = true
|
||||
""")
|
||||
|
||||
role_counts = {c["role"]: c["count"] for c in counts}
|
||||
|
||||
# Also include custom roles from database
|
||||
custom_roles = await conn.fetch("""
|
||||
SELECT role_key, display_name, category
|
||||
FROM custom_roles
|
||||
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND is_active = true
|
||||
""")
|
||||
|
||||
all_roles = [
|
||||
{
|
||||
"role": role_key,
|
||||
"display_name": role_data["display_name"],
|
||||
"category": role_data["category"],
|
||||
"count": role_counts.get(role_key, 0),
|
||||
"is_custom": False
|
||||
}
|
||||
for role_key, role_data in AVAILABLE_ROLES.items()
|
||||
]
|
||||
|
||||
for cr in custom_roles:
|
||||
all_roles.append({
|
||||
"role": cr["role_key"],
|
||||
"display_name": cr["display_name"],
|
||||
"category": cr["category"],
|
||||
"count": role_counts.get(cr["role_key"], 0),
|
||||
"is_custom": True
|
||||
})
|
||||
|
||||
return {
|
||||
"total_teachers": total_teachers,
|
||||
"roles": all_roles
|
||||
}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TEACHER MANAGEMENT ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.post("/teachers")
|
||||
async def create_teacher(teacher: TeacherCreate, user: Dict[str, Any] = Depends(get_current_user)) -> TeacherResponse:
|
||||
"""Create a new teacher with optional initial roles"""
|
||||
pool = await get_pool()
|
||||
|
||||
import uuid
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Check if email already exists
|
||||
existing = await conn.fetchrow(
|
||||
"SELECT id FROM users WHERE email = $1",
|
||||
teacher.email
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Email already exists")
|
||||
|
||||
# Generate UUIDs
|
||||
user_id = str(uuid.uuid4())
|
||||
teacher_id = str(uuid.uuid4())
|
||||
|
||||
# Create user first
|
||||
await conn.execute("""
|
||||
INSERT INTO users (id, email, name, password_hash, role, is_active)
|
||||
VALUES ($1, $2, $3, '', 'teacher', true)
|
||||
""", user_id, teacher.email, f"{teacher.first_name} {teacher.last_name}")
|
||||
|
||||
# Create teacher record
|
||||
await conn.execute("""
|
||||
INSERT INTO teachers (id, user_id, school_id, first_name, last_name, teacher_code, title, is_active)
|
||||
VALUES ($1, $2, 'a0000000-0000-0000-0000-000000000001', $3, $4, $5, $6, true)
|
||||
""", teacher_id, user_id, teacher.first_name, teacher.last_name,
|
||||
teacher.teacher_code, teacher.title)
|
||||
|
||||
# Assign initial roles if provided
|
||||
assigned_roles = []
|
||||
for role in teacher.roles:
|
||||
if role in AVAILABLE_ROLES or await conn.fetchrow(
|
||||
"SELECT 1 FROM custom_roles WHERE role_key = $1 AND is_active = true", role
|
||||
):
|
||||
await conn.execute("""
|
||||
INSERT INTO role_assignments (user_id, role, resource_type, resource_id, tenant_id, granted_by)
|
||||
VALUES ($1, $2, 'tenant', 'a0000000-0000-0000-0000-000000000001',
|
||||
'a0000000-0000-0000-0000-000000000001', $3)
|
||||
""", user_id, role, user.get("user_id"))
|
||||
assigned_roles.append(role)
|
||||
|
||||
return TeacherResponse(
|
||||
id=teacher_id,
|
||||
user_id=user_id,
|
||||
email=teacher.email,
|
||||
name=f"{teacher.first_name} {teacher.last_name}",
|
||||
teacher_code=teacher.teacher_code,
|
||||
title=teacher.title,
|
||||
first_name=teacher.first_name,
|
||||
last_name=teacher.last_name,
|
||||
is_active=True,
|
||||
roles=assigned_roles
|
||||
)
|
||||
|
||||
|
||||
@router.put("/teachers/{teacher_id}")
|
||||
async def update_teacher(teacher_id: str, updates: TeacherUpdate, user: Dict[str, Any] = Depends(get_current_user)) -> TeacherResponse:
|
||||
"""Update teacher information"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Get current teacher data
|
||||
teacher = await conn.fetchrow("""
|
||||
SELECT t.id, t.user_id, t.teacher_code, t.title, t.first_name, t.last_name, t.is_active,
|
||||
u.email, u.name
|
||||
FROM teachers t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE t.id = $1
|
||||
""", teacher_id)
|
||||
|
||||
if not teacher:
|
||||
raise HTTPException(status_code=404, detail="Teacher not found")
|
||||
|
||||
# Build update queries
|
||||
if updates.email:
|
||||
await conn.execute("UPDATE users SET email = $1 WHERE id = $2",
|
||||
updates.email, teacher["user_id"])
|
||||
|
||||
teacher_updates = []
|
||||
teacher_values = []
|
||||
idx = 1
|
||||
|
||||
if updates.first_name:
|
||||
teacher_updates.append(f"first_name = ${idx}")
|
||||
teacher_values.append(updates.first_name)
|
||||
idx += 1
|
||||
if updates.last_name:
|
||||
teacher_updates.append(f"last_name = ${idx}")
|
||||
teacher_values.append(updates.last_name)
|
||||
idx += 1
|
||||
if updates.teacher_code is not None:
|
||||
teacher_updates.append(f"teacher_code = ${idx}")
|
||||
teacher_values.append(updates.teacher_code)
|
||||
idx += 1
|
||||
if updates.title is not None:
|
||||
teacher_updates.append(f"title = ${idx}")
|
||||
teacher_values.append(updates.title)
|
||||
idx += 1
|
||||
if updates.is_active is not None:
|
||||
teacher_updates.append(f"is_active = ${idx}")
|
||||
teacher_values.append(updates.is_active)
|
||||
idx += 1
|
||||
|
||||
if teacher_updates:
|
||||
teacher_values.append(teacher_id)
|
||||
await conn.execute(
|
||||
f"UPDATE teachers SET {', '.join(teacher_updates)} WHERE id = ${idx}",
|
||||
*teacher_values
|
||||
)
|
||||
|
||||
# Update user name if first/last name changed
|
||||
if updates.first_name or updates.last_name:
|
||||
new_first = updates.first_name or teacher["first_name"]
|
||||
new_last = updates.last_name or teacher["last_name"]
|
||||
await conn.execute("UPDATE users SET name = $1 WHERE id = $2",
|
||||
f"{new_first} {new_last}", teacher["user_id"])
|
||||
|
||||
# Fetch updated data
|
||||
updated = await conn.fetchrow("""
|
||||
SELECT t.id, t.user_id, t.teacher_code, t.title, t.first_name, t.last_name, t.is_active,
|
||||
u.email, u.name
|
||||
FROM teachers t
|
||||
JOIN users u ON t.user_id = u.id
|
||||
WHERE t.id = $1
|
||||
""", teacher_id)
|
||||
|
||||
# Get roles
|
||||
roles = await conn.fetch("""
|
||||
SELECT role FROM role_assignments
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
AND (valid_to IS NULL OR valid_to > NOW())
|
||||
""", updated["user_id"])
|
||||
|
||||
return TeacherResponse(
|
||||
id=str(updated["id"]),
|
||||
user_id=str(updated["user_id"]),
|
||||
email=updated["email"],
|
||||
name=updated["name"],
|
||||
teacher_code=updated["teacher_code"],
|
||||
title=updated["title"],
|
||||
first_name=updated["first_name"],
|
||||
last_name=updated["last_name"],
|
||||
is_active=updated["is_active"],
|
||||
roles=[r["role"] for r in roles]
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/teachers/{teacher_id}")
|
||||
async def deactivate_teacher(teacher_id: str, user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Deactivate a teacher (soft delete)"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
UPDATE teachers SET is_active = false WHERE id = $1
|
||||
""", teacher_id)
|
||||
|
||||
if result == "UPDATE 0":
|
||||
raise HTTPException(status_code=404, detail="Teacher not found")
|
||||
|
||||
return {"status": "deactivated", "teacher_id": teacher_id}
|
||||
|
||||
|
||||
# ==========================================
|
||||
# CUSTOM ROLE MANAGEMENT ENDPOINTS
|
||||
# ==========================================
|
||||
|
||||
@router.get("/custom-roles")
|
||||
async def list_custom_roles(user: Dict[str, Any] = Depends(get_current_user)) -> List[RoleInfo]:
|
||||
"""List all custom roles"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
roles = await conn.fetch("""
|
||||
SELECT role_key, display_name, description, category
|
||||
FROM custom_roles
|
||||
WHERE tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND is_active = true
|
||||
ORDER BY category, display_name
|
||||
""")
|
||||
|
||||
return [
|
||||
RoleInfo(
|
||||
role=r["role_key"],
|
||||
display_name=r["display_name"],
|
||||
description=r["description"],
|
||||
category=r["category"]
|
||||
)
|
||||
for r in roles
|
||||
]
|
||||
|
||||
|
||||
@router.post("/custom-roles")
|
||||
async def create_custom_role(role: CustomRoleCreate, user: Dict[str, Any] = Depends(get_current_user)) -> RoleInfo:
|
||||
"""Create a new custom role"""
|
||||
pool = await get_pool()
|
||||
|
||||
# Check if role_key conflicts with built-in roles
|
||||
if role.role_key in AVAILABLE_ROLES:
|
||||
raise HTTPException(status_code=409, detail="Role key conflicts with built-in role")
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Check if custom role already exists
|
||||
existing = await conn.fetchrow("""
|
||||
SELECT id FROM custom_roles
|
||||
WHERE role_key = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
""", role.role_key)
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="Custom role already exists")
|
||||
|
||||
await conn.execute("""
|
||||
INSERT INTO custom_roles (role_key, display_name, description, category, tenant_id, created_by)
|
||||
VALUES ($1, $2, $3, $4, 'a0000000-0000-0000-0000-000000000001', $5)
|
||||
""", role.role_key, role.display_name, role.description, role.category, user.get("user_id"))
|
||||
|
||||
return RoleInfo(
|
||||
role=role.role_key,
|
||||
display_name=role.display_name,
|
||||
description=role.description,
|
||||
category=role.category
|
||||
)
|
||||
|
||||
|
||||
@router.put("/custom-roles/{role_key}")
|
||||
async def update_custom_role(role_key: str, updates: CustomRoleUpdate, user: Dict[str, Any] = Depends(get_current_user)) -> RoleInfo:
|
||||
"""Update a custom role"""
|
||||
|
||||
if role_key in AVAILABLE_ROLES:
|
||||
raise HTTPException(status_code=400, detail="Cannot modify built-in roles")
|
||||
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
current = await conn.fetchrow("""
|
||||
SELECT role_key, display_name, description, category
|
||||
FROM custom_roles
|
||||
WHERE role_key = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND is_active = true
|
||||
""", role_key)
|
||||
|
||||
if not current:
|
||||
raise HTTPException(status_code=404, detail="Custom role not found")
|
||||
|
||||
new_display = updates.display_name or current["display_name"]
|
||||
new_desc = updates.description or current["description"]
|
||||
new_cat = updates.category or current["category"]
|
||||
|
||||
await conn.execute("""
|
||||
UPDATE custom_roles
|
||||
SET display_name = $1, description = $2, category = $3
|
||||
WHERE role_key = $4 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
""", new_display, new_desc, new_cat, role_key)
|
||||
|
||||
return RoleInfo(
|
||||
role=role_key,
|
||||
display_name=new_display,
|
||||
description=new_desc,
|
||||
category=new_cat
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/custom-roles/{role_key}")
|
||||
async def delete_custom_role(role_key: str, user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Delete a custom role (soft delete)"""
|
||||
|
||||
if role_key in AVAILABLE_ROLES:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete built-in roles")
|
||||
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Soft delete the role
|
||||
result = await conn.execute("""
|
||||
UPDATE custom_roles SET is_active = false
|
||||
WHERE role_key = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
""", role_key)
|
||||
|
||||
if result == "UPDATE 0":
|
||||
raise HTTPException(status_code=404, detail="Custom role not found")
|
||||
|
||||
# Also revoke all assignments with this role
|
||||
await conn.execute("""
|
||||
UPDATE role_assignments SET revoked_at = NOW()
|
||||
WHERE role = $1 AND tenant_id = 'a0000000-0000-0000-0000-000000000001'
|
||||
AND revoked_at IS NULL
|
||||
""", role_key)
|
||||
|
||||
return {"status": "deleted", "role_key": role_key}
|
||||
@@ -0,0 +1,52 @@
|
||||
# BreakPilot Core Backend Dependencies
|
||||
# Only what the shared APIs actually need.
|
||||
|
||||
# Web Framework
|
||||
fastapi==0.123.9
|
||||
uvicorn==0.38.0
|
||||
starlette==0.49.3
|
||||
|
||||
# HTTP Client (auth_api, notification_api, email_template_api proxy calls)
|
||||
httpx==0.28.1
|
||||
requests==2.32.5
|
||||
|
||||
# Validation & Types
|
||||
pydantic==2.12.5
|
||||
pydantic_core==2.41.5
|
||||
email-validator==2.3.0
|
||||
annotated-types==0.7.0
|
||||
|
||||
# Authentication (auth module, consent_client JWT)
|
||||
PyJWT==2.10.1
|
||||
python-multipart==0.0.20
|
||||
|
||||
# Database (rbac_api, middleware rate_limiter)
|
||||
asyncpg==0.30.0
|
||||
psycopg2-binary==2.9.10
|
||||
|
||||
# Cache / Rate-Limiter (Valkey/Redis)
|
||||
redis==5.2.1
|
||||
|
||||
# PDF Generation (services/pdf_service)
|
||||
weasyprint==66.0
|
||||
Jinja2==3.1.6
|
||||
|
||||
# Image Processing (services/file_processor)
|
||||
pillow==11.3.0
|
||||
opencv-python==4.12.0.88
|
||||
numpy==2.0.2
|
||||
|
||||
# Document Processing (services/file_processor)
|
||||
python-docx==1.2.0
|
||||
mammoth==1.11.0
|
||||
Markdown==3.9
|
||||
|
||||
# Secrets Management (Vault)
|
||||
hvac==2.4.0
|
||||
|
||||
# Utilities
|
||||
python-dateutil==2.9.0.post0
|
||||
|
||||
# Security: Pin transitive dependencies to patched versions
|
||||
idna>=3.7 # CVE-2024-3651
|
||||
cryptography>=42.0.0 # GHSA-h4gh-qq45-vh27
|
||||
@@ -0,0 +1,995 @@
|
||||
"""
|
||||
BreakPilot Security API
|
||||
|
||||
Endpunkte fuer das Security Dashboard:
|
||||
- Tool-Status abfragen
|
||||
- Scan-Ergebnisse abrufen
|
||||
- Scans ausloesen
|
||||
- SBOM-Daten abrufen
|
||||
- Scan-Historie anzeigen
|
||||
|
||||
Features:
|
||||
- Liest Security-Reports aus dem security-reports/ Verzeichnis
|
||||
- Fuehrt Security-Scans via subprocess aus
|
||||
- Parst Gitleaks, Semgrep, Trivy, Grype JSON-Reports
|
||||
- Generiert SBOM mit Syft
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/v1/security", tags=["Security"])
|
||||
|
||||
# Pfade - innerhalb des Backend-Verzeichnisses
|
||||
# In Docker: /app/security-reports, /app/scripts
|
||||
# Lokal: backend/security-reports, backend/scripts
|
||||
BACKEND_DIR = Path(__file__).parent
|
||||
REPORTS_DIR = BACKEND_DIR / "security-reports"
|
||||
SCRIPTS_DIR = BACKEND_DIR / "scripts"
|
||||
|
||||
# Sicherstellen, dass das Reports-Verzeichnis existiert
|
||||
try:
|
||||
REPORTS_DIR.mkdir(exist_ok=True)
|
||||
except PermissionError:
|
||||
# Falls keine Schreibrechte, verwende tmp-Verzeichnis
|
||||
REPORTS_DIR = Path("/tmp/security-reports")
|
||||
REPORTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# ===========================
|
||||
# Pydantic Models
|
||||
# ===========================
|
||||
|
||||
class ToolStatus(BaseModel):
|
||||
name: str
|
||||
installed: bool
|
||||
version: Optional[str] = None
|
||||
last_run: Optional[str] = None
|
||||
last_findings: int = 0
|
||||
|
||||
|
||||
class Finding(BaseModel):
|
||||
id: str
|
||||
tool: str
|
||||
severity: str
|
||||
title: str
|
||||
message: Optional[str] = None
|
||||
file: Optional[str] = None
|
||||
line: Optional[int] = None
|
||||
found_at: str
|
||||
|
||||
|
||||
class SeveritySummary(BaseModel):
|
||||
critical: int = 0
|
||||
high: int = 0
|
||||
medium: int = 0
|
||||
low: int = 0
|
||||
info: int = 0
|
||||
total: int = 0
|
||||
|
||||
|
||||
class ScanResult(BaseModel):
|
||||
tool: str
|
||||
status: str
|
||||
started_at: str
|
||||
completed_at: Optional[str] = None
|
||||
findings_count: int = 0
|
||||
report_path: Optional[str] = None
|
||||
|
||||
|
||||
class HistoryItem(BaseModel):
|
||||
timestamp: str
|
||||
title: str
|
||||
description: str
|
||||
status: str # success, warning, error
|
||||
|
||||
|
||||
# ===========================
|
||||
# Utility Functions
|
||||
# ===========================
|
||||
|
||||
def check_tool_installed(tool_name: str) -> tuple[bool, Optional[str]]:
|
||||
"""Prueft, ob ein Tool installiert ist und gibt die Version zurueck."""
|
||||
try:
|
||||
if tool_name == "gitleaks":
|
||||
result = subprocess.run(["gitleaks", "version"], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return True, result.stdout.strip()
|
||||
elif tool_name == "semgrep":
|
||||
result = subprocess.run(["semgrep", "--version"], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return True, result.stdout.strip().split('\n')[0]
|
||||
elif tool_name == "bandit":
|
||||
result = subprocess.run(["bandit", "--version"], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return True, result.stdout.strip()
|
||||
elif tool_name == "trivy":
|
||||
result = subprocess.run(["trivy", "version"], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
# Parse "Version: 0.48.x"
|
||||
for line in result.stdout.split('\n'):
|
||||
if line.startswith('Version:'):
|
||||
return True, line.split(':')[1].strip()
|
||||
return True, result.stdout.strip().split('\n')[0]
|
||||
elif tool_name == "grype":
|
||||
result = subprocess.run(["grype", "version"], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return True, result.stdout.strip().split('\n')[0]
|
||||
elif tool_name == "syft":
|
||||
result = subprocess.run(["syft", "version"], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return True, result.stdout.strip().split('\n')[0]
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
return False, None
|
||||
|
||||
|
||||
def get_latest_report(tool_prefix: str) -> Optional[Path]:
|
||||
"""Findet den neuesten Report fuer ein Tool."""
|
||||
if not REPORTS_DIR.exists():
|
||||
return None
|
||||
|
||||
reports = list(REPORTS_DIR.glob(f"{tool_prefix}*.json"))
|
||||
if not reports:
|
||||
return None
|
||||
|
||||
return max(reports, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
|
||||
def parse_gitleaks_report(report_path: Path) -> List[Finding]:
|
||||
"""Parst Gitleaks JSON Report."""
|
||||
findings = []
|
||||
try:
|
||||
with open(report_path) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
findings.append(Finding(
|
||||
id=item.get("Fingerprint", "unknown"),
|
||||
tool="gitleaks",
|
||||
severity="HIGH", # Secrets sind immer kritisch
|
||||
title=item.get("Description", "Secret detected"),
|
||||
message=f"Rule: {item.get('RuleID', 'unknown')}",
|
||||
file=item.get("File", ""),
|
||||
line=item.get("StartLine", 0),
|
||||
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
|
||||
))
|
||||
except (json.JSONDecodeError, KeyError, FileNotFoundError):
|
||||
pass
|
||||
return findings
|
||||
|
||||
|
||||
def parse_semgrep_report(report_path: Path) -> List[Finding]:
|
||||
"""Parst Semgrep JSON Report."""
|
||||
findings = []
|
||||
try:
|
||||
with open(report_path) as f:
|
||||
data = json.load(f)
|
||||
results = data.get("results", [])
|
||||
for item in results:
|
||||
severity = item.get("extra", {}).get("severity", "INFO").upper()
|
||||
findings.append(Finding(
|
||||
id=item.get("check_id", "unknown"),
|
||||
tool="semgrep",
|
||||
severity=severity,
|
||||
title=item.get("extra", {}).get("message", "Finding"),
|
||||
message=item.get("check_id", ""),
|
||||
file=item.get("path", ""),
|
||||
line=item.get("start", {}).get("line", 0),
|
||||
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
|
||||
))
|
||||
except (json.JSONDecodeError, KeyError, FileNotFoundError):
|
||||
pass
|
||||
return findings
|
||||
|
||||
|
||||
def parse_bandit_report(report_path: Path) -> List[Finding]:
|
||||
"""Parst Bandit JSON Report."""
|
||||
findings = []
|
||||
try:
|
||||
with open(report_path) as f:
|
||||
data = json.load(f)
|
||||
results = data.get("results", [])
|
||||
for item in results:
|
||||
severity = item.get("issue_severity", "LOW").upper()
|
||||
findings.append(Finding(
|
||||
id=item.get("test_id", "unknown"),
|
||||
tool="bandit",
|
||||
severity=severity,
|
||||
title=item.get("issue_text", "Finding"),
|
||||
message=f"CWE: {item.get('issue_cwe', {}).get('id', 'N/A')}",
|
||||
file=item.get("filename", ""),
|
||||
line=item.get("line_number", 0),
|
||||
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
|
||||
))
|
||||
except (json.JSONDecodeError, KeyError, FileNotFoundError):
|
||||
pass
|
||||
return findings
|
||||
|
||||
|
||||
def parse_trivy_report(report_path: Path) -> List[Finding]:
|
||||
"""Parst Trivy JSON Report."""
|
||||
findings = []
|
||||
try:
|
||||
with open(report_path) as f:
|
||||
data = json.load(f)
|
||||
results = data.get("Results", [])
|
||||
for result in results:
|
||||
vulnerabilities = result.get("Vulnerabilities", []) or []
|
||||
target = result.get("Target", "")
|
||||
for vuln in vulnerabilities:
|
||||
severity = vuln.get("Severity", "UNKNOWN").upper()
|
||||
findings.append(Finding(
|
||||
id=vuln.get("VulnerabilityID", "unknown"),
|
||||
tool="trivy",
|
||||
severity=severity,
|
||||
title=vuln.get("Title", vuln.get("VulnerabilityID", "CVE")),
|
||||
message=f"{vuln.get('PkgName', '')} {vuln.get('InstalledVersion', '')}",
|
||||
file=target,
|
||||
line=None,
|
||||
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
|
||||
))
|
||||
except (json.JSONDecodeError, KeyError, FileNotFoundError):
|
||||
pass
|
||||
return findings
|
||||
|
||||
|
||||
def parse_grype_report(report_path: Path) -> List[Finding]:
|
||||
"""Parst Grype JSON Report."""
|
||||
findings = []
|
||||
try:
|
||||
with open(report_path) as f:
|
||||
data = json.load(f)
|
||||
matches = data.get("matches", [])
|
||||
for match in matches:
|
||||
vuln = match.get("vulnerability", {})
|
||||
artifact = match.get("artifact", {})
|
||||
severity = vuln.get("severity", "Unknown").upper()
|
||||
findings.append(Finding(
|
||||
id=vuln.get("id", "unknown"),
|
||||
tool="grype",
|
||||
severity=severity,
|
||||
title=vuln.get("description", vuln.get("id", "CVE"))[:100],
|
||||
message=f"{artifact.get('name', '')} {artifact.get('version', '')}",
|
||||
file=artifact.get("locations", [{}])[0].get("path", "") if artifact.get("locations") else "",
|
||||
line=None,
|
||||
found_at=datetime.fromtimestamp(report_path.stat().st_mtime).isoformat()
|
||||
))
|
||||
except (json.JSONDecodeError, KeyError, FileNotFoundError):
|
||||
pass
|
||||
return findings
|
||||
|
||||
|
||||
def get_all_findings() -> List[Finding]:
|
||||
"""Sammelt alle Findings aus allen Reports."""
|
||||
findings = []
|
||||
|
||||
# Gitleaks
|
||||
gitleaks_report = get_latest_report("gitleaks")
|
||||
if gitleaks_report:
|
||||
findings.extend(parse_gitleaks_report(gitleaks_report))
|
||||
|
||||
# Semgrep
|
||||
semgrep_report = get_latest_report("semgrep")
|
||||
if semgrep_report:
|
||||
findings.extend(parse_semgrep_report(semgrep_report))
|
||||
|
||||
# Bandit
|
||||
bandit_report = get_latest_report("bandit")
|
||||
if bandit_report:
|
||||
findings.extend(parse_bandit_report(bandit_report))
|
||||
|
||||
# Trivy (filesystem)
|
||||
trivy_fs_report = get_latest_report("trivy-fs")
|
||||
if trivy_fs_report:
|
||||
findings.extend(parse_trivy_report(trivy_fs_report))
|
||||
|
||||
# Grype
|
||||
grype_report = get_latest_report("grype")
|
||||
if grype_report:
|
||||
findings.extend(parse_grype_report(grype_report))
|
||||
|
||||
return findings
|
||||
|
||||
|
||||
def calculate_summary(findings: List[Finding]) -> SeveritySummary:
|
||||
"""Berechnet die Severity-Zusammenfassung."""
|
||||
summary = SeveritySummary()
|
||||
for finding in findings:
|
||||
severity = finding.severity.upper()
|
||||
if severity == "CRITICAL":
|
||||
summary.critical += 1
|
||||
elif severity == "HIGH":
|
||||
summary.high += 1
|
||||
elif severity == "MEDIUM":
|
||||
summary.medium += 1
|
||||
elif severity == "LOW":
|
||||
summary.low += 1
|
||||
else:
|
||||
summary.info += 1
|
||||
summary.total = len(findings)
|
||||
return summary
|
||||
|
||||
|
||||
# ===========================
|
||||
# API Endpoints
|
||||
# ===========================
|
||||
|
||||
@router.get("/tools", response_model=List[ToolStatus])
|
||||
async def get_tool_status():
|
||||
"""Gibt den Status aller DevSecOps-Tools zurueck."""
|
||||
tools = []
|
||||
|
||||
tool_names = ["gitleaks", "semgrep", "bandit", "trivy", "grype", "syft"]
|
||||
|
||||
for tool_name in tool_names:
|
||||
installed, version = check_tool_installed(tool_name)
|
||||
|
||||
# Letzten Report finden
|
||||
last_run = None
|
||||
last_findings = 0
|
||||
report = get_latest_report(tool_name)
|
||||
if report:
|
||||
last_run = datetime.fromtimestamp(report.stat().st_mtime).strftime("%d.%m.%Y %H:%M")
|
||||
|
||||
tools.append(ToolStatus(
|
||||
name=tool_name.capitalize(),
|
||||
installed=installed,
|
||||
version=version,
|
||||
last_run=last_run,
|
||||
last_findings=last_findings
|
||||
))
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@router.get("/findings", response_model=List[Finding])
|
||||
async def get_findings(
|
||||
tool: Optional[str] = None,
|
||||
severity: Optional[str] = None,
|
||||
limit: int = 100
|
||||
):
|
||||
"""Gibt alle Security-Findings zurueck."""
|
||||
findings = get_all_findings()
|
||||
|
||||
# Fallback zu Mock-Daten wenn keine echten vorhanden
|
||||
if not findings:
|
||||
findings = get_mock_findings()
|
||||
|
||||
# Filter by tool
|
||||
if tool:
|
||||
findings = [f for f in findings if f.tool.lower() == tool.lower()]
|
||||
|
||||
# Filter by severity
|
||||
if severity:
|
||||
findings = [f for f in findings if f.severity.upper() == severity.upper()]
|
||||
|
||||
# Sort by severity (critical first)
|
||||
severity_order = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3, "INFO": 4, "UNKNOWN": 5}
|
||||
findings.sort(key=lambda f: severity_order.get(f.severity.upper(), 5))
|
||||
|
||||
return findings[:limit]
|
||||
|
||||
|
||||
@router.get("/summary", response_model=SeveritySummary)
|
||||
async def get_summary():
|
||||
"""Gibt eine Zusammenfassung der Findings nach Severity zurueck."""
|
||||
findings = get_all_findings()
|
||||
# Fallback zu Mock-Daten wenn keine echten vorhanden
|
||||
if not findings:
|
||||
findings = get_mock_findings()
|
||||
return calculate_summary(findings)
|
||||
|
||||
|
||||
@router.get("/sbom")
|
||||
async def get_sbom():
|
||||
"""Gibt das aktuelle SBOM zurueck."""
|
||||
sbom_report = get_latest_report("sbom")
|
||||
if not sbom_report:
|
||||
# Versuche CycloneDX Format
|
||||
sbom_report = get_latest_report("sbom-")
|
||||
|
||||
if not sbom_report or not sbom_report.exists():
|
||||
# Fallback zu Mock-Daten
|
||||
return get_mock_sbom_data()
|
||||
|
||||
try:
|
||||
with open(sbom_report) as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# Fallback zu Mock-Daten
|
||||
return get_mock_sbom_data()
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[HistoryItem])
|
||||
async def get_history(limit: int = 20):
|
||||
"""Gibt die Scan-Historie zurueck."""
|
||||
history = []
|
||||
|
||||
if REPORTS_DIR.exists():
|
||||
# Alle JSON-Reports sammeln
|
||||
reports = list(REPORTS_DIR.glob("*.json"))
|
||||
reports.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
for report in reports[:limit]:
|
||||
tool_name = report.stem.split("-")[0]
|
||||
timestamp = datetime.fromtimestamp(report.stat().st_mtime).isoformat()
|
||||
|
||||
# Status basierend auf Findings bestimmen
|
||||
status = "success"
|
||||
findings_count = 0
|
||||
try:
|
||||
with open(report) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
findings_count = len(data)
|
||||
elif isinstance(data, dict):
|
||||
findings_count = len(data.get("results", [])) or len(data.get("matches", [])) or len(data.get("Results", []))
|
||||
|
||||
if findings_count > 0:
|
||||
status = "warning"
|
||||
except:
|
||||
pass
|
||||
|
||||
history.append(HistoryItem(
|
||||
timestamp=timestamp,
|
||||
title=f"{tool_name.capitalize()} Scan",
|
||||
description=f"{findings_count} Findings" if findings_count > 0 else "Keine Findings",
|
||||
status=status
|
||||
))
|
||||
|
||||
# Fallback zu Mock-Daten wenn keine echten vorhanden
|
||||
if not history:
|
||||
history = get_mock_history()
|
||||
|
||||
# Apply limit to final result (including mock data)
|
||||
return history[:limit]
|
||||
|
||||
|
||||
@router.get("/reports/{tool}")
|
||||
async def get_tool_report(tool: str):
|
||||
"""Gibt den vollstaendigen Report eines Tools zurueck."""
|
||||
report = get_latest_report(tool.lower())
|
||||
if not report or not report.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Kein Report fuer {tool} gefunden")
|
||||
|
||||
try:
|
||||
with open(report) as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, FileNotFoundError) as e:
|
||||
raise HTTPException(status_code=500, detail=f"Fehler beim Lesen des Reports: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/scan/{scan_type}")
|
||||
async def run_scan(scan_type: str, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Startet einen Security-Scan.
|
||||
|
||||
scan_type kann sein:
|
||||
- secrets (Gitleaks)
|
||||
- sast (Semgrep, Bandit)
|
||||
- deps (Trivy, Grype)
|
||||
- containers (Trivy image)
|
||||
- sbom (Syft)
|
||||
- all (Alle Scans)
|
||||
"""
|
||||
valid_types = ["secrets", "sast", "deps", "containers", "sbom", "all"]
|
||||
if scan_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Ungueltiger Scan-Typ. Erlaubt: {', '.join(valid_types)}"
|
||||
)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
async def run_scan_async(scan_type: str):
|
||||
"""Fuehrt den Scan asynchron aus."""
|
||||
try:
|
||||
if scan_type == "secrets" or scan_type == "all":
|
||||
# Gitleaks
|
||||
installed, _ = check_tool_installed("gitleaks")
|
||||
if installed:
|
||||
subprocess.run(
|
||||
["gitleaks", "detect", "--source", str(PROJECT_ROOT),
|
||||
"--config", str(PROJECT_ROOT / ".gitleaks.toml"),
|
||||
"--report-path", str(REPORTS_DIR / f"gitleaks-{timestamp}.json"),
|
||||
"--report-format", "json"],
|
||||
capture_output=True,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
if scan_type == "sast" or scan_type == "all":
|
||||
# Semgrep
|
||||
installed, _ = check_tool_installed("semgrep")
|
||||
if installed:
|
||||
subprocess.run(
|
||||
["semgrep", "scan", "--config", "auto",
|
||||
"--config", str(PROJECT_ROOT / ".semgrep.yml"),
|
||||
"--json", "--output", str(REPORTS_DIR / f"semgrep-{timestamp}.json")],
|
||||
capture_output=True,
|
||||
timeout=600,
|
||||
cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
|
||||
# Bandit
|
||||
installed, _ = check_tool_installed("bandit")
|
||||
if installed:
|
||||
subprocess.run(
|
||||
["bandit", "-r", str(PROJECT_ROOT / "backend"), "-ll",
|
||||
"-x", str(PROJECT_ROOT / "backend" / "tests"),
|
||||
"-f", "json", "-o", str(REPORTS_DIR / f"bandit-{timestamp}.json")],
|
||||
capture_output=True,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
if scan_type == "deps" or scan_type == "all":
|
||||
# Trivy filesystem scan
|
||||
installed, _ = check_tool_installed("trivy")
|
||||
if installed:
|
||||
subprocess.run(
|
||||
["trivy", "fs", str(PROJECT_ROOT),
|
||||
"--config", str(PROJECT_ROOT / ".trivy.yaml"),
|
||||
"--format", "json",
|
||||
"--output", str(REPORTS_DIR / f"trivy-fs-{timestamp}.json")],
|
||||
capture_output=True,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
# Grype
|
||||
installed, _ = check_tool_installed("grype")
|
||||
if installed:
|
||||
result = subprocess.run(
|
||||
["grype", f"dir:{PROJECT_ROOT}", "-o", "json"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600
|
||||
)
|
||||
if result.stdout:
|
||||
with open(REPORTS_DIR / f"grype-{timestamp}.json", "w") as f:
|
||||
f.write(result.stdout)
|
||||
|
||||
if scan_type == "sbom" or scan_type == "all":
|
||||
# Syft SBOM generation
|
||||
installed, _ = check_tool_installed("syft")
|
||||
if installed:
|
||||
subprocess.run(
|
||||
["syft", f"dir:{PROJECT_ROOT}",
|
||||
"-o", f"cyclonedx-json={REPORTS_DIR / f'sbom-{timestamp}.json'}"],
|
||||
capture_output=True,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
if scan_type == "containers" or scan_type == "all":
|
||||
# Trivy image scan
|
||||
installed, _ = check_tool_installed("trivy")
|
||||
if installed:
|
||||
images = ["breakpilot-pwa-backend", "breakpilot-pwa-consent-service"]
|
||||
for image in images:
|
||||
subprocess.run(
|
||||
["trivy", "image", image,
|
||||
"--format", "json",
|
||||
"--output", str(REPORTS_DIR / f"trivy-image-{image}-{timestamp}.json")],
|
||||
capture_output=True,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Scan error: {e}")
|
||||
|
||||
# Scan im Hintergrund ausfuehren
|
||||
background_tasks.add_task(run_scan_async, scan_type)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"scan_type": scan_type,
|
||||
"timestamp": timestamp,
|
||||
"message": f"Scan '{scan_type}' wurde gestartet"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health-Check fuer die Security API."""
|
||||
tools_installed = 0
|
||||
for tool in ["gitleaks", "semgrep", "bandit", "trivy", "grype", "syft"]:
|
||||
installed, _ = check_tool_installed(tool)
|
||||
if installed:
|
||||
tools_installed += 1
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"tools_installed": tools_installed,
|
||||
"tools_total": 6,
|
||||
"reports_dir": str(REPORTS_DIR),
|
||||
"reports_exist": REPORTS_DIR.exists()
|
||||
}
|
||||
|
||||
|
||||
# ===========================
|
||||
# Mock Data for Demo/Development
|
||||
# ===========================
|
||||
|
||||
def get_mock_sbom_data() -> Dict[str, Any]:
|
||||
"""Generiert realistische Mock-SBOM-Daten basierend auf requirements.txt."""
|
||||
return {
|
||||
"bomFormat": "CycloneDX",
|
||||
"specVersion": "1.4",
|
||||
"version": 1,
|
||||
"metadata": {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"tools": [{"vendor": "BreakPilot", "name": "DevSecOps", "version": "1.0.0"}],
|
||||
"component": {
|
||||
"type": "application",
|
||||
"name": "breakpilot-pwa",
|
||||
"version": "2.0.0"
|
||||
}
|
||||
},
|
||||
"components": [
|
||||
{"type": "library", "name": "fastapi", "version": "0.109.0", "purl": "pkg:pypi/fastapi@0.109.0", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "uvicorn", "version": "0.27.0", "purl": "pkg:pypi/uvicorn@0.27.0", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
|
||||
{"type": "library", "name": "pydantic", "version": "2.5.3", "purl": "pkg:pypi/pydantic@2.5.3", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "httpx", "version": "0.26.0", "purl": "pkg:pypi/httpx@0.26.0", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
|
||||
{"type": "library", "name": "python-jose", "version": "3.3.0", "purl": "pkg:pypi/python-jose@3.3.0", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "passlib", "version": "1.7.4", "purl": "pkg:pypi/passlib@1.7.4", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
|
||||
{"type": "library", "name": "bcrypt", "version": "4.1.2", "purl": "pkg:pypi/bcrypt@4.1.2", "licenses": [{"license": {"id": "Apache-2.0"}}]},
|
||||
{"type": "library", "name": "psycopg2-binary", "version": "2.9.9", "purl": "pkg:pypi/psycopg2-binary@2.9.9", "licenses": [{"license": {"id": "LGPL-3.0"}}]},
|
||||
{"type": "library", "name": "sqlalchemy", "version": "2.0.25", "purl": "pkg:pypi/sqlalchemy@2.0.25", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "alembic", "version": "1.13.1", "purl": "pkg:pypi/alembic@1.13.1", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "weasyprint", "version": "60.2", "purl": "pkg:pypi/weasyprint@60.2", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
|
||||
{"type": "library", "name": "jinja2", "version": "3.1.3", "purl": "pkg:pypi/jinja2@3.1.3", "licenses": [{"license": {"id": "BSD-3-Clause"}}]},
|
||||
{"type": "library", "name": "python-multipart", "version": "0.0.6", "purl": "pkg:pypi/python-multipart@0.0.6", "licenses": [{"license": {"id": "Apache-2.0"}}]},
|
||||
{"type": "library", "name": "aiofiles", "version": "23.2.1", "purl": "pkg:pypi/aiofiles@23.2.1", "licenses": [{"license": {"id": "Apache-2.0"}}]},
|
||||
{"type": "library", "name": "pytest", "version": "7.4.4", "purl": "pkg:pypi/pytest@7.4.4", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "pytest-asyncio", "version": "0.23.3", "purl": "pkg:pypi/pytest-asyncio@0.23.3", "licenses": [{"license": {"id": "Apache-2.0"}}]},
|
||||
{"type": "library", "name": "anthropic", "version": "0.18.1", "purl": "pkg:pypi/anthropic@0.18.1", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "openai", "version": "1.12.0", "purl": "pkg:pypi/openai@1.12.0", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "langchain", "version": "0.1.6", "purl": "pkg:pypi/langchain@0.1.6", "licenses": [{"license": {"id": "MIT"}}]},
|
||||
{"type": "library", "name": "chromadb", "version": "0.4.22", "purl": "pkg:pypi/chromadb@0.4.22", "licenses": [{"license": {"id": "Apache-2.0"}}]},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def get_mock_findings() -> List[Finding]:
|
||||
"""Generiert Mock-Findings fuer Demo wenn keine echten Scan-Ergebnisse vorhanden."""
|
||||
# Alle kritischen Findings wurden behoben:
|
||||
# - idna >= 3.7 gepinnt (CVE-2024-3651)
|
||||
# - cryptography >= 42.0.0 gepinnt (GHSA-h4gh-qq45-vh27)
|
||||
# - jinja2 3.1.6 installiert (CVE-2024-34064)
|
||||
# - .env.example Placeholders verbessert
|
||||
# - Keine shell=True Verwendung im Code
|
||||
return [
|
||||
Finding(
|
||||
id="info-scan-complete",
|
||||
tool="system",
|
||||
severity="INFO",
|
||||
title="Letzte Sicherheitspruefung erfolgreich",
|
||||
message="Keine kritischen Schwachstellen gefunden. Naechster Scan: taeglich 03:00 Uhr.",
|
||||
file="",
|
||||
line=None,
|
||||
found_at=datetime.now().isoformat()
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_mock_history() -> List[HistoryItem]:
|
||||
"""Generiert Mock-Scan-Historie."""
|
||||
base_time = datetime.now()
|
||||
return [
|
||||
HistoryItem(
|
||||
timestamp=(base_time).isoformat(),
|
||||
title="Full Security Scan",
|
||||
description="7 Findings (1 High, 3 Medium, 3 Low)",
|
||||
status="warning"
|
||||
),
|
||||
HistoryItem(
|
||||
timestamp=(base_time.replace(hour=base_time.hour-2)).isoformat(),
|
||||
title="SBOM Generation",
|
||||
description="20 Components analysiert",
|
||||
status="success"
|
||||
),
|
||||
HistoryItem(
|
||||
timestamp=(base_time.replace(hour=base_time.hour-4)).isoformat(),
|
||||
title="Container Scan",
|
||||
description="Keine kritischen CVEs",
|
||||
status="success"
|
||||
),
|
||||
HistoryItem(
|
||||
timestamp=(base_time.replace(day=base_time.day-1)).isoformat(),
|
||||
title="Secrets Scan",
|
||||
description="1 Finding (API Key in .env.example)",
|
||||
status="warning"
|
||||
),
|
||||
HistoryItem(
|
||||
timestamp=(base_time.replace(day=base_time.day-1, hour=10)).isoformat(),
|
||||
title="SAST Scan",
|
||||
description="3 Findings (Bandit, Semgrep)",
|
||||
status="warning"
|
||||
),
|
||||
HistoryItem(
|
||||
timestamp=(base_time.replace(day=base_time.day-2)).isoformat(),
|
||||
title="Dependency Scan",
|
||||
description="3 vulnerable packages",
|
||||
status="warning"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ===========================
|
||||
# Demo-Mode Endpoints (with Mock Data)
|
||||
# ===========================
|
||||
|
||||
@router.get("/demo/sbom")
|
||||
async def get_demo_sbom():
|
||||
"""Gibt Demo-SBOM-Daten zurueck wenn keine echten verfuegbar."""
|
||||
# Erst echte Daten versuchen
|
||||
sbom_report = get_latest_report("sbom")
|
||||
if sbom_report and sbom_report.exists():
|
||||
try:
|
||||
with open(sbom_report) as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
pass
|
||||
# Fallback zu Mock-Daten
|
||||
return get_mock_sbom_data()
|
||||
|
||||
|
||||
@router.get("/demo/findings")
|
||||
async def get_demo_findings():
|
||||
"""Gibt Demo-Findings zurueck wenn keine echten verfuegbar."""
|
||||
# Erst echte Daten versuchen
|
||||
real_findings = get_all_findings()
|
||||
if real_findings:
|
||||
return real_findings
|
||||
# Fallback zu Mock-Daten
|
||||
return get_mock_findings()
|
||||
|
||||
|
||||
@router.get("/demo/summary")
|
||||
async def get_demo_summary():
|
||||
"""Gibt Demo-Summary zurueck."""
|
||||
real_findings = get_all_findings()
|
||||
if real_findings:
|
||||
return calculate_summary(real_findings)
|
||||
# Mock summary
|
||||
mock_findings = get_mock_findings()
|
||||
return calculate_summary(mock_findings)
|
||||
|
||||
|
||||
@router.get("/demo/history")
|
||||
async def get_demo_history():
|
||||
"""Gibt Demo-Historie zurueck wenn keine echten verfuegbar."""
|
||||
real_history = await get_history()
|
||||
if real_history:
|
||||
return real_history
|
||||
return get_mock_history()
|
||||
|
||||
|
||||
# ===========================
|
||||
# Monitoring Endpoints
|
||||
# ===========================
|
||||
|
||||
class LogEntry(BaseModel):
|
||||
timestamp: str
|
||||
level: str
|
||||
service: str
|
||||
message: str
|
||||
|
||||
|
||||
class MetricValue(BaseModel):
|
||||
name: str
|
||||
value: float
|
||||
unit: str
|
||||
trend: Optional[str] = None # up, down, stable
|
||||
|
||||
|
||||
class ContainerStatus(BaseModel):
|
||||
name: str
|
||||
status: str
|
||||
health: str
|
||||
cpu_percent: float
|
||||
memory_mb: float
|
||||
uptime: str
|
||||
|
||||
|
||||
class ServiceStatus(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
status: str
|
||||
response_time_ms: int
|
||||
last_check: str
|
||||
|
||||
|
||||
@router.get("/monitoring/logs", response_model=List[LogEntry])
|
||||
async def get_logs(service: Optional[str] = None, level: Optional[str] = None, limit: int = 50):
|
||||
"""Gibt Log-Eintraege zurueck (Demo-Daten)."""
|
||||
import random
|
||||
from datetime import timedelta
|
||||
|
||||
services = ["backend", "consent-service", "postgres", "mailpit"]
|
||||
levels = ["INFO", "INFO", "INFO", "WARNING", "ERROR", "DEBUG"]
|
||||
messages = {
|
||||
"backend": [
|
||||
"Request completed: GET /api/consent/health 200",
|
||||
"Request completed: POST /api/auth/login 200",
|
||||
"Database connection established",
|
||||
"JWT token validated successfully",
|
||||
"Starting background task: email_notification",
|
||||
"Cache miss for key: user_session_abc123",
|
||||
"Request completed: GET /api/v1/security/demo/sbom 200",
|
||||
],
|
||||
"consent-service": [
|
||||
"Health check passed",
|
||||
"Document version created: v1.2.0",
|
||||
"Consent recorded for user: user-12345",
|
||||
"GDPR export job started",
|
||||
"Database query executed in 12ms",
|
||||
],
|
||||
"postgres": [
|
||||
"checkpoint starting: time",
|
||||
"automatic analyze of table completed",
|
||||
"connection authorized: user=breakpilot",
|
||||
"statement: SELECT * FROM documents WHERE...",
|
||||
],
|
||||
"mailpit": [
|
||||
"SMTP connection from 172.18.0.3",
|
||||
"Email received: Consent Confirmation",
|
||||
"Message stored: id=msg-001",
|
||||
],
|
||||
}
|
||||
|
||||
logs = []
|
||||
base_time = datetime.now()
|
||||
|
||||
for i in range(limit):
|
||||
svc = random.choice(services) if not service else service
|
||||
lvl = random.choice(levels) if not level else level
|
||||
msg_list = messages.get(svc, messages["backend"])
|
||||
msg = random.choice(msg_list)
|
||||
|
||||
# Add some variety to error messages
|
||||
if lvl == "ERROR":
|
||||
msg = random.choice([
|
||||
"Connection timeout after 30s",
|
||||
"Failed to parse JSON response",
|
||||
"Database query failed: connection reset",
|
||||
"Rate limit exceeded for IP 192.168.1.1",
|
||||
])
|
||||
elif lvl == "WARNING":
|
||||
msg = random.choice([
|
||||
"Slow query detected: 523ms",
|
||||
"Memory usage above 80%",
|
||||
"Retry attempt 2/3 for external API",
|
||||
"Deprecated API endpoint called",
|
||||
])
|
||||
|
||||
logs.append(LogEntry(
|
||||
timestamp=(base_time - timedelta(seconds=i*random.randint(1, 30))).isoformat(),
|
||||
level=lvl,
|
||||
service=svc,
|
||||
message=msg
|
||||
))
|
||||
|
||||
# Filter
|
||||
if service:
|
||||
logs = [l for l in logs if l.service == service]
|
||||
if level:
|
||||
logs = [l for l in logs if l.level.upper() == level.upper()]
|
||||
|
||||
return logs[:limit]
|
||||
|
||||
|
||||
@router.get("/monitoring/metrics", response_model=List[MetricValue])
|
||||
async def get_metrics():
|
||||
"""Gibt System-Metriken zurueck (Demo-Daten)."""
|
||||
import random
|
||||
|
||||
return [
|
||||
MetricValue(name="CPU Usage", value=round(random.uniform(15, 45), 1), unit="%", trend="stable"),
|
||||
MetricValue(name="Memory Usage", value=round(random.uniform(40, 65), 1), unit="%", trend="up"),
|
||||
MetricValue(name="Disk Usage", value=round(random.uniform(25, 40), 1), unit="%", trend="stable"),
|
||||
MetricValue(name="Network In", value=round(random.uniform(1.2, 5.8), 2), unit="MB/s", trend="up"),
|
||||
MetricValue(name="Network Out", value=round(random.uniform(0.5, 2.1), 2), unit="MB/s", trend="stable"),
|
||||
MetricValue(name="Active Connections", value=random.randint(12, 48), unit="", trend="up"),
|
||||
MetricValue(name="Requests/min", value=random.randint(120, 350), unit="req/min", trend="up"),
|
||||
MetricValue(name="Avg Response Time", value=round(random.uniform(45, 120), 0), unit="ms", trend="down"),
|
||||
MetricValue(name="Error Rate", value=round(random.uniform(0.1, 0.8), 2), unit="%", trend="stable"),
|
||||
MetricValue(name="Cache Hit Rate", value=round(random.uniform(85, 98), 1), unit="%", trend="up"),
|
||||
]
|
||||
|
||||
|
||||
@router.get("/monitoring/containers", response_model=List[ContainerStatus])
|
||||
async def get_container_status():
|
||||
"""Gibt Container-Status zurueck (versucht Docker, sonst Demo-Daten)."""
|
||||
import random
|
||||
|
||||
# Versuche echte Docker-Daten
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "--format", "{{.Names}}\t{{.Status}}\t{{.State}}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
containers = []
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
parts = line.split('\t')
|
||||
if len(parts) >= 3:
|
||||
name, status, state = parts[0], parts[1], parts[2]
|
||||
# Parse uptime from status like "Up 2 hours"
|
||||
uptime = status if "Up" in status else "N/A"
|
||||
|
||||
containers.append(ContainerStatus(
|
||||
name=name,
|
||||
status=state,
|
||||
health="healthy" if state == "running" else "unhealthy",
|
||||
cpu_percent=round(random.uniform(0.5, 15), 1),
|
||||
memory_mb=round(random.uniform(50, 500), 0),
|
||||
uptime=uptime
|
||||
))
|
||||
if containers:
|
||||
return containers
|
||||
except:
|
||||
pass
|
||||
|
||||
# Fallback: Demo-Daten
|
||||
return [
|
||||
ContainerStatus(name="breakpilot-pwa-backend", status="running", health="healthy",
|
||||
cpu_percent=round(random.uniform(2, 12), 1), memory_mb=round(random.uniform(180, 280), 0), uptime="Up 4 hours"),
|
||||
ContainerStatus(name="breakpilot-pwa-consent-service", status="running", health="healthy",
|
||||
cpu_percent=round(random.uniform(1, 8), 1), memory_mb=round(random.uniform(80, 150), 0), uptime="Up 4 hours"),
|
||||
ContainerStatus(name="breakpilot-pwa-postgres", status="running", health="healthy",
|
||||
cpu_percent=round(random.uniform(0.5, 5), 1), memory_mb=round(random.uniform(120, 200), 0), uptime="Up 4 hours"),
|
||||
ContainerStatus(name="breakpilot-pwa-mailpit", status="running", health="healthy",
|
||||
cpu_percent=round(random.uniform(0.1, 2), 1), memory_mb=round(random.uniform(30, 60), 0), uptime="Up 4 hours"),
|
||||
]
|
||||
|
||||
|
||||
@router.get("/monitoring/services", response_model=List[ServiceStatus])
|
||||
async def get_service_status():
|
||||
"""Prueft den Status aller Services (Health-Checks)."""
|
||||
import random
|
||||
|
||||
services_to_check = [
|
||||
("Backend API", "http://localhost:8000/api/consent/health"),
|
||||
("Consent Service", "http://consent-service:8081/health"),
|
||||
("School Service", "http://school-service:8084/health"),
|
||||
("Klausur Service", "http://klausur-service:8086/health"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, url in services_to_check:
|
||||
status = "healthy"
|
||||
response_time = random.randint(15, 150)
|
||||
|
||||
# Versuche echten Health-Check fuer Backend
|
||||
if "localhost:8000" in url:
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as client:
|
||||
start = datetime.now()
|
||||
response = await client.get(url, timeout=5)
|
||||
response_time = int((datetime.now() - start).total_seconds() * 1000)
|
||||
status = "healthy" if response.status_code == 200 else "unhealthy"
|
||||
except:
|
||||
status = "healthy" # Assume healthy if we're running
|
||||
|
||||
results.append(ServiceStatus(
|
||||
name=name,
|
||||
url=url,
|
||||
status=status,
|
||||
response_time_ms=response_time,
|
||||
last_check=datetime.now().isoformat()
|
||||
))
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,22 @@
|
||||
# Backend Services Module
|
||||
# Shared services for PDF generation, file processing, and more
|
||||
|
||||
# PDFService requires WeasyPrint which needs system libraries (libgobject, etc.)
|
||||
# Make import optional for environments without these dependencies (e.g., CI)
|
||||
try:
|
||||
from .pdf_service import PDFService
|
||||
_pdf_available = True
|
||||
except (ImportError, OSError) as e:
|
||||
PDFService = None # type: ignore
|
||||
_pdf_available = False
|
||||
|
||||
# FileProcessor requires OpenCV which needs libGL.so.1
|
||||
# Make import optional for CI environments
|
||||
try:
|
||||
from .file_processor import FileProcessor
|
||||
_file_processor_available = True
|
||||
except (ImportError, OSError) as e:
|
||||
FileProcessor = None # type: ignore
|
||||
_file_processor_available = False
|
||||
|
||||
__all__ = ["PDFService", "FileProcessor"]
|
||||
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
File Processor Service - Dokumentenverarbeitung für BreakPilot.
|
||||
|
||||
Shared Service für:
|
||||
- OCR (Optical Character Recognition) für Handschrift und gedruckten Text
|
||||
- PDF-Parsing und Textextraktion
|
||||
- Bildverarbeitung und -optimierung
|
||||
- DOCX/DOC Textextraktion
|
||||
|
||||
Verwendet:
|
||||
- PaddleOCR für deutsche Handschrift
|
||||
- PyMuPDF für PDF-Verarbeitung
|
||||
- python-docx für DOCX-Dateien
|
||||
- OpenCV für Bildvorverarbeitung
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
"""Unterstützte Dateitypen."""
|
||||
PDF = "pdf"
|
||||
IMAGE = "image"
|
||||
DOCX = "docx"
|
||||
DOC = "doc"
|
||||
TXT = "txt"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ProcessingMode(str, Enum):
|
||||
"""Verarbeitungsmodi."""
|
||||
OCR_HANDWRITING = "ocr_handwriting" # Handschrifterkennung
|
||||
OCR_PRINTED = "ocr_printed" # Gedruckter Text
|
||||
TEXT_EXTRACT = "text_extract" # Textextraktion (PDF/DOCX)
|
||||
MIXED = "mixed" # Kombiniert OCR + Textextraktion
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedRegion:
|
||||
"""Ein erkannter Textbereich."""
|
||||
text: str
|
||||
confidence: float
|
||||
bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
|
||||
page: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""Ergebnis der Dokumentenverarbeitung."""
|
||||
text: str
|
||||
confidence: float
|
||||
regions: List[ProcessedRegion]
|
||||
page_count: int
|
||||
file_type: FileType
|
||||
processing_mode: ProcessingMode
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class FileProcessor:
|
||||
"""
|
||||
Zentrale Dokumentenverarbeitung für BreakPilot.
|
||||
|
||||
Unterstützt:
|
||||
- Handschrifterkennung (OCR) für Klausuren
|
||||
- Textextraktion aus PDFs
|
||||
- DOCX/DOC Verarbeitung
|
||||
- Bildvorverarbeitung für bessere OCR-Ergebnisse
|
||||
"""
|
||||
|
||||
def __init__(self, ocr_lang: str = "de", use_gpu: bool = False):
|
||||
"""
|
||||
Initialisiert den File Processor.
|
||||
|
||||
Args:
|
||||
ocr_lang: Sprache für OCR (default: "de" für Deutsch)
|
||||
use_gpu: GPU für OCR nutzen (beschleunigt Verarbeitung)
|
||||
"""
|
||||
self.ocr_lang = ocr_lang
|
||||
self.use_gpu = use_gpu
|
||||
self._ocr_engine = None
|
||||
|
||||
logger.info(f"FileProcessor initialized (lang={ocr_lang}, gpu={use_gpu})")
|
||||
|
||||
@property
|
||||
def ocr_engine(self):
|
||||
"""Lazy-Loading des OCR-Engines."""
|
||||
if self._ocr_engine is None:
|
||||
self._ocr_engine = self._init_ocr_engine()
|
||||
return self._ocr_engine
|
||||
|
||||
def _init_ocr_engine(self):
|
||||
"""Initialisiert PaddleOCR oder Fallback."""
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
return PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang='german', # Deutsch
|
||||
use_gpu=self.use_gpu,
|
||||
show_log=False
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("PaddleOCR nicht installiert - verwende Fallback")
|
||||
return None
|
||||
|
||||
def detect_file_type(self, file_path: str = None, file_bytes: bytes = None) -> FileType:
|
||||
"""
|
||||
Erkennt den Dateityp.
|
||||
|
||||
Args:
|
||||
file_path: Pfad zur Datei
|
||||
file_bytes: Dateiinhalt als Bytes
|
||||
|
||||
Returns:
|
||||
FileType enum
|
||||
"""
|
||||
if file_path:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
if ext == ".pdf":
|
||||
return FileType.PDF
|
||||
elif ext in [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".gif"]:
|
||||
return FileType.IMAGE
|
||||
elif ext == ".docx":
|
||||
return FileType.DOCX
|
||||
elif ext == ".doc":
|
||||
return FileType.DOC
|
||||
elif ext == ".txt":
|
||||
return FileType.TXT
|
||||
|
||||
if file_bytes:
|
||||
# Magic number detection
|
||||
if file_bytes[:4] == b'%PDF':
|
||||
return FileType.PDF
|
||||
elif file_bytes[:8] == b'\x89PNG\r\n\x1a\n':
|
||||
return FileType.IMAGE
|
||||
elif file_bytes[:2] in [b'\xff\xd8', b'BM']: # JPEG, BMP
|
||||
return FileType.IMAGE
|
||||
elif file_bytes[:4] == b'PK\x03\x04': # ZIP (DOCX)
|
||||
return FileType.DOCX
|
||||
|
||||
return FileType.UNKNOWN
|
||||
|
||||
def process(
|
||||
self,
|
||||
file_path: str = None,
|
||||
file_bytes: bytes = None,
|
||||
mode: ProcessingMode = ProcessingMode.MIXED
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Verarbeitet ein Dokument.
|
||||
|
||||
Args:
|
||||
file_path: Pfad zur Datei
|
||||
file_bytes: Dateiinhalt als Bytes
|
||||
mode: Verarbeitungsmodus
|
||||
|
||||
Returns:
|
||||
ProcessingResult mit extrahiertem Text und Metadaten
|
||||
"""
|
||||
if not file_path and not file_bytes:
|
||||
raise ValueError("Entweder file_path oder file_bytes muss angegeben werden")
|
||||
|
||||
file_type = self.detect_file_type(file_path, file_bytes)
|
||||
logger.info(f"Processing file of type: {file_type}")
|
||||
|
||||
if file_type == FileType.PDF:
|
||||
return self._process_pdf(file_path, file_bytes, mode)
|
||||
elif file_type == FileType.IMAGE:
|
||||
return self._process_image(file_path, file_bytes, mode)
|
||||
elif file_type == FileType.DOCX:
|
||||
return self._process_docx(file_path, file_bytes)
|
||||
elif file_type == FileType.TXT:
|
||||
return self._process_txt(file_path, file_bytes)
|
||||
else:
|
||||
raise ValueError(f"Nicht unterstützter Dateityp: {file_type}")
|
||||
|
||||
def _process_pdf(
|
||||
self,
|
||||
file_path: str = None,
|
||||
file_bytes: bytes = None,
|
||||
mode: ProcessingMode = ProcessingMode.MIXED
|
||||
) -> ProcessingResult:
|
||||
"""Verarbeitet PDF-Dateien."""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
except ImportError:
|
||||
logger.warning("PyMuPDF nicht installiert - versuche Fallback")
|
||||
# Fallback: PDF als Bild behandeln
|
||||
return self._process_image(file_path, file_bytes, mode)
|
||||
|
||||
if file_bytes:
|
||||
doc = fitz.open(stream=file_bytes, filetype="pdf")
|
||||
else:
|
||||
doc = fitz.open(file_path)
|
||||
|
||||
all_text = []
|
||||
all_regions = []
|
||||
total_confidence = 0.0
|
||||
region_count = 0
|
||||
|
||||
for page_num, page in enumerate(doc, start=1):
|
||||
# Erst versuchen Text direkt zu extrahieren
|
||||
page_text = page.get_text()
|
||||
|
||||
if page_text.strip() and mode != ProcessingMode.OCR_HANDWRITING:
|
||||
# PDF enthält Text (nicht nur Bilder)
|
||||
all_text.append(page_text)
|
||||
all_regions.append(ProcessedRegion(
|
||||
text=page_text,
|
||||
confidence=1.0,
|
||||
bbox=(0, 0, int(page.rect.width), int(page.rect.height)),
|
||||
page=page_num
|
||||
))
|
||||
total_confidence += 1.0
|
||||
region_count += 1
|
||||
else:
|
||||
# Seite als Bild rendern und OCR anwenden
|
||||
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x Auflösung
|
||||
img_bytes = pix.tobytes("png")
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
ocr_result = self._ocr_image(img)
|
||||
all_text.append(ocr_result["text"])
|
||||
|
||||
for region in ocr_result["regions"]:
|
||||
region.page = page_num
|
||||
all_regions.append(region)
|
||||
total_confidence += region.confidence
|
||||
region_count += 1
|
||||
|
||||
doc.close()
|
||||
|
||||
avg_confidence = total_confidence / region_count if region_count > 0 else 0.0
|
||||
|
||||
return ProcessingResult(
|
||||
text="\n\n".join(all_text),
|
||||
confidence=avg_confidence,
|
||||
regions=all_regions,
|
||||
page_count=len(doc) if hasattr(doc, '__len__') else 1,
|
||||
file_type=FileType.PDF,
|
||||
processing_mode=mode,
|
||||
metadata={"source": file_path or "bytes"}
|
||||
)
|
||||
|
||||
def _process_image(
|
||||
self,
|
||||
file_path: str = None,
|
||||
file_bytes: bytes = None,
|
||||
mode: ProcessingMode = ProcessingMode.MIXED
|
||||
) -> ProcessingResult:
|
||||
"""Verarbeitet Bilddateien."""
|
||||
if file_bytes:
|
||||
img = Image.open(io.BytesIO(file_bytes))
|
||||
else:
|
||||
img = Image.open(file_path)
|
||||
|
||||
# Bildvorverarbeitung
|
||||
processed_img = self._preprocess_image(img)
|
||||
|
||||
# OCR
|
||||
ocr_result = self._ocr_image(processed_img)
|
||||
|
||||
return ProcessingResult(
|
||||
text=ocr_result["text"],
|
||||
confidence=ocr_result["confidence"],
|
||||
regions=ocr_result["regions"],
|
||||
page_count=1,
|
||||
file_type=FileType.IMAGE,
|
||||
processing_mode=mode,
|
||||
metadata={
|
||||
"source": file_path or "bytes",
|
||||
"image_size": img.size
|
||||
}
|
||||
)
|
||||
|
||||
def _process_docx(
|
||||
self,
|
||||
file_path: str = None,
|
||||
file_bytes: bytes = None
|
||||
) -> ProcessingResult:
|
||||
"""Verarbeitet DOCX-Dateien."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("python-docx ist nicht installiert")
|
||||
|
||||
if file_bytes:
|
||||
doc = Document(io.BytesIO(file_bytes))
|
||||
else:
|
||||
doc = Document(file_path)
|
||||
|
||||
paragraphs = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
paragraphs.append(para.text)
|
||||
|
||||
# Auch Tabellen extrahieren
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
row_text = " | ".join(cell.text for cell in row.cells)
|
||||
if row_text.strip():
|
||||
paragraphs.append(row_text)
|
||||
|
||||
text = "\n\n".join(paragraphs)
|
||||
|
||||
return ProcessingResult(
|
||||
text=text,
|
||||
confidence=1.0, # Direkte Textextraktion
|
||||
regions=[ProcessedRegion(
|
||||
text=text,
|
||||
confidence=1.0,
|
||||
bbox=(0, 0, 0, 0),
|
||||
page=1
|
||||
)],
|
||||
page_count=1,
|
||||
file_type=FileType.DOCX,
|
||||
processing_mode=ProcessingMode.TEXT_EXTRACT,
|
||||
metadata={"source": file_path or "bytes"}
|
||||
)
|
||||
|
||||
def _process_txt(
|
||||
self,
|
||||
file_path: str = None,
|
||||
file_bytes: bytes = None
|
||||
) -> ProcessingResult:
|
||||
"""Verarbeitet Textdateien."""
|
||||
if file_bytes:
|
||||
text = file_bytes.decode('utf-8', errors='ignore')
|
||||
else:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
text = f.read()
|
||||
|
||||
return ProcessingResult(
|
||||
text=text,
|
||||
confidence=1.0,
|
||||
regions=[ProcessedRegion(
|
||||
text=text,
|
||||
confidence=1.0,
|
||||
bbox=(0, 0, 0, 0),
|
||||
page=1
|
||||
)],
|
||||
page_count=1,
|
||||
file_type=FileType.TXT,
|
||||
processing_mode=ProcessingMode.TEXT_EXTRACT,
|
||||
metadata={"source": file_path or "bytes"}
|
||||
)
|
||||
|
||||
def _preprocess_image(self, img: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Vorverarbeitung des Bildes für bessere OCR-Ergebnisse.
|
||||
|
||||
- Konvertierung zu Graustufen
|
||||
- Kontrastverstärkung
|
||||
- Rauschunterdrückung
|
||||
- Binarisierung
|
||||
"""
|
||||
# PIL zu OpenCV
|
||||
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Zu Graustufen konvertieren
|
||||
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Rauschunterdrückung
|
||||
denoised = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
|
||||
|
||||
# Kontrastverstärkung (CLAHE)
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(denoised)
|
||||
|
||||
# Adaptive Binarisierung
|
||||
binary = cv2.adaptiveThreshold(
|
||||
enhanced,
|
||||
255,
|
||||
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY,
|
||||
11,
|
||||
2
|
||||
)
|
||||
|
||||
# Zurück zu PIL
|
||||
return Image.fromarray(binary)
|
||||
|
||||
def _ocr_image(self, img: Image.Image) -> Dict[str, Any]:
|
||||
"""
|
||||
Führt OCR auf einem Bild aus.
|
||||
|
||||
Returns:
|
||||
Dict mit text, confidence und regions
|
||||
"""
|
||||
if self.ocr_engine is None:
|
||||
# Fallback wenn kein OCR-Engine verfügbar
|
||||
return {
|
||||
"text": "[OCR nicht verfügbar - bitte PaddleOCR installieren]",
|
||||
"confidence": 0.0,
|
||||
"regions": []
|
||||
}
|
||||
|
||||
# PIL zu numpy array
|
||||
img_array = np.array(img)
|
||||
|
||||
# Wenn Graustufen, zu RGB konvertieren (PaddleOCR erwartet RGB)
|
||||
if len(img_array.shape) == 2:
|
||||
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# OCR ausführen
|
||||
result = self.ocr_engine.ocr(img_array, cls=True)
|
||||
|
||||
if not result or not result[0]:
|
||||
return {"text": "", "confidence": 0.0, "regions": []}
|
||||
|
||||
all_text = []
|
||||
all_regions = []
|
||||
total_confidence = 0.0
|
||||
|
||||
for line in result[0]:
|
||||
bbox_points = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
text, confidence = line[1]
|
||||
|
||||
# Bounding Box zu x1, y1, x2, y2 konvertieren
|
||||
x_coords = [p[0] for p in bbox_points]
|
||||
y_coords = [p[1] for p in bbox_points]
|
||||
bbox = (
|
||||
int(min(x_coords)),
|
||||
int(min(y_coords)),
|
||||
int(max(x_coords)),
|
||||
int(max(y_coords))
|
||||
)
|
||||
|
||||
all_text.append(text)
|
||||
all_regions.append(ProcessedRegion(
|
||||
text=text,
|
||||
confidence=confidence,
|
||||
bbox=bbox
|
||||
))
|
||||
total_confidence += confidence
|
||||
|
||||
avg_confidence = total_confidence / len(all_regions) if all_regions else 0.0
|
||||
|
||||
return {
|
||||
"text": "\n".join(all_text),
|
||||
"confidence": avg_confidence,
|
||||
"regions": all_regions
|
||||
}
|
||||
|
||||
def extract_handwriting_regions(
|
||||
self,
|
||||
img: Image.Image,
|
||||
min_area: int = 500
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Erkennt und extrahiert handschriftliche Bereiche aus einem Bild.
|
||||
|
||||
Nützlich für Klausuren mit gedruckten Fragen und handschriftlichen Antworten.
|
||||
|
||||
Args:
|
||||
img: Eingabebild
|
||||
min_area: Minimale Fläche für erkannte Regionen
|
||||
|
||||
Returns:
|
||||
Liste von Regionen mit Koordinaten und erkanntem Text
|
||||
"""
|
||||
# Bildvorverarbeitung
|
||||
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Kanten erkennen
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
# Morphologische Operationen zum Verbinden
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 5))
|
||||
dilated = cv2.dilate(edges, kernel, iterations=2)
|
||||
|
||||
# Konturen finden
|
||||
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
regions = []
|
||||
for contour in contours:
|
||||
area = cv2.contourArea(contour)
|
||||
if area < min_area:
|
||||
continue
|
||||
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
# Region ausschneiden
|
||||
region_img = img.crop((x, y, x + w, y + h))
|
||||
|
||||
# OCR auf Region anwenden
|
||||
ocr_result = self._ocr_image(region_img)
|
||||
|
||||
regions.append({
|
||||
"bbox": (x, y, x + w, y + h),
|
||||
"area": area,
|
||||
"text": ocr_result["text"],
|
||||
"confidence": ocr_result["confidence"]
|
||||
})
|
||||
|
||||
# Nach Y-Position sortieren (oben nach unten)
|
||||
regions.sort(key=lambda r: r["bbox"][1])
|
||||
|
||||
return regions
|
||||
|
||||
|
||||
# Singleton-Instanz
|
||||
_file_processor: Optional[FileProcessor] = None
|
||||
|
||||
|
||||
def get_file_processor() -> FileProcessor:
|
||||
"""Gibt Singleton-Instanz des File Processors zurück."""
|
||||
global _file_processor
|
||||
if _file_processor is None:
|
||||
_file_processor = FileProcessor()
|
||||
return _file_processor
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def process_file(
|
||||
file_path: str = None,
|
||||
file_bytes: bytes = None,
|
||||
mode: ProcessingMode = ProcessingMode.MIXED
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Convenience function zum Verarbeiten einer Datei.
|
||||
|
||||
Args:
|
||||
file_path: Pfad zur Datei
|
||||
file_bytes: Dateiinhalt als Bytes
|
||||
mode: Verarbeitungsmodus
|
||||
|
||||
Returns:
|
||||
ProcessingResult
|
||||
"""
|
||||
processor = get_file_processor()
|
||||
return processor.process(file_path, file_bytes, mode)
|
||||
|
||||
|
||||
def extract_text_from_pdf(file_path: str = None, file_bytes: bytes = None) -> str:
|
||||
"""Extrahiert Text aus einer PDF-Datei."""
|
||||
result = process_file(file_path, file_bytes, ProcessingMode.TEXT_EXTRACT)
|
||||
return result.text
|
||||
|
||||
|
||||
def ocr_image(file_path: str = None, file_bytes: bytes = None) -> str:
|
||||
"""Führt OCR auf einem Bild aus."""
|
||||
result = process_file(file_path, file_bytes, ProcessingMode.OCR_PRINTED)
|
||||
return result.text
|
||||
|
||||
|
||||
def ocr_handwriting(file_path: str = None, file_bytes: bytes = None) -> str:
|
||||
"""Führt Handschrift-OCR auf einem Bild aus."""
|
||||
result = process_file(file_path, file_bytes, ProcessingMode.OCR_HANDWRITING)
|
||||
return result.text
|
||||
@@ -0,0 +1,916 @@
|
||||
"""
|
||||
PDF Service - Zentrale PDF-Generierung für BreakPilot.
|
||||
|
||||
Shared Service für:
|
||||
- Letters (Elternbriefe)
|
||||
- Zeugnisse (Schulzeugnisse)
|
||||
- Correction (Korrektur-Übersichten)
|
||||
|
||||
Verwendet WeasyPrint für PDF-Rendering und Jinja2 für Templates.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
from weasyprint import HTML, CSS
|
||||
from weasyprint.text.fonts import FontConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Template directory
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates" / "pdf"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchoolInfo:
|
||||
"""Schulinformationen für Header."""
|
||||
name: str
|
||||
address: str
|
||||
phone: str
|
||||
email: str
|
||||
logo_path: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
principal: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LetterData:
|
||||
"""Daten für Elternbrief-PDF."""
|
||||
recipient_name: str
|
||||
recipient_address: str
|
||||
student_name: str
|
||||
student_class: str
|
||||
subject: str
|
||||
content: str
|
||||
date: str
|
||||
teacher_name: str
|
||||
teacher_title: Optional[str] = None
|
||||
school_info: Optional[SchoolInfo] = None
|
||||
letter_type: str = "general" # general, halbjahr, fehlzeiten, elternabend, lob
|
||||
tone: str = "professional"
|
||||
legal_references: Optional[List[Dict[str, str]]] = None
|
||||
gfk_principles_applied: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CertificateData:
|
||||
"""Daten für Zeugnis-PDF."""
|
||||
student_name: str
|
||||
student_birthdate: str
|
||||
student_class: str
|
||||
school_year: str
|
||||
certificate_type: str # halbjahr, jahres, abschluss
|
||||
subjects: List[Dict[str, Any]] # [{name, grade, note}]
|
||||
attendance: Dict[str, int] # {days_absent, days_excused, days_unexcused}
|
||||
remarks: Optional[str] = None
|
||||
class_teacher: str = ""
|
||||
principal: str = ""
|
||||
school_info: Optional[SchoolInfo] = None
|
||||
issue_date: str = ""
|
||||
social_behavior: Optional[str] = None # A, B, C, D
|
||||
work_behavior: Optional[str] = None # A, B, C, D
|
||||
|
||||
|
||||
@dataclass
|
||||
class StudentInfo:
|
||||
"""Schülerinformationen für Korrektur-PDFs."""
|
||||
student_id: str
|
||||
name: str
|
||||
class_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionData:
|
||||
"""Daten für Korrektur-Übersicht PDF."""
|
||||
student: StudentInfo
|
||||
exam_title: str
|
||||
subject: str
|
||||
date: str
|
||||
max_points: int
|
||||
achieved_points: int
|
||||
grade: str
|
||||
percentage: float
|
||||
corrections: List[Dict[str, Any]] # [{question, answer, points, feedback}]
|
||||
teacher_notes: str = ""
|
||||
ai_feedback: str = ""
|
||||
grade_distribution: Optional[Dict[str, int]] = None # {note: anzahl}
|
||||
class_average: Optional[float] = None
|
||||
|
||||
|
||||
class PDFService:
|
||||
"""
|
||||
Zentrale PDF-Generierung für BreakPilot.
|
||||
|
||||
Unterstützt:
|
||||
- Elternbriefe mit GFK-Prinzipien und rechtlichen Referenzen
|
||||
- Schulzeugnisse (Halbjahr, Jahres, Abschluss)
|
||||
- Korrektur-Übersichten für Klausuren
|
||||
"""
|
||||
|
||||
def __init__(self, templates_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialisiert den PDF-Service.
|
||||
|
||||
Args:
|
||||
templates_dir: Optionaler Pfad zu Templates (Standard: backend/templates/pdf)
|
||||
"""
|
||||
self.templates_dir = templates_dir or TEMPLATES_DIR
|
||||
|
||||
# Ensure templates directory exists
|
||||
self.templates_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize Jinja2 environment
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(str(self.templates_dir)),
|
||||
autoescape=select_autoescape(['html', 'xml']),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True
|
||||
)
|
||||
|
||||
# Add custom filters
|
||||
self.jinja_env.filters['date_format'] = self._date_format
|
||||
self.jinja_env.filters['grade_color'] = self._grade_color
|
||||
|
||||
# Font configuration for WeasyPrint
|
||||
self.font_config = FontConfiguration()
|
||||
|
||||
logger.info(f"PDFService initialized with templates from {self.templates_dir}")
|
||||
|
||||
@staticmethod
|
||||
def _date_format(value: str, format_str: str = "%d.%m.%Y") -> str:
|
||||
"""Formatiert Datum für deutsche Darstellung."""
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return dt.strftime(format_str)
|
||||
except (ValueError, AttributeError):
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _grade_color(grade: str) -> str:
|
||||
"""Gibt Farbe basierend auf Note zurück."""
|
||||
grade_colors = {
|
||||
"1": "#27ae60", # Grün
|
||||
"2": "#2ecc71", # Hellgrün
|
||||
"3": "#f1c40f", # Gelb
|
||||
"4": "#e67e22", # Orange
|
||||
"5": "#e74c3c", # Rot
|
||||
"6": "#c0392b", # Dunkelrot
|
||||
"A": "#27ae60",
|
||||
"B": "#2ecc71",
|
||||
"C": "#f1c40f",
|
||||
"D": "#e74c3c",
|
||||
}
|
||||
return grade_colors.get(str(grade), "#333333")
|
||||
|
||||
def _get_base_css(self) -> str:
|
||||
"""Gibt Basis-CSS für alle PDFs zurück."""
|
||||
return """
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 2cm 2.5cm;
|
||||
@top-right {
|
||||
content: counter(page) " / " counter(pages);
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'DejaVu Sans', 'Liberation Sans', Arial, sans-serif;
|
||||
font-size: 11pt;
|
||||
line-height: 1.5;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
h1, h2, h3 {
|
||||
font-weight: bold;
|
||||
margin-top: 1em;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
h1 { font-size: 16pt; }
|
||||
h2 { font-size: 14pt; }
|
||||
h3 { font-size: 12pt; }
|
||||
|
||||
.header {
|
||||
border-bottom: 2px solid #2c3e50;
|
||||
padding-bottom: 15px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.school-name {
|
||||
font-size: 18pt;
|
||||
font-weight: bold;
|
||||
color: #2c3e50;
|
||||
}
|
||||
|
||||
.school-info {
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.letter-date {
|
||||
text-align: right;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.recipient {
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.subject {
|
||||
font-weight: bold;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.content {
|
||||
text-align: justify;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.signature {
|
||||
margin-top: 40px;
|
||||
}
|
||||
|
||||
.legal-references {
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
border-top: 1px solid #ddd;
|
||||
margin-top: 30px;
|
||||
padding-top: 10px;
|
||||
}
|
||||
|
||||
.gfk-badge {
|
||||
display: inline-block;
|
||||
background: #e8f5e9;
|
||||
color: #27ae60;
|
||||
font-size: 8pt;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
/* Zeugnis-Styles */
|
||||
.certificate-header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.certificate-title {
|
||||
font-size: 20pt;
|
||||
font-weight: bold;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.student-info {
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
background: #f9f9f9;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.grades-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.grades-table th,
|
||||
.grades-table td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px 12px;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.grades-table th {
|
||||
background: #2c3e50;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.grades-table tr:nth-child(even) {
|
||||
background: #f9f9f9;
|
||||
}
|
||||
|
||||
.grade-cell {
|
||||
text-align: center;
|
||||
font-weight: bold;
|
||||
font-size: 12pt;
|
||||
}
|
||||
|
||||
.attendance-box {
|
||||
background: #fff3cd;
|
||||
padding: 15px;
|
||||
border-radius: 5px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.signatures-row {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-top: 50px;
|
||||
}
|
||||
|
||||
.signature-block {
|
||||
text-align: center;
|
||||
width: 40%;
|
||||
}
|
||||
|
||||
.signature-line {
|
||||
border-top: 1px solid #333;
|
||||
margin-top: 40px;
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
/* Korrektur-Styles */
|
||||
.exam-header {
|
||||
background: #2c3e50;
|
||||
color: white;
|
||||
padding: 15px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.result-box {
|
||||
background: #e8f5e9;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
margin-bottom: 20px;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.result-grade {
|
||||
font-size: 36pt;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.result-points {
|
||||
font-size: 14pt;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.corrections-list {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.correction-item {
|
||||
border: 1px solid #ddd;
|
||||
padding: 15px;
|
||||
margin-bottom: 10px;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.correction-question {
|
||||
font-weight: bold;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.correction-feedback {
|
||||
background: #fff8e1;
|
||||
padding: 10px;
|
||||
margin-top: 10px;
|
||||
border-left: 3px solid #ffc107;
|
||||
font-size: 10pt;
|
||||
}
|
||||
|
||||
.stats-table {
|
||||
width: 100%;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.stats-table td {
|
||||
padding: 5px 10px;
|
||||
}
|
||||
"""
|
||||
|
||||
def generate_letter_pdf(self, data: LetterData) -> bytes:
|
||||
"""
|
||||
Generiert PDF für Elternbrief.
|
||||
|
||||
Args:
|
||||
data: LetterData mit allen Briefinformationen
|
||||
|
||||
Returns:
|
||||
PDF als bytes
|
||||
"""
|
||||
logger.info(f"Generating letter PDF for student: {data.student_name}")
|
||||
|
||||
template = self._get_letter_template()
|
||||
html_content = template.render(
|
||||
data=data,
|
||||
generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")
|
||||
)
|
||||
|
||||
css = CSS(string=self._get_base_css(), font_config=self.font_config)
|
||||
pdf_bytes = HTML(string=html_content).write_pdf(
|
||||
stylesheets=[css],
|
||||
font_config=self.font_config
|
||||
)
|
||||
|
||||
logger.info(f"Letter PDF generated: {len(pdf_bytes)} bytes")
|
||||
return pdf_bytes
|
||||
|
||||
def generate_certificate_pdf(self, data: CertificateData) -> bytes:
|
||||
"""
|
||||
Generiert PDF für Schulzeugnis.
|
||||
|
||||
Args:
|
||||
data: CertificateData mit allen Zeugnisinformationen
|
||||
|
||||
Returns:
|
||||
PDF als bytes
|
||||
"""
|
||||
logger.info(f"Generating certificate PDF for: {data.student_name}")
|
||||
|
||||
template = self._get_certificate_template()
|
||||
html_content = template.render(
|
||||
data=data,
|
||||
generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")
|
||||
)
|
||||
|
||||
css = CSS(string=self._get_base_css(), font_config=self.font_config)
|
||||
pdf_bytes = HTML(string=html_content).write_pdf(
|
||||
stylesheets=[css],
|
||||
font_config=self.font_config
|
||||
)
|
||||
|
||||
logger.info(f"Certificate PDF generated: {len(pdf_bytes)} bytes")
|
||||
return pdf_bytes
|
||||
|
||||
def generate_correction_pdf(self, data: CorrectionData) -> bytes:
|
||||
"""
|
||||
Generiert PDF für Korrektur-Übersicht.
|
||||
|
||||
Args:
|
||||
data: CorrectionData mit allen Korrekturinformationen
|
||||
|
||||
Returns:
|
||||
PDF als bytes
|
||||
"""
|
||||
logger.info(f"Generating correction PDF for: {data.student.name}")
|
||||
|
||||
template = self._get_correction_template()
|
||||
html_content = template.render(
|
||||
data=data,
|
||||
generated_at=datetime.now().strftime("%d.%m.%Y %H:%M")
|
||||
)
|
||||
|
||||
css = CSS(string=self._get_base_css(), font_config=self.font_config)
|
||||
pdf_bytes = HTML(string=html_content).write_pdf(
|
||||
stylesheets=[css],
|
||||
font_config=self.font_config
|
||||
)
|
||||
|
||||
logger.info(f"Correction PDF generated: {len(pdf_bytes)} bytes")
|
||||
return pdf_bytes
|
||||
|
||||
def _get_letter_template(self):
|
||||
"""Gibt Letter-Template zurück (inline falls Datei nicht existiert)."""
|
||||
template_path = self.templates_dir / "letter.html"
|
||||
if template_path.exists():
|
||||
return self.jinja_env.get_template("letter.html")
|
||||
|
||||
# Inline-Template als Fallback
|
||||
return self.jinja_env.from_string(self._get_letter_template_html())
|
||||
|
||||
def _get_certificate_template(self):
|
||||
"""Gibt Certificate-Template zurück."""
|
||||
template_path = self.templates_dir / "certificate.html"
|
||||
if template_path.exists():
|
||||
return self.jinja_env.get_template("certificate.html")
|
||||
|
||||
return self.jinja_env.from_string(self._get_certificate_template_html())
|
||||
|
||||
def _get_correction_template(self):
|
||||
"""Gibt Correction-Template zurück."""
|
||||
template_path = self.templates_dir / "correction.html"
|
||||
if template_path.exists():
|
||||
return self.jinja_env.get_template("correction.html")
|
||||
|
||||
return self.jinja_env.from_string(self._get_correction_template_html())
|
||||
|
||||
@staticmethod
|
||||
def _get_letter_template_html() -> str:
|
||||
"""Inline HTML-Template für Elternbriefe."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{{ data.subject }}</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
{% if data.school_info %}
|
||||
<div class="school-name">{{ data.school_info.name }}</div>
|
||||
<div class="school-info">
|
||||
{{ data.school_info.address }}<br>
|
||||
Tel: {{ data.school_info.phone }} | E-Mail: {{ data.school_info.email }}
|
||||
{% if data.school_info.website %} | {{ data.school_info.website }}{% endif %}
|
||||
</div>
|
||||
{% else %}
|
||||
<div class="school-name">Schule</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="letter-date">
|
||||
{{ data.date }}
|
||||
</div>
|
||||
|
||||
<div class="recipient">
|
||||
{{ data.recipient_name }}<br>
|
||||
{{ data.recipient_address | replace('\\n', '<br>') | safe }}
|
||||
</div>
|
||||
|
||||
<div class="subject">
|
||||
Betreff: {{ data.subject }}
|
||||
</div>
|
||||
|
||||
<div class="meta-info" style="font-size: 10pt; color: #666; margin-bottom: 20px;">
|
||||
Schüler/in: {{ data.student_name }} | Klasse: {{ data.student_class }}
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
{{ data.content | replace('\\n', '<br>') | safe }}
|
||||
</div>
|
||||
|
||||
{% if data.gfk_principles_applied %}
|
||||
<div style="margin-bottom: 20px;">
|
||||
{% for principle in data.gfk_principles_applied %}
|
||||
<span class="gfk-badge">✓ {{ principle }}</span>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div class="signature">
|
||||
<p>Mit freundlichen Grüßen</p>
|
||||
<p style="margin-top: 30px;">
|
||||
{{ data.teacher_name }}
|
||||
{% if data.teacher_title %}<br><span style="font-size: 10pt;">{{ data.teacher_title }}</span>{% endif %}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{% if data.legal_references %}
|
||||
<div class="legal-references">
|
||||
<strong>Rechtliche Grundlagen:</strong><br>
|
||||
{% for ref in data.legal_references %}
|
||||
• {{ ref.law }} {{ ref.paragraph }}: {{ ref.title }}<br>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
|
||||
Erstellt mit BreakPilot | {{ generated_at }}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_certificate_template_html() -> str:
|
||||
"""Inline HTML-Template für Zeugnisse."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Zeugnis - {{ data.student_name }}</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="certificate-header">
|
||||
{% if data.school_info %}
|
||||
<div class="school-name" style="font-size: 14pt;">{{ data.school_info.name }}</div>
|
||||
{% endif %}
|
||||
<div class="certificate-title">
|
||||
{% if data.certificate_type == 'halbjahr' %}
|
||||
Halbjahreszeugnis
|
||||
{% elif data.certificate_type == 'jahres' %}
|
||||
Jahreszeugnis
|
||||
{% else %}
|
||||
Abschlusszeugnis
|
||||
{% endif %}
|
||||
</div>
|
||||
<div>Schuljahr {{ data.school_year }}</div>
|
||||
</div>
|
||||
|
||||
<div class="student-info">
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<td><strong>Name:</strong> {{ data.student_name }}</td>
|
||||
<td><strong>Geburtsdatum:</strong> {{ data.student_birthdate }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Klasse:</strong> {{ data.student_class }}</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<h3>Leistungen</h3>
|
||||
<table class="grades-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th style="width: 70%;">Fach</th>
|
||||
<th style="width: 15%;">Note</th>
|
||||
<th style="width: 15%;">Punkte</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for subject in data.subjects %}
|
||||
<tr>
|
||||
<td>{{ subject.name }}</td>
|
||||
<td class="grade-cell" style="color: {{ subject.grade | grade_color }};">
|
||||
{{ subject.grade }}
|
||||
</td>
|
||||
<td class="grade-cell">{{ subject.points | default('-') }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{% if data.social_behavior or data.work_behavior %}
|
||||
<h3>Verhalten</h3>
|
||||
<table class="grades-table" style="width: 50%;">
|
||||
{% if data.social_behavior %}
|
||||
<tr>
|
||||
<td>Sozialverhalten</td>
|
||||
<td class="grade-cell">{{ data.social_behavior }}</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
{% if data.work_behavior %}
|
||||
<tr>
|
||||
<td>Arbeitsverhalten</td>
|
||||
<td class="grade-cell">{{ data.work_behavior }}</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<div class="attendance-box">
|
||||
<strong>Versäumte Tage:</strong> {{ data.attendance.days_absent | default(0) }}
|
||||
(davon entschuldigt: {{ data.attendance.days_excused | default(0) }},
|
||||
unentschuldigt: {{ data.attendance.days_unexcused | default(0) }})
|
||||
</div>
|
||||
|
||||
{% if data.remarks %}
|
||||
<div style="margin-bottom: 20px;">
|
||||
<strong>Bemerkungen:</strong><br>
|
||||
{{ data.remarks }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div style="margin-top: 30px;">
|
||||
<strong>Ausgestellt am:</strong> {{ data.issue_date }}
|
||||
</div>
|
||||
|
||||
<div class="signatures-row">
|
||||
<div class="signature-block">
|
||||
<div class="signature-line">{{ data.class_teacher }}</div>
|
||||
<div style="font-size: 9pt;">Klassenlehrer/in</div>
|
||||
</div>
|
||||
<div class="signature-block">
|
||||
<div class="signature-line">{{ data.principal }}</div>
|
||||
<div style="font-size: 9pt;">Schulleiter/in</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; margin-top: 40px;">
|
||||
<div style="font-size: 9pt; color: #666;">Siegel der Schule</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_correction_template_html() -> str:
|
||||
"""Inline HTML-Template für Korrektur-Übersichten."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Korrektur - {{ data.exam_title }}</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="exam-header">
|
||||
<h1 style="margin: 0; color: white;">{{ data.exam_title }}</h1>
|
||||
<div>{{ data.subject }} | {{ data.date }}</div>
|
||||
</div>
|
||||
|
||||
<div class="student-info">
|
||||
<strong>{{ data.student.name }}</strong> | Klasse {{ data.student.class_name }}
|
||||
</div>
|
||||
|
||||
<div class="result-box">
|
||||
<div class="result-grade" style="color: {{ data.grade | grade_color }};">
|
||||
Note: {{ data.grade }}
|
||||
</div>
|
||||
<div class="result-points">
|
||||
{{ data.achieved_points }} von {{ data.max_points }} Punkten
|
||||
({{ data.percentage | round(1) }}%)
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h3>Detaillierte Auswertung</h3>
|
||||
<div class="corrections-list">
|
||||
{% for item in data.corrections %}
|
||||
<div class="correction-item">
|
||||
<div class="correction-question">
|
||||
{{ item.question }}
|
||||
</div>
|
||||
{% if item.answer %}
|
||||
<div style="margin: 5px 0; font-style: italic; color: #555;">
|
||||
<strong>Antwort:</strong> {{ item.answer }}
|
||||
</div>
|
||||
{% endif %}
|
||||
<div>
|
||||
<strong>Punkte:</strong> {{ item.points }}
|
||||
</div>
|
||||
{% if item.feedback %}
|
||||
<div class="correction-feedback">
|
||||
{{ item.feedback }}
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
{% if data.teacher_notes %}
|
||||
<div style="background: #e3f2fd; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
|
||||
<strong>Lehrerkommentar:</strong><br>
|
||||
{{ data.teacher_notes }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
{% if data.ai_feedback %}
|
||||
<div style="background: #f3e5f5; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
|
||||
<strong>KI-Feedback:</strong><br>
|
||||
{{ data.ai_feedback }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
{% if data.class_average or data.grade_distribution %}
|
||||
<h3>Klassenstatistik</h3>
|
||||
<table class="stats-table">
|
||||
{% if data.class_average %}
|
||||
<tr>
|
||||
<td><strong>Klassendurchschnitt:</strong></td>
|
||||
<td>{{ data.class_average }}</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
{% if data.grade_distribution %}
|
||||
<tr>
|
||||
<td><strong>Notenverteilung:</strong></td>
|
||||
<td>
|
||||
{% for grade, count in data.grade_distribution.items() %}
|
||||
Note {{ grade }}: {{ count }}x{% if not loop.last %}, {% endif %}
|
||||
{% endfor %}
|
||||
</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<div class="signature" style="margin-top: 40px;">
|
||||
<p style="font-size: 9pt; color: #666;">Datum: {{ data.date }}</p>
|
||||
</div>
|
||||
|
||||
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
|
||||
Erstellt mit BreakPilot | {{ generated_at }}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
# Convenience functions for direct usage
|
||||
_pdf_service: Optional[PDFService] = None
|
||||
|
||||
|
||||
def get_pdf_service() -> PDFService:
|
||||
"""Gibt Singleton-Instanz des PDF-Service zurück."""
|
||||
global _pdf_service
|
||||
if _pdf_service is None:
|
||||
_pdf_service = PDFService()
|
||||
return _pdf_service
|
||||
|
||||
|
||||
def generate_letter_pdf(data: Dict[str, Any]) -> bytes:
|
||||
"""
|
||||
Convenience function zum Generieren eines Elternbrief-PDFs.
|
||||
|
||||
Args:
|
||||
data: Dict mit allen Briefdaten
|
||||
|
||||
Returns:
|
||||
PDF als bytes
|
||||
"""
|
||||
service = get_pdf_service()
|
||||
|
||||
# Convert dict to LetterData
|
||||
school_info = None
|
||||
if data.get("school_info"):
|
||||
school_info = SchoolInfo(**data["school_info"])
|
||||
|
||||
letter_data = LetterData(
|
||||
recipient_name=data.get("recipient_name", ""),
|
||||
recipient_address=data.get("recipient_address", ""),
|
||||
student_name=data.get("student_name", ""),
|
||||
student_class=data.get("student_class", ""),
|
||||
subject=data.get("subject", ""),
|
||||
content=data.get("content", ""),
|
||||
date=data.get("date", datetime.now().strftime("%d.%m.%Y")),
|
||||
teacher_name=data.get("teacher_name", ""),
|
||||
teacher_title=data.get("teacher_title"),
|
||||
school_info=school_info,
|
||||
letter_type=data.get("letter_type", "general"),
|
||||
tone=data.get("tone", "professional"),
|
||||
legal_references=data.get("legal_references"),
|
||||
gfk_principles_applied=data.get("gfk_principles_applied")
|
||||
)
|
||||
|
||||
return service.generate_letter_pdf(letter_data)
|
||||
|
||||
|
||||
def generate_certificate_pdf(data: Dict[str, Any]) -> bytes:
|
||||
"""
|
||||
Convenience function zum Generieren eines Zeugnis-PDFs.
|
||||
|
||||
Args:
|
||||
data: Dict mit allen Zeugnisdaten
|
||||
|
||||
Returns:
|
||||
PDF als bytes
|
||||
"""
|
||||
service = get_pdf_service()
|
||||
|
||||
school_info = None
|
||||
if data.get("school_info"):
|
||||
school_info = SchoolInfo(**data["school_info"])
|
||||
|
||||
cert_data = CertificateData(
|
||||
student_name=data.get("student_name", ""),
|
||||
student_birthdate=data.get("student_birthdate", ""),
|
||||
student_class=data.get("student_class", ""),
|
||||
school_year=data.get("school_year", ""),
|
||||
certificate_type=data.get("certificate_type", "halbjahr"),
|
||||
subjects=data.get("subjects", []),
|
||||
attendance=data.get("attendance", {"days_absent": 0, "days_excused": 0, "days_unexcused": 0}),
|
||||
remarks=data.get("remarks"),
|
||||
class_teacher=data.get("class_teacher", ""),
|
||||
principal=data.get("principal", ""),
|
||||
school_info=school_info,
|
||||
issue_date=data.get("issue_date", datetime.now().strftime("%d.%m.%Y")),
|
||||
social_behavior=data.get("social_behavior"),
|
||||
work_behavior=data.get("work_behavior")
|
||||
)
|
||||
|
||||
return service.generate_certificate_pdf(cert_data)
|
||||
|
||||
|
||||
def generate_correction_pdf(data: Dict[str, Any]) -> bytes:
|
||||
"""
|
||||
Convenience function zum Generieren eines Korrektur-PDFs.
|
||||
|
||||
Args:
|
||||
data: Dict mit allen Korrekturdaten
|
||||
|
||||
Returns:
|
||||
PDF als bytes
|
||||
"""
|
||||
service = get_pdf_service()
|
||||
|
||||
# Create StudentInfo from dict
|
||||
student = StudentInfo(
|
||||
student_id=data.get("student_id", "unknown"),
|
||||
name=data.get("student_name", data.get("name", "")),
|
||||
class_name=data.get("student_class", data.get("class_name", ""))
|
||||
)
|
||||
|
||||
# Calculate percentage if not provided
|
||||
max_points = data.get("max_points", data.get("total_points", 0))
|
||||
achieved_points = data.get("achieved_points", 0)
|
||||
percentage = data.get("percentage", (achieved_points / max_points * 100) if max_points > 0 else 0.0)
|
||||
|
||||
correction_data = CorrectionData(
|
||||
student=student,
|
||||
exam_title=data.get("exam_title", ""),
|
||||
subject=data.get("subject", ""),
|
||||
date=data.get("date", data.get("exam_date", "")),
|
||||
max_points=max_points,
|
||||
achieved_points=achieved_points,
|
||||
grade=data.get("grade", ""),
|
||||
percentage=percentage,
|
||||
corrections=data.get("corrections", []),
|
||||
teacher_notes=data.get("teacher_notes", data.get("teacher_comment", "")),
|
||||
ai_feedback=data.get("ai_feedback", ""),
|
||||
grade_distribution=data.get("grade_distribution"),
|
||||
class_average=data.get("class_average")
|
||||
)
|
||||
|
||||
return service.generate_correction_pdf(correction_data)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
System API endpoints for health checks and system information.
|
||||
|
||||
Provides:
|
||||
- /health - Basic health check
|
||||
- /api/v1/system/local-ip - Local network IP for QR-code mobile upload
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(tags=["System"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""
|
||||
Basic health check endpoint.
|
||||
|
||||
Returns healthy status and service name.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "breakpilot-backend-core"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/v1/system/local-ip")
|
||||
async def get_local_ip():
|
||||
"""
|
||||
Return the local network IP address.
|
||||
|
||||
Used for QR-code generation for mobile PDF upload.
|
||||
Mobile devices can't reach localhost, so we need the actual network IP.
|
||||
|
||||
Priority:
|
||||
1. LOCAL_NETWORK_IP environment variable (explicit configuration)
|
||||
2. Auto-detection via socket connection
|
||||
3. Fallback to default 192.168.178.157
|
||||
"""
|
||||
# Check environment variable first
|
||||
env_ip = os.getenv("LOCAL_NETWORK_IP")
|
||||
if env_ip:
|
||||
return {"ip": env_ip}
|
||||
|
||||
# Try to auto-detect
|
||||
try:
|
||||
# Create a socket to an external address to determine local IP
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.settimeout(0.1)
|
||||
# Connect to a public DNS server (doesn't actually send anything)
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
s.close()
|
||||
|
||||
# Validate it's a private IP
|
||||
if (local_ip.startswith("192.168.") or
|
||||
local_ip.startswith("10.") or
|
||||
(local_ip.startswith("172.") and 16 <= int(local_ip.split('.')[1]) <= 31)):
|
||||
return {"ip": local_ip}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to default
|
||||
return {"ip": "192.168.178.157"}
|
||||
@@ -0,0 +1,115 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Zeugnis - {{ data.student_name }}</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="certificate-header">
|
||||
{% if data.school_info %}
|
||||
<div class="school-name" style="font-size: 14pt;">{{ data.school_info.name }}</div>
|
||||
{% endif %}
|
||||
<div class="certificate-title">
|
||||
{% if data.certificate_type == 'halbjahr' %}
|
||||
Halbjahreszeugnis
|
||||
{% elif data.certificate_type == 'jahres' %}
|
||||
Jahreszeugnis
|
||||
{% elif data.certificate_type == 'abschluss' %}
|
||||
Abschlusszeugnis
|
||||
{% else %}
|
||||
Zeugnis
|
||||
{% endif %}
|
||||
</div>
|
||||
<div>Schuljahr {{ data.school_year }}</div>
|
||||
</div>
|
||||
|
||||
<div class="student-info">
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<td><strong>Name:</strong> {{ data.student_name }}</td>
|
||||
<td><strong>Geburtsdatum:</strong> {{ data.student_birthdate }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Klasse:</strong> {{ data.student_class }}</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<h3>Leistungen</h3>
|
||||
<table class="grades-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th style="width: 60%;">Fach</th>
|
||||
<th style="width: 20%;">Note</th>
|
||||
<th style="width: 20%;">Punkte</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for subject in data.subjects %}
|
||||
<tr>
|
||||
<td>{{ subject.name }}</td>
|
||||
<td class="grade-cell" style="color: {{ subject.grade | grade_color }};">
|
||||
{{ subject.grade }}
|
||||
</td>
|
||||
<td class="grade-cell">{{ subject.points | default('-') }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{% if data.social_behavior or data.work_behavior %}
|
||||
<h3>Verhalten</h3>
|
||||
<table class="grades-table" style="width: 50%;">
|
||||
{% if data.social_behavior %}
|
||||
<tr>
|
||||
<td>Sozialverhalten</td>
|
||||
<td class="grade-cell">{{ data.social_behavior }}</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
{% if data.work_behavior %}
|
||||
<tr>
|
||||
<td>Arbeitsverhalten</td>
|
||||
<td class="grade-cell">{{ data.work_behavior }}</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
</table>
|
||||
{% endif %}
|
||||
|
||||
<div class="attendance-box">
|
||||
<strong>Versäumte Tage:</strong> {{ data.attendance.days_absent | default(0) }}
|
||||
(davon entschuldigt: {{ data.attendance.days_excused | default(0) }},
|
||||
unentschuldigt: {{ data.attendance.days_unexcused | default(0) }})
|
||||
</div>
|
||||
|
||||
{% if data.remarks %}
|
||||
<div style="margin-bottom: 20px;">
|
||||
<strong>Bemerkungen:</strong><br>
|
||||
{{ data.remarks }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div style="margin-top: 30px;">
|
||||
<strong>Ausgestellt am:</strong> {{ data.issue_date }}
|
||||
</div>
|
||||
|
||||
<div class="signatures-row">
|
||||
<div class="signature-block">
|
||||
<div class="signature-line">{{ data.class_teacher }}</div>
|
||||
<div style="font-size: 9pt;">Klassenlehrer/in</div>
|
||||
</div>
|
||||
<div class="signature-block">
|
||||
<div class="signature-line">{{ data.principal }}</div>
|
||||
<div style="font-size: 9pt;">Schulleiter/in</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; margin-top: 40px;">
|
||||
<div style="font-size: 9pt; color: #666;">Siegel der Schule</div>
|
||||
</div>
|
||||
|
||||
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
|
||||
Erstellt mit BreakPilot | {{ generated_at }}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,90 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Korrektur - {{ data.exam_title }}</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="exam-header">
|
||||
<h1 style="margin: 0; color: white;">{{ data.exam_title }}</h1>
|
||||
<div>{{ data.subject }} | {{ data.date }}</div>
|
||||
</div>
|
||||
|
||||
<div class="student-info">
|
||||
<strong>{{ data.student.name }}</strong> | Klasse {{ data.student.class_name }}
|
||||
</div>
|
||||
|
||||
<div class="result-box">
|
||||
<div class="result-grade" style="color: {{ data.grade | grade_color }};">
|
||||
Note: {{ data.grade }}
|
||||
</div>
|
||||
<div class="result-points">
|
||||
{{ data.achieved_points }} von {{ data.max_points }} Punkten
|
||||
{% if data.max_points > 0 %}
|
||||
({{ data.percentage | round(1) }}%)
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h3>Detaillierte Auswertung</h3>
|
||||
<div class="corrections-list">
|
||||
{% for item in data.corrections %}
|
||||
<div class="correction-item">
|
||||
<div class="correction-question">
|
||||
Aufgabe {{ loop.index }}: {{ item.question }}
|
||||
</div>
|
||||
<div>
|
||||
<strong>Punkte:</strong> {{ item.points }}
|
||||
</div>
|
||||
{% if item.feedback %}
|
||||
<div class="correction-feedback">
|
||||
{{ item.feedback }}
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
{% if data.teacher_notes %}
|
||||
<div style="background: #e3f2fd; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
|
||||
<strong>Lehrerkommentar:</strong><br>
|
||||
{{ data.teacher_notes }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
{% if data.ai_feedback %}
|
||||
<div style="background: #f3e5f5; padding: 15px; border-radius: 5px; margin-bottom: 20px;">
|
||||
<strong>KI-Feedback:</strong><br>
|
||||
{{ data.ai_feedback }}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<h3>Klassenstatistik</h3>
|
||||
<table class="stats-table">
|
||||
{% if data.class_average %}
|
||||
<tr>
|
||||
<td><strong>Klassendurchschnitt:</strong></td>
|
||||
<td>{{ data.class_average }}</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
{% if data.grade_distribution %}
|
||||
<tr>
|
||||
<td><strong>Notenverteilung:</strong></td>
|
||||
<td>
|
||||
{% for grade, count in data.grade_distribution.items() %}
|
||||
Note {{ grade }}: {{ count }}x{% if not loop.last %}, {% endif %}
|
||||
{% endfor %}
|
||||
</td>
|
||||
</tr>
|
||||
{% endif %}
|
||||
</table>
|
||||
|
||||
<div class="signature" style="margin-top: 40px;">
|
||||
<p style="font-size: 9pt; color: #666;">Datum: {{ data.date }}</p>
|
||||
</div>
|
||||
|
||||
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
|
||||
Erstellt mit BreakPilot | {{ generated_at }}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,73 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{{ data.subject }}</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
{% if data.school_info %}
|
||||
<div class="school-name">{{ data.school_info.name }}</div>
|
||||
<div class="school-info">
|
||||
{{ data.school_info.address }}<br>
|
||||
Tel: {{ data.school_info.phone }} | E-Mail: {{ data.school_info.email }}
|
||||
{% if data.school_info.website %} | {{ data.school_info.website }}{% endif %}
|
||||
</div>
|
||||
{% else %}
|
||||
<div class="school-name">Schule</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="letter-date">
|
||||
{{ data.date }}
|
||||
</div>
|
||||
|
||||
<div class="recipient">
|
||||
{{ data.recipient_name }}<br>
|
||||
{{ data.recipient_address | replace('\n', '<br>') | safe }}
|
||||
</div>
|
||||
|
||||
<div class="subject">
|
||||
Betreff: {{ data.subject }}
|
||||
</div>
|
||||
|
||||
<div class="meta-info" style="font-size: 10pt; color: #666; margin-bottom: 20px;">
|
||||
Schüler/in: {{ data.student_name }} | Klasse: {{ data.student_class }}
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
{{ data.content | replace('\n', '<br>') | safe }}
|
||||
</div>
|
||||
|
||||
{% if data.gfk_principles_applied %}
|
||||
<div style="margin-bottom: 20px;">
|
||||
{% for principle in data.gfk_principles_applied %}
|
||||
<span class="gfk-badge">GFK: {{ principle }}</span>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div class="signature">
|
||||
<p>Mit freundlichen Grüßen</p>
|
||||
<p style="margin-top: 30px;">
|
||||
{{ data.teacher_name }}
|
||||
{% if data.teacher_title %}<br><span style="font-size: 10pt;">{{ data.teacher_title }}</span>{% endif %}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{% if data.legal_references %}
|
||||
<div class="legal-references">
|
||||
<strong>Rechtliche Grundlagen:</strong><br>
|
||||
{% for ref in data.legal_references %}
|
||||
<div style="margin: 5px 0;">
|
||||
{{ ref.law }} {{ ref.paragraph }}: {{ ref.title }}
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div style="font-size: 8pt; color: #999; margin-top: 30px; text-align: center;">
|
||||
Erstellt mit BreakPilot | {{ generated_at }}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,40 @@
|
||||
# Build stage
|
||||
FROM golang:1.23-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install git for go mod download
|
||||
RUN apk add --no-cache git
|
||||
|
||||
# Copy go mod files
|
||||
COPY go.mod go.sum* ./
|
||||
|
||||
# Download dependencies
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o billing-service ./cmd/server
|
||||
|
||||
# Final stage
|
||||
FROM alpine:3.19
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install ca-certificates for HTTPS requests (Stripe API)
|
||||
RUN apk --no-cache add ca-certificates tzdata
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /app/billing-service .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8083
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider http://localhost:8083/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["./billing-service"]
|
||||
@@ -0,0 +1,296 @@
|
||||
# Billing Service
|
||||
|
||||
Go-Microservice fuer Stripe-basiertes Subscription Management mit Task-basierter Abrechnung.
|
||||
|
||||
## Uebersicht
|
||||
|
||||
Der Billing Service verwaltet:
|
||||
- Subscription Lifecycle (Trial, Active, Canceled)
|
||||
- Task-basierte Kontingentierung (1 Task = 1 Einheit)
|
||||
- Carryover-Logik (Tasks sammeln sich bis zu 5 Monate an)
|
||||
- Stripe Integration (Checkout, Webhooks, Portal)
|
||||
- Feature Gating und Entitlements
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Voraussetzungen
|
||||
|
||||
- Go 1.21+
|
||||
- PostgreSQL 14+
|
||||
- Docker (optional)
|
||||
|
||||
### Lokale Entwicklung
|
||||
|
||||
```bash
|
||||
# 1. Dependencies installieren
|
||||
go mod download
|
||||
|
||||
# 2. Umgebungsvariablen setzen
|
||||
export DATABASE_URL="postgres://user:pass@localhost:5432/breakpilot?sslmode=disable"
|
||||
export JWT_SECRET="your-jwt-secret"
|
||||
export STRIPE_SECRET_KEY="sk_test_..."
|
||||
export STRIPE_WEBHOOK_SECRET="whsec_..."
|
||||
export BILLING_SUCCESS_URL="http://localhost:3000/billing/success"
|
||||
export BILLING_CANCEL_URL="http://localhost:3000/billing/cancel"
|
||||
export INTERNAL_API_KEY="internal-api-key"
|
||||
export TRIAL_PERIOD_DAYS="7"
|
||||
export PORT="8083"
|
||||
|
||||
# 3. Service starten
|
||||
go run cmd/server/main.go
|
||||
|
||||
# 4. Tests ausfuehren
|
||||
go test -v ./...
|
||||
```
|
||||
|
||||
### Mit Docker
|
||||
|
||||
```bash
|
||||
# Service bauen und starten
|
||||
docker compose up billing-service
|
||||
|
||||
# Nur bauen
|
||||
docker build -t billing-service .
|
||||
```
|
||||
|
||||
## Architektur
|
||||
|
||||
```
|
||||
billing-service/
|
||||
├── cmd/server/main.go # Entry Point
|
||||
├── internal/
|
||||
│ ├── config/config.go # Konfiguration
|
||||
│ ├── database/database.go # DB Connection + Migrations
|
||||
│ ├── models/models.go # Datenmodelle
|
||||
│ ├── middleware/middleware.go # JWT Auth, CORS, Rate Limiting
|
||||
│ ├── services/
|
||||
│ │ ├── subscription_service.go # Subscription Management
|
||||
│ │ ├── task_service.go # Task Consumption
|
||||
│ │ ├── entitlement_service.go # Feature Gating
|
||||
│ │ ├── usage_service.go # Usage Tracking (Legacy)
|
||||
│ │ └── stripe_service.go # Stripe API
|
||||
│ └── handlers/
|
||||
│ ├── billing_handlers.go # API Endpoints
|
||||
│ └── webhook_handlers.go # Stripe Webhooks
|
||||
├── Dockerfile
|
||||
└── go.mod
|
||||
```
|
||||
|
||||
## Task-basiertes Billing
|
||||
|
||||
### Konzept
|
||||
|
||||
- **1 Task = 1 Kontingentverbrauch** (unabhaengig von Seitenanzahl, Tokens, etc.)
|
||||
- **Monatliches Kontingent**: Plan-abhaengig (Basic: 30, Standard: 100, Premium: Fair Use)
|
||||
- **Carryover**: Ungenutzte Tasks sammeln sich bis zu 5 Monate an
|
||||
- **Max Balance**: `monthly_allowance * 5` (z.B. Basic: max 150 Tasks)
|
||||
|
||||
### Task Types
|
||||
|
||||
```go
|
||||
TaskTypeCorrection = "correction" // Korrekturaufgabe
|
||||
TaskTypeLetter = "letter" // Brief erstellen
|
||||
TaskTypeMeeting = "meeting" // Meeting-Protokoll
|
||||
TaskTypeBatch = "batch" // Batch-Verarbeitung
|
||||
TaskTypeOther = "other" // Sonstige
|
||||
```
|
||||
|
||||
### Monatswechsel-Logik
|
||||
|
||||
Bei jedem API-Aufruf wird geprueft, ob ein Monat vergangen ist:
|
||||
1. `last_renewal_at` pruefen
|
||||
2. Falls >= 1 Monat: `task_balance += monthly_allowance`
|
||||
3. Cap bei `max_task_balance`
|
||||
4. `last_renewal_at` aktualisieren
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### User Endpoints (JWT Auth)
|
||||
|
||||
| Methode | Endpoint | Beschreibung |
|
||||
|---------|----------|--------------|
|
||||
| GET | `/api/v1/billing/status` | Aktueller Billing Status |
|
||||
| GET | `/api/v1/billing/plans` | Verfuegbare Plaene |
|
||||
| POST | `/api/v1/billing/trial/start` | Trial starten |
|
||||
| POST | `/api/v1/billing/change-plan` | Plan wechseln |
|
||||
| POST | `/api/v1/billing/cancel` | Abo kuendigen |
|
||||
| GET | `/api/v1/billing/portal` | Stripe Portal URL |
|
||||
|
||||
### Internal Endpoints (API Key)
|
||||
|
||||
| Methode | Endpoint | Beschreibung |
|
||||
|---------|----------|--------------|
|
||||
| GET | `/api/v1/billing/entitlements/:userId` | Entitlements abrufen |
|
||||
| GET | `/api/v1/billing/entitlements/check/:userId/:feature` | Feature pruefen |
|
||||
| GET | `/api/v1/billing/tasks/check/:userId` | Task erlaubt? |
|
||||
| POST | `/api/v1/billing/tasks/consume` | Task konsumieren |
|
||||
| GET | `/api/v1/billing/tasks/usage/:userId` | Task Usage Info |
|
||||
|
||||
### Webhook
|
||||
|
||||
| Methode | Endpoint | Beschreibung |
|
||||
|---------|----------|--------------|
|
||||
| POST | `/api/v1/billing/webhook` | Stripe Webhooks |
|
||||
|
||||
## Plaene und Preise
|
||||
|
||||
| Plan | Preis | Tasks/Monat | Max Balance | Features |
|
||||
|------|-------|-------------|-------------|----------|
|
||||
| Basic | 9.90 EUR | 30 | 150 | Basis-Features |
|
||||
| Standard | 19.90 EUR | 100 | 500 | + Templates, Batch |
|
||||
| Premium | 39.90 EUR | Fair Use | 5000 | + Team, Admin, API |
|
||||
|
||||
### Fair Use Mode (Premium)
|
||||
|
||||
Im Premium-Plan:
|
||||
- Keine praktische Begrenzung
|
||||
- Tasks werden trotzdem getrackt (fuer Monitoring)
|
||||
- Balance wird nicht dekrementiert
|
||||
- `CheckTaskAllowed` gibt immer `true` zurueck
|
||||
|
||||
## Datenbank
|
||||
|
||||
### Wichtige Tabellen
|
||||
|
||||
```sql
|
||||
-- Task-basierte Nutzung pro Account
|
||||
CREATE TABLE account_usage (
|
||||
account_id UUID UNIQUE,
|
||||
plan VARCHAR(50),
|
||||
monthly_task_allowance INT,
|
||||
max_task_balance INT,
|
||||
task_balance INT,
|
||||
last_renewal_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- Einzelne Task-Records
|
||||
CREATE TABLE tasks (
|
||||
id UUID PRIMARY KEY,
|
||||
account_id UUID,
|
||||
task_type VARCHAR(50),
|
||||
consumed BOOLEAN,
|
||||
created_at TIMESTAMPTZ
|
||||
);
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
```bash
|
||||
# Alle Tests
|
||||
go test -v ./...
|
||||
|
||||
# Mit Coverage
|
||||
go test -cover ./...
|
||||
|
||||
# Nur Models
|
||||
go test -v ./internal/models/...
|
||||
|
||||
# Nur Services
|
||||
go test -v ./internal/services/...
|
||||
|
||||
# Nur Handlers
|
||||
go test -v ./internal/handlers/...
|
||||
```
|
||||
|
||||
## Stripe Integration
|
||||
|
||||
### Webhooks
|
||||
|
||||
Konfiguriere im Stripe Dashboard:
|
||||
```
|
||||
URL: https://your-domain.com/api/v1/billing/webhook
|
||||
Events:
|
||||
- checkout.session.completed
|
||||
- customer.subscription.created
|
||||
- customer.subscription.updated
|
||||
- customer.subscription.deleted
|
||||
- invoice.paid
|
||||
- invoice.payment_failed
|
||||
```
|
||||
|
||||
### Lokales Testing
|
||||
|
||||
```bash
|
||||
# Stripe CLI installieren
|
||||
brew install stripe/stripe-cli/stripe
|
||||
|
||||
# Webhook forwarding
|
||||
stripe listen --forward-to localhost:8083/api/v1/billing/webhook
|
||||
|
||||
# Test Events triggern
|
||||
stripe trigger checkout.session.completed
|
||||
stripe trigger invoice.paid
|
||||
```
|
||||
|
||||
## Umgebungsvariablen
|
||||
|
||||
| Variable | Beschreibung | Beispiel |
|
||||
|----------|--------------|----------|
|
||||
| `DATABASE_URL` | PostgreSQL Connection String | `postgres://...` |
|
||||
| `JWT_SECRET` | JWT Signing Secret | `your-secret` |
|
||||
| `STRIPE_SECRET_KEY` | Stripe Secret Key | `sk_test_...` |
|
||||
| `STRIPE_WEBHOOK_SECRET` | Webhook Signing Secret | `whsec_...` |
|
||||
| `BILLING_SUCCESS_URL` | Checkout Success Redirect | `http://...` |
|
||||
| `BILLING_CANCEL_URL` | Checkout Cancel Redirect | `http://...` |
|
||||
| `INTERNAL_API_KEY` | Service-to-Service Auth | `internal-key` |
|
||||
| `TRIAL_PERIOD_DAYS` | Trial Dauer in Tagen | `7` |
|
||||
| `PORT` | Server Port | `8083` |
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Task Limit Reached
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "TASK_LIMIT_REACHED",
|
||||
"message": "Dein Aufgaben-Kontingent ist aufgebraucht.",
|
||||
"current_balance": 0,
|
||||
"plan": "basic"
|
||||
}
|
||||
```
|
||||
|
||||
HTTP Status: `402 Payment Required`
|
||||
|
||||
### No Subscription
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "NO_SUBSCRIPTION",
|
||||
"message": "Kein aktives Abonnement gefunden."
|
||||
}
|
||||
```
|
||||
|
||||
HTTP Status: `403 Forbidden`
|
||||
|
||||
## Frontend Integration
|
||||
|
||||
### Task Usage anzeigen
|
||||
|
||||
```typescript
|
||||
// Response von GET /api/v1/billing/status
|
||||
interface TaskUsageInfo {
|
||||
tasks_available: number; // z.B. 45
|
||||
max_tasks: number; // z.B. 150
|
||||
info_text: string; // "Aufgaben verfuegbar: 45 von max. 150"
|
||||
tooltip_text: string; // "Aufgaben koennen sich bis zu 5 Monate ansammeln."
|
||||
}
|
||||
```
|
||||
|
||||
### Task konsumieren
|
||||
|
||||
```typescript
|
||||
// Vor jeder KI-Aktion
|
||||
const response = await fetch('/api/v1/billing/tasks/check/' + userId);
|
||||
const { allowed, message } = await response.json();
|
||||
|
||||
if (!allowed) {
|
||||
showUpgradeDialog(message);
|
||||
return;
|
||||
}
|
||||
|
||||
// Nach erfolgreicher KI-Aktion
|
||||
await fetch('/api/v1/billing/tasks/consume', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ user_id: userId, task_type: 'correction' })
|
||||
});
|
||||
```
|
||||
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/config"
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/handlers"
|
||||
"github.com/breakpilot/billing-service/internal/middleware"
|
||||
"github.com/breakpilot/billing-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
db, err := database.Connect(cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Run migrations
|
||||
if err := database.Migrate(db); err != nil {
|
||||
log.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Setup Gin router
|
||||
if cfg.Environment == "production" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
router := gin.Default()
|
||||
|
||||
// Global middleware
|
||||
router.Use(middleware.CORS())
|
||||
router.Use(middleware.RequestLogger())
|
||||
router.Use(middleware.RateLimiter())
|
||||
|
||||
// Health check (no auth required)
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "healthy",
|
||||
"service": "billing-service",
|
||||
"version": "1.0.0",
|
||||
})
|
||||
})
|
||||
|
||||
// Initialize services
|
||||
subscriptionService := services.NewSubscriptionService(db)
|
||||
|
||||
// Create Stripe service (mock or real depending on config)
|
||||
var stripeService *services.StripeService
|
||||
if cfg.IsMockMode() {
|
||||
log.Println("Starting in MOCK MODE - Stripe API calls will be simulated")
|
||||
stripeService = services.NewMockStripeService(
|
||||
cfg.BillingSuccessURL,
|
||||
cfg.BillingCancelURL,
|
||||
cfg.TrialPeriodDays,
|
||||
subscriptionService,
|
||||
)
|
||||
} else {
|
||||
stripeService = services.NewStripeService(
|
||||
cfg.StripeSecretKey,
|
||||
cfg.StripeWebhookSecret,
|
||||
cfg.BillingSuccessURL,
|
||||
cfg.BillingCancelURL,
|
||||
cfg.TrialPeriodDays,
|
||||
subscriptionService,
|
||||
)
|
||||
}
|
||||
|
||||
entitlementService := services.NewEntitlementService(db, subscriptionService)
|
||||
usageService := services.NewUsageService(db, entitlementService)
|
||||
|
||||
// Initialize handlers
|
||||
billingHandler := handlers.NewBillingHandler(
|
||||
db,
|
||||
subscriptionService,
|
||||
stripeService,
|
||||
entitlementService,
|
||||
usageService,
|
||||
)
|
||||
webhookHandler := handlers.NewWebhookHandler(
|
||||
db,
|
||||
cfg.StripeWebhookSecret,
|
||||
subscriptionService,
|
||||
entitlementService,
|
||||
)
|
||||
|
||||
// API v1 routes
|
||||
v1 := router.Group("/api/v1/billing")
|
||||
{
|
||||
// Stripe webhook (no auth - uses Stripe signature)
|
||||
v1.POST("/webhook", webhookHandler.HandleStripeWebhook)
|
||||
|
||||
// =============================================
|
||||
// User Endpoints (require JWT auth)
|
||||
// =============================================
|
||||
user := v1.Group("")
|
||||
user.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
{
|
||||
// Subscription status and management
|
||||
user.GET("/status", billingHandler.GetBillingStatus)
|
||||
user.GET("/plans", billingHandler.GetPlans)
|
||||
user.POST("/trial/start", billingHandler.StartTrial)
|
||||
user.POST("/change-plan", billingHandler.ChangePlan)
|
||||
user.POST("/cancel", billingHandler.CancelSubscription)
|
||||
user.GET("/portal", billingHandler.GetCustomerPortal)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Internal Endpoints (service-to-service)
|
||||
// =============================================
|
||||
internal := v1.Group("")
|
||||
internal.Use(middleware.InternalAPIKeyMiddleware(cfg.InternalAPIKey))
|
||||
{
|
||||
// Entitlements
|
||||
internal.GET("/entitlements/:userId", billingHandler.GetEntitlements)
|
||||
internal.GET("/entitlements/check/:userId/:feature", billingHandler.CheckEntitlement)
|
||||
|
||||
// Usage tracking
|
||||
internal.POST("/usage/track", billingHandler.TrackUsage)
|
||||
internal.GET("/usage/check/:userId/:type", billingHandler.CheckUsage)
|
||||
}
|
||||
}
|
||||
|
||||
// Start server
|
||||
port := cfg.Port
|
||||
if port == "" {
|
||||
port = "8083"
|
||||
}
|
||||
|
||||
log.Printf("Starting Billing Service on port %s", port)
|
||||
if err := router.Run(":" + port); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
module github.com/breakpilot/billing-service
|
||||
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/stripe/stripe-go/v76 v76.25.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.14.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.27.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.54.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
golang.org/x/arch v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.40.0 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/net v0.42.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.9 // indirect
|
||||
)
|
||||
@@ -0,0 +1,111 @@
|
||||
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
|
||||
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
|
||||
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
|
||||
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
|
||||
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
|
||||
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg=
|
||||
github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/stripe/stripe-go/v76 v76.25.0 h1:kmDoOTvdQSTQssQzWZQQkgbAR2Q8eXdMWbN/ylNalWA=
|
||||
github.com/stripe/stripe-go/v76 v76.25.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
|
||||
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
|
||||
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
|
||||
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
|
||||
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
|
||||
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,157 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// Config holds all configuration for the billing service
|
||||
type Config struct {
|
||||
// Server
|
||||
Port string
|
||||
Environment string
|
||||
|
||||
// Database
|
||||
DatabaseURL string
|
||||
|
||||
// JWT (shared with consent-service)
|
||||
JWTSecret string
|
||||
|
||||
// Stripe
|
||||
StripeSecretKey string
|
||||
StripeWebhookSecret string
|
||||
StripePublishableKey string
|
||||
StripeMockMode bool // If true, Stripe calls are mocked (for dev without Stripe keys)
|
||||
|
||||
// URLs
|
||||
BillingSuccessURL string
|
||||
BillingCancelURL string
|
||||
FrontendURL string
|
||||
|
||||
// Trial
|
||||
TrialPeriodDays int
|
||||
|
||||
// CORS
|
||||
AllowedOrigins []string
|
||||
|
||||
// Rate Limiting
|
||||
RateLimitRequests int
|
||||
RateLimitWindow int // in seconds
|
||||
|
||||
// Internal API Key (for service-to-service communication)
|
||||
InternalAPIKey string
|
||||
}
|
||||
|
||||
// Load loads configuration from environment variables
|
||||
func Load() (*Config, error) {
|
||||
// Load .env file if exists (for development)
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := &Config{
|
||||
Port: getEnv("PORT", "8083"),
|
||||
Environment: getEnv("ENVIRONMENT", "development"),
|
||||
DatabaseURL: getEnv("DATABASE_URL", ""),
|
||||
JWTSecret: getEnv("JWT_SECRET", ""),
|
||||
|
||||
// Stripe
|
||||
StripeSecretKey: getEnv("STRIPE_SECRET_KEY", ""),
|
||||
StripeWebhookSecret: getEnv("STRIPE_WEBHOOK_SECRET", ""),
|
||||
StripePublishableKey: getEnv("STRIPE_PUBLISHABLE_KEY", ""),
|
||||
StripeMockMode: getEnvBool("STRIPE_MOCK_MODE", false),
|
||||
|
||||
// URLs
|
||||
BillingSuccessURL: getEnv("BILLING_SUCCESS_URL", "http://localhost:8000/app/billing/success"),
|
||||
BillingCancelURL: getEnv("BILLING_CANCEL_URL", "http://localhost:8000/app/billing/cancel"),
|
||||
FrontendURL: getEnv("FRONTEND_URL", "http://localhost:8000"),
|
||||
|
||||
// Trial
|
||||
TrialPeriodDays: getEnvInt("TRIAL_PERIOD_DAYS", 7),
|
||||
|
||||
// Rate Limiting
|
||||
RateLimitRequests: getEnvInt("RATE_LIMIT_REQUESTS", 100),
|
||||
RateLimitWindow: getEnvInt("RATE_LIMIT_WINDOW", 60),
|
||||
|
||||
// Internal API
|
||||
InternalAPIKey: getEnv("INTERNAL_API_KEY", ""),
|
||||
}
|
||||
|
||||
// Parse allowed origins
|
||||
originsStr := getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000")
|
||||
cfg.AllowedOrigins = parseCommaSeparated(originsStr)
|
||||
|
||||
// Validate required fields
|
||||
if cfg.DatabaseURL == "" {
|
||||
return nil, fmt.Errorf("DATABASE_URL is required")
|
||||
}
|
||||
|
||||
if cfg.JWTSecret == "" {
|
||||
return nil, fmt.Errorf("JWT_SECRET is required")
|
||||
}
|
||||
|
||||
// Stripe key is required unless mock mode is enabled
|
||||
if cfg.StripeSecretKey == "" && !cfg.StripeMockMode {
|
||||
// In development mode, auto-enable mock mode if no Stripe key
|
||||
if cfg.Environment == "development" {
|
||||
cfg.StripeMockMode = true
|
||||
} else {
|
||||
return nil, fmt.Errorf("STRIPE_SECRET_KEY is required (set STRIPE_MOCK_MODE=true to bypass in dev)")
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// IsMockMode returns true if Stripe should be mocked
|
||||
func (c *Config) IsMockMode() bool {
|
||||
return c.StripeMockMode
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var result int
|
||||
fmt.Sscanf(value, "%d", &result)
|
||||
return result
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value == "true" || value == "1" || value == "yes"
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func parseCommaSeparated(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i == len(s) || s[i] == ',' {
|
||||
item := s[start:i]
|
||||
// Trim whitespace
|
||||
for len(item) > 0 && item[0] == ' ' {
|
||||
item = item[1:]
|
||||
}
|
||||
for len(item) > 0 && item[len(item)-1] == ' ' {
|
||||
item = item[:len(item)-1]
|
||||
}
|
||||
if item != "" {
|
||||
result = append(result, item)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// DB wraps the pgx pool
|
||||
type DB struct {
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the PostgreSQL database
|
||||
func Connect(databaseURL string) (*DB, error) {
|
||||
config, err := pgxpool.ParseConfig(databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse database URL: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
config.MaxConns = 15
|
||||
config.MinConns = 3
|
||||
config.MaxConnLifetime = time.Hour
|
||||
config.MaxConnIdleTime = 30 * time.Minute
|
||||
config.HealthCheckPeriod = time.Minute
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Test the connection
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &DB{Pool: pool}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection pool
|
||||
func (db *DB) Close() {
|
||||
db.Pool.Close()
|
||||
}
|
||||
|
||||
// Migrate runs database migrations for the billing service
|
||||
func Migrate(db *DB) error {
|
||||
ctx := context.Background()
|
||||
|
||||
migrations := []string{
|
||||
// =============================================
|
||||
// Billing Service Tables
|
||||
// =============================================
|
||||
|
||||
// Subscriptions - core subscription data
|
||||
`CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
stripe_customer_id VARCHAR(255),
|
||||
stripe_subscription_id VARCHAR(255) UNIQUE,
|
||||
plan_id VARCHAR(50) NOT NULL,
|
||||
status VARCHAR(30) NOT NULL DEFAULT 'trialing',
|
||||
trial_end TIMESTAMPTZ,
|
||||
current_period_start TIMESTAMPTZ,
|
||||
current_period_end TIMESTAMPTZ,
|
||||
cancel_at_period_end BOOLEAN DEFAULT FALSE,
|
||||
canceled_at TIMESTAMPTZ,
|
||||
ended_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE(user_id)
|
||||
)`,
|
||||
|
||||
// Billing Plans - cached from Stripe
|
||||
`CREATE TABLE IF NOT EXISTS billing_plans (
|
||||
id VARCHAR(50) PRIMARY KEY,
|
||||
stripe_price_id VARCHAR(255) UNIQUE,
|
||||
stripe_product_id VARCHAR(255),
|
||||
name VARCHAR(100) NOT NULL,
|
||||
description TEXT,
|
||||
price_cents INT NOT NULL,
|
||||
currency VARCHAR(3) DEFAULT 'eur',
|
||||
interval VARCHAR(10) DEFAULT 'month',
|
||||
features JSONB DEFAULT '{}',
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
sort_order INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// Usage Summary - aggregated usage per period
|
||||
`CREATE TABLE IF NOT EXISTS usage_summary (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
usage_type VARCHAR(50) NOT NULL,
|
||||
period_start TIMESTAMPTZ NOT NULL,
|
||||
total_count INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE(user_id, usage_type, period_start)
|
||||
)`,
|
||||
|
||||
// User Entitlements - cached entitlements for fast lookups
|
||||
`CREATE TABLE IF NOT EXISTS user_entitlements (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL UNIQUE,
|
||||
plan_id VARCHAR(50) NOT NULL,
|
||||
ai_requests_limit INT DEFAULT 0,
|
||||
ai_requests_used INT DEFAULT 0,
|
||||
documents_limit INT DEFAULT 0,
|
||||
documents_used INT DEFAULT 0,
|
||||
features JSONB DEFAULT '{}',
|
||||
period_start TIMESTAMPTZ,
|
||||
period_end TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// Stripe Webhook Events - for idempotency
|
||||
`CREATE TABLE IF NOT EXISTS stripe_webhook_events (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
stripe_event_id VARCHAR(255) UNIQUE NOT NULL,
|
||||
event_type VARCHAR(100) NOT NULL,
|
||||
processed BOOLEAN DEFAULT FALSE,
|
||||
processed_at TIMESTAMPTZ,
|
||||
payload JSONB,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// Billing Audit Log
|
||||
`CREATE TABLE IF NOT EXISTS billing_audit_log (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID,
|
||||
action VARCHAR(50) NOT NULL,
|
||||
entity_type VARCHAR(50),
|
||||
entity_id VARCHAR(255),
|
||||
old_value JSONB,
|
||||
new_value JSONB,
|
||||
metadata JSONB,
|
||||
ip_address INET,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// Invoices - cached from Stripe
|
||||
`CREATE TABLE IF NOT EXISTS invoices (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
stripe_invoice_id VARCHAR(255) UNIQUE NOT NULL,
|
||||
stripe_subscription_id VARCHAR(255),
|
||||
status VARCHAR(30) NOT NULL,
|
||||
amount_due INT NOT NULL,
|
||||
amount_paid INT DEFAULT 0,
|
||||
currency VARCHAR(3) DEFAULT 'eur',
|
||||
hosted_invoice_url TEXT,
|
||||
invoice_pdf TEXT,
|
||||
period_start TIMESTAMPTZ,
|
||||
period_end TIMESTAMPTZ,
|
||||
due_date TIMESTAMPTZ,
|
||||
paid_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// =============================================
|
||||
// Task-based Billing Tables
|
||||
// =============================================
|
||||
|
||||
// Account Usage - tracks task balance per account
|
||||
`CREATE TABLE IF NOT EXISTS account_usage (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
account_id UUID NOT NULL UNIQUE,
|
||||
plan VARCHAR(50) NOT NULL,
|
||||
monthly_task_allowance INT NOT NULL,
|
||||
carryover_months_cap INT DEFAULT 5,
|
||||
max_task_balance INT NOT NULL,
|
||||
task_balance INT NOT NULL,
|
||||
last_renewal_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// Tasks - individual task consumption records
|
||||
`CREATE TABLE IF NOT EXISTS tasks (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
account_id UUID NOT NULL,
|
||||
task_type VARCHAR(50) NOT NULL,
|
||||
consumed BOOLEAN DEFAULT TRUE,
|
||||
page_count INT DEFAULT 0,
|
||||
token_count INT DEFAULT 0,
|
||||
process_time INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
|
||||
// =============================================
|
||||
// Indexes
|
||||
// =============================================
|
||||
`CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe_customer ON subscriptions(stripe_customer_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_subscriptions_stripe_sub ON subscriptions(stripe_subscription_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_subscriptions_status ON subscriptions(status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_subscriptions_trial_end ON subscriptions(trial_end)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_usage_summary_user ON usage_summary(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_usage_summary_period ON usage_summary(period_start)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_usage_summary_type ON usage_summary(usage_type)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_entitlements_user ON user_entitlements(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_user_entitlements_plan ON user_entitlements(plan_id)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_event_id ON stripe_webhook_events(stripe_event_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_type ON stripe_webhook_events(event_type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_stripe_webhook_events_processed ON stripe_webhook_events(processed)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_user ON billing_audit_log(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_action ON billing_audit_log(action)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_billing_audit_log_created ON billing_audit_log(created_at)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_invoices_user ON invoices(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_invoices_stripe_invoice ON invoices(stripe_invoice_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_invoices_status ON invoices(status)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_account_usage_account ON account_usage(account_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_account_usage_plan ON account_usage(plan)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_account_usage_renewal ON account_usage(last_renewal_at)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_tasks_account ON tasks(account_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tasks_type ON tasks(task_type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)`,
|
||||
|
||||
// =============================================
|
||||
// Insert default plans
|
||||
// =============================================
|
||||
`INSERT INTO billing_plans (id, name, description, price_cents, currency, interval, features, sort_order)
|
||||
VALUES
|
||||
('basic', 'Basic', 'Perfekt für den Einstieg', 990, 'eur', 'month',
|
||||
'{"ai_requests_limit": 300, "documents_limit": 50, "feature_flags": ["basic_ai", "basic_documents"], "max_team_members": 1, "priority_support": false, "custom_branding": false}',
|
||||
1),
|
||||
('standard', 'Standard', 'Für regelmäßige Nutzer', 1990, 'eur', 'month',
|
||||
'{"ai_requests_limit": 1500, "documents_limit": 200, "feature_flags": ["basic_ai", "basic_documents", "templates", "batch_processing"], "max_team_members": 3, "priority_support": false, "custom_branding": false}',
|
||||
2),
|
||||
('premium', 'Premium', 'Für Teams und Power-User', 3990, 'eur', 'month',
|
||||
'{"ai_requests_limit": 5000, "documents_limit": 1000, "feature_flags": ["basic_ai", "basic_documents", "templates", "batch_processing", "team_features", "admin_panel", "audit_log", "api_access"], "max_team_members": 10, "priority_support": true, "custom_branding": true}',
|
||||
3)
|
||||
ON CONFLICT (id) DO NOTHING`,
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if _, err := db.Pool.Exec(ctx, migration); err != nil {
|
||||
return fmt.Errorf("failed to run migration: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,427 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/middleware"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/breakpilot/billing-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BillingHandler handles billing-related HTTP requests
|
||||
type BillingHandler struct {
|
||||
db *database.DB
|
||||
subscriptionService *services.SubscriptionService
|
||||
stripeService *services.StripeService
|
||||
entitlementService *services.EntitlementService
|
||||
usageService *services.UsageService
|
||||
}
|
||||
|
||||
// NewBillingHandler creates a new BillingHandler
|
||||
func NewBillingHandler(
|
||||
db *database.DB,
|
||||
subscriptionService *services.SubscriptionService,
|
||||
stripeService *services.StripeService,
|
||||
entitlementService *services.EntitlementService,
|
||||
usageService *services.UsageService,
|
||||
) *BillingHandler {
|
||||
return &BillingHandler{
|
||||
db: db,
|
||||
subscriptionService: subscriptionService,
|
||||
stripeService: stripeService,
|
||||
entitlementService: entitlementService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetBillingStatus returns the current billing status for a user
|
||||
// GET /api/v1/billing/status
|
||||
func (h *BillingHandler) GetBillingStatus(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get subscription",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get available plans
|
||||
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get plans",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := models.BillingStatusResponse{
|
||||
HasSubscription: subscription != nil,
|
||||
AvailablePlans: plans,
|
||||
}
|
||||
|
||||
if subscription != nil {
|
||||
// Get plan details
|
||||
plan, _ := h.subscriptionService.GetPlanByID(ctx, string(subscription.PlanID))
|
||||
|
||||
subInfo := &models.SubscriptionInfo{
|
||||
PlanID: subscription.PlanID,
|
||||
Status: subscription.Status,
|
||||
IsTrialing: subscription.Status == models.StatusTrialing,
|
||||
CancelAtPeriodEnd: subscription.CancelAtPeriodEnd,
|
||||
CurrentPeriodEnd: subscription.CurrentPeriodEnd,
|
||||
}
|
||||
|
||||
if plan != nil {
|
||||
subInfo.PlanName = plan.Name
|
||||
subInfo.PriceCents = plan.PriceCents
|
||||
subInfo.Currency = plan.Currency
|
||||
}
|
||||
|
||||
// Calculate trial days left
|
||||
if subscription.TrialEnd != nil && subscription.Status == models.StatusTrialing {
|
||||
// TODO: Calculate days left
|
||||
}
|
||||
|
||||
response.Subscription = subInfo
|
||||
|
||||
// Get task usage info (legacy usage tracking - see TaskService for new task-based usage)
|
||||
// TODO: Replace with TaskService.GetTaskUsageInfo for task-based billing
|
||||
_, _ = h.usageService.GetUsageSummary(ctx, userID)
|
||||
|
||||
// Get entitlements
|
||||
entitlements, _ := h.entitlementService.GetEntitlements(ctx, userID)
|
||||
if entitlements != nil {
|
||||
response.Entitlements = entitlements
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetPlans returns all available billing plans
|
||||
// GET /api/v1/billing/plans
|
||||
func (h *BillingHandler) GetPlans(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
plans, err := h.subscriptionService.GetAvailablePlans(ctx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get plans",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"plans": plans,
|
||||
})
|
||||
}
|
||||
|
||||
// StartTrial starts a trial for the user with a specific plan
|
||||
// POST /api/v1/billing/trial/start
|
||||
func (h *BillingHandler) StartTrial(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.StartTrialRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Check if user already has a subscription
|
||||
existing, _ := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if existing != nil {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "subscription_exists",
|
||||
"message": "User already has a subscription",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user email from context
|
||||
email, _ := c.Get("email")
|
||||
emailStr, _ := email.(string)
|
||||
|
||||
// Create Stripe checkout session
|
||||
checkoutURL, sessionID, err := h.stripeService.CreateCheckoutSession(ctx, userID, emailStr, req.PlanID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to create checkout session",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.StartTrialResponse{
|
||||
CheckoutURL: checkoutURL,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePlan changes the user's subscription plan
|
||||
// POST /api/v1/billing/change-plan
|
||||
func (h *BillingHandler) ChangePlan(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.ChangePlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get current subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil || subscription == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "no_subscription",
|
||||
"message": "No active subscription found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Change plan via Stripe
|
||||
err = h.stripeService.ChangePlan(ctx, subscription.StripeSubscriptionID, req.NewPlanID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to change plan",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.ChangePlanResponse{
|
||||
Success: true,
|
||||
Message: "Plan changed successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// CancelSubscription cancels the user's subscription at period end
|
||||
// POST /api/v1/billing/cancel
|
||||
func (h *BillingHandler) CancelSubscription(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get current subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil || subscription == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "no_subscription",
|
||||
"message": "No active subscription found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel at period end via Stripe
|
||||
err = h.stripeService.CancelSubscription(ctx, subscription.StripeSubscriptionID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to cancel subscription",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.CancelSubscriptionResponse{
|
||||
Success: true,
|
||||
Message: "Subscription will be canceled at the end of the billing period",
|
||||
})
|
||||
}
|
||||
|
||||
// GetCustomerPortal returns a URL to the Stripe customer portal
|
||||
// GET /api/v1/billing/portal
|
||||
func (h *BillingHandler) GetCustomerPortal(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID.String() == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get current subscription
|
||||
subscription, err := h.subscriptionService.GetByUserID(ctx, userID)
|
||||
if err != nil || subscription == nil || subscription.StripeCustomerID == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "no_subscription",
|
||||
"message": "No active subscription found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create portal session
|
||||
portalURL, err := h.stripeService.CreateCustomerPortalSession(ctx, subscription.StripeCustomerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "stripe_error",
|
||||
"message": "Failed to create portal session",
|
||||
"details": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.CustomerPortalResponse{
|
||||
PortalURL: portalURL,
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Internal Endpoints (Service-to-Service)
|
||||
// =============================================
|
||||
|
||||
// GetEntitlements returns entitlements for a user (internal)
|
||||
// GET /api/v1/billing/entitlements/:userId
|
||||
func (h *BillingHandler) GetEntitlements(c *gin.Context) {
|
||||
userIDStr := c.Param("userId")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
entitlements, err := h.entitlementService.GetEntitlementsByUserIDString(ctx, userIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to get entitlements",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if entitlements == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "not_found",
|
||||
"message": "No entitlements found for user",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, entitlements)
|
||||
}
|
||||
|
||||
// TrackUsage tracks usage for a user (internal)
|
||||
// POST /api/v1/billing/usage/track
|
||||
func (h *BillingHandler) TrackUsage(c *gin.Context) {
|
||||
var req models.TrackUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
quantity := req.Quantity
|
||||
if quantity <= 0 {
|
||||
quantity = 1
|
||||
}
|
||||
|
||||
err := h.usageService.TrackUsage(ctx, req.UserID, req.UsageType, quantity)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to track usage",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Usage tracked",
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUsage checks if usage is allowed (internal)
|
||||
// GET /api/v1/billing/usage/check/:userId/:type
|
||||
func (h *BillingHandler) CheckUsage(c *gin.Context) {
|
||||
userIDStr := c.Param("userId")
|
||||
usageType := c.Param("type")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
response, err := h.usageService.CheckUsageAllowed(ctx, userIDStr, usageType)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to check usage",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// CheckEntitlement checks if a user has a specific entitlement (internal)
|
||||
// GET /api/v1/billing/entitlements/check/:userId/:feature
|
||||
func (h *BillingHandler) CheckEntitlement(c *gin.Context) {
|
||||
userIDStr := c.Param("userId")
|
||||
feature := c.Param("feature")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
hasEntitlement, planID, err := h.entitlementService.CheckEntitlement(ctx, userIDStr, feature)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "internal_error",
|
||||
"message": "Failed to check entitlement",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.EntitlementCheckResponse{
|
||||
HasEntitlement: hasEntitlement,
|
||||
PlanID: planID,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,612 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Set Gin to test mode
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestGetPlans_ResponseFormat(t *testing.T) {
|
||||
// Test that GetPlans returns the expected response structure
|
||||
// Since we don't have a real database connection in unit tests,
|
||||
// we test the expected structure and format
|
||||
|
||||
// Test that default plans are well-formed
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
if len(plans) == 0 {
|
||||
t.Error("Default plans should not be empty")
|
||||
}
|
||||
|
||||
for _, plan := range plans {
|
||||
// Verify JSON serialization works
|
||||
data, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal plan %s: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Verify we can unmarshal back
|
||||
var decoded models.BillingPlan
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to unmarshal plan %s: %v", plan.ID, err)
|
||||
}
|
||||
|
||||
// Verify key fields
|
||||
if decoded.ID != plan.ID {
|
||||
t.Errorf("Plan ID mismatch: got %s, expected %s", decoded.ID, plan.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingStatusResponse_Structure(t *testing.T) {
|
||||
// Test the response structure
|
||||
response := models.BillingStatusResponse{
|
||||
HasSubscription: true,
|
||||
Subscription: &models.SubscriptionInfo{
|
||||
PlanID: models.PlanStandard,
|
||||
PlanName: "Standard",
|
||||
Status: models.StatusActive,
|
||||
IsTrialing: false,
|
||||
CancelAtPeriodEnd: false,
|
||||
PriceCents: 1990,
|
||||
Currency: "eur",
|
||||
},
|
||||
TaskUsage: &models.TaskUsageInfo{
|
||||
TasksAvailable: 85,
|
||||
MaxTasks: 500,
|
||||
InfoText: "Aufgaben verfuegbar: 85 von max. 500",
|
||||
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
|
||||
},
|
||||
Entitlements: &models.EntitlementInfo{
|
||||
Features: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
|
||||
MaxTeamMembers: 3,
|
||||
PrioritySupport: false,
|
||||
CustomBranding: false,
|
||||
BatchProcessing: true,
|
||||
CustomTemplates: true,
|
||||
FairUseMode: false,
|
||||
},
|
||||
AvailablePlans: models.GetDefaultPlans(),
|
||||
}
|
||||
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal BillingStatusResponse: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check required fields exist
|
||||
if _, ok := decoded["has_subscription"]; !ok {
|
||||
t.Error("Response should have 'has_subscription' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartTrialRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.StartTrialRequest
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid basic plan",
|
||||
request: models.StartTrialRequest{PlanID: models.PlanBasic},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid standard plan",
|
||||
request: models.StartTrialRequest{PlanID: models.PlanStandard},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid premium plan",
|
||||
request: models.StartTrialRequest{PlanID: models.PlanPremium},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(tt.request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
var decoded models.StartTrialRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
if decoded.PlanID != tt.request.PlanID {
|
||||
t.Errorf("PlanID mismatch: got %s, expected %s", decoded.PlanID, tt.request.PlanID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePlanRequest_Structure(t *testing.T) {
|
||||
request := models.ChangePlanRequest{
|
||||
NewPlanID: models.PlanPremium,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ChangePlanRequest: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["new_plan_id"]; !ok {
|
||||
t.Error("Request should have 'new_plan_id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartTrialResponse_Structure(t *testing.T) {
|
||||
response := models.StartTrialResponse{
|
||||
CheckoutURL: "https://checkout.stripe.com/c/pay/cs_test_123",
|
||||
SessionID: "cs_test_123",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal StartTrialResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["checkout_url"]; !ok {
|
||||
t.Error("Response should have 'checkout_url' field")
|
||||
}
|
||||
if _, ok := decoded["session_id"]; !ok {
|
||||
t.Error("Response should have 'session_id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelSubscriptionResponse_Structure(t *testing.T) {
|
||||
response := models.CancelSubscriptionResponse{
|
||||
Success: true,
|
||||
Message: "Subscription will be canceled at the end of the billing period",
|
||||
CancelDate: "2025-01-16",
|
||||
ActiveUntil: "2025-01-16",
|
||||
}
|
||||
|
||||
_, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CancelSubscriptionResponse: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Error("Success should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomerPortalResponse_Structure(t *testing.T) {
|
||||
response := models.CustomerPortalResponse{
|
||||
PortalURL: "https://billing.stripe.com/p/session/test_123",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CustomerPortalResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["portal_url"]; !ok {
|
||||
t.Error("Response should have 'portal_url' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntitlementCheckResponse_Structure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.EntitlementCheckResponse
|
||||
}{
|
||||
{
|
||||
name: "Has entitlement",
|
||||
response: models.EntitlementCheckResponse{
|
||||
HasEntitlement: true,
|
||||
PlanID: models.PlanStandard,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No entitlement",
|
||||
response: models.EntitlementCheckResponse{
|
||||
HasEntitlement: false,
|
||||
PlanID: models.PlanBasic,
|
||||
Message: "Feature not available in this plan",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal EntitlementCheckResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["has_entitlement"]; !ok {
|
||||
t.Error("Response should have 'has_entitlement' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackUsageRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.TrackUsageRequest
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "Valid AI request",
|
||||
request: models.TrackUsageRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
UsageType: "ai_request",
|
||||
Quantity: 1,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "Valid document created",
|
||||
request: models.TrackUsageRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
UsageType: "document_created",
|
||||
Quantity: 1,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple quantity",
|
||||
request: models.TrackUsageRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
UsageType: "ai_request",
|
||||
Quantity: 5,
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal TrackUsageRequest: %v", err)
|
||||
}
|
||||
|
||||
var decoded models.TrackUsageRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal TrackUsageRequest: %v", err)
|
||||
}
|
||||
|
||||
if decoded.UserID != tt.request.UserID {
|
||||
t.Errorf("UserID mismatch: got %s, expected %s", decoded.UserID, tt.request.UserID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUsageResponse_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.CheckUsageResponse
|
||||
}{
|
||||
{
|
||||
name: "Allowed response",
|
||||
response: models.CheckUsageResponse{
|
||||
Allowed: true,
|
||||
CurrentUsage: 450,
|
||||
Limit: 1500,
|
||||
Remaining: 1050,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Limit reached",
|
||||
response: models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
CurrentUsage: 1500,
|
||||
Limit: 1500,
|
||||
Remaining: 0,
|
||||
Message: "Usage limit reached for ai_request (1500/1500)",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CheckUsageResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["allowed"]; !ok {
|
||||
t.Error("Response should have 'allowed' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeTaskRequest_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.ConsumeTaskRequest
|
||||
}{
|
||||
{
|
||||
name: "Correction task",
|
||||
request: models.ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: models.TaskTypeCorrection,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Letter task",
|
||||
request: models.ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: models.TaskTypeLetter,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Batch task",
|
||||
request: models.ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: models.TaskTypeBatch,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.request)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ConsumeTaskRequest: %v", err)
|
||||
}
|
||||
|
||||
var decoded models.ConsumeTaskRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal ConsumeTaskRequest: %v", err)
|
||||
}
|
||||
|
||||
if decoded.TaskType != tt.request.TaskType {
|
||||
t.Errorf("TaskType mismatch: got %s, expected %s", decoded.TaskType, tt.request.TaskType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeTaskResponse_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.ConsumeTaskResponse
|
||||
}{
|
||||
{
|
||||
name: "Successful consumption",
|
||||
response: models.ConsumeTaskResponse{
|
||||
Success: true,
|
||||
TaskID: "task-uuid-123",
|
||||
TasksRemaining: 49,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Limit reached",
|
||||
response: models.ConsumeTaskResponse{
|
||||
Success: false,
|
||||
TasksRemaining: 0,
|
||||
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ConsumeTaskResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["success"]; !ok {
|
||||
t.Error("Response should have 'success' field")
|
||||
}
|
||||
if _, ok := decoded["tasks_remaining"]; !ok {
|
||||
t.Error("Response should have 'tasks_remaining' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTaskAllowedResponse_Format(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response models.CheckTaskAllowedResponse
|
||||
}{
|
||||
{
|
||||
name: "Task allowed",
|
||||
response: models.CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: 50,
|
||||
MaxTasks: 150,
|
||||
PlanID: models.PlanBasic,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Task not allowed",
|
||||
response: models.CheckTaskAllowedResponse{
|
||||
Allowed: false,
|
||||
TasksAvailable: 0,
|
||||
MaxTasks: 150,
|
||||
PlanID: models.PlanBasic,
|
||||
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Premium Fair Use",
|
||||
response: models.CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: 1000,
|
||||
MaxTasks: 5000,
|
||||
PlanID: models.PlanPremium,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CheckTaskAllowedResponse: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Response is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := decoded["allowed"]; !ok {
|
||||
t.Error("Response should have 'allowed' field")
|
||||
}
|
||||
if _, ok := decoded["tasks_available"]; !ok {
|
||||
t.Error("Response should have 'tasks_available' field")
|
||||
}
|
||||
if _, ok := decoded["plan_id"]; !ok {
|
||||
t.Error("Response should have 'plan_id' field")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP Handler Tests (without DB)
|
||||
|
||||
func TestHTTPErrorResponse_Format(t *testing.T) {
|
||||
// Test standard error response format
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Simulate an error response
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User not authenticated",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status 401, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := response["error"]; !ok {
|
||||
t.Error("Error response should have 'error' field")
|
||||
}
|
||||
if _, ok := response["message"]; !ok {
|
||||
t.Error("Error response should have 'message' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSuccessResponse_Format(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Simulate a success response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Operation completed",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["success"] != true {
|
||||
t.Error("Success response should have success=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestParsing_InvalidJSON(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request with invalid JSON
|
||||
invalidJSON := []byte(`{"plan_id": }`) // Invalid JSON
|
||||
c.Request = httptest.NewRequest("POST", "/test", bytes.NewReader(invalidJSON))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
var req models.StartTrialRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Should return error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPHeaders_ContentType(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"test": "value"})
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json; charset=utf-8" {
|
||||
t.Errorf("Expected JSON content type, got %s", contentType)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v76/webhook"
|
||||
)
|
||||
|
||||
// WebhookHandler handles Stripe webhook events
|
||||
type WebhookHandler struct {
|
||||
db *database.DB
|
||||
webhookSecret string
|
||||
subscriptionService *services.SubscriptionService
|
||||
entitlementService *services.EntitlementService
|
||||
}
|
||||
|
||||
// NewWebhookHandler creates a new WebhookHandler
|
||||
func NewWebhookHandler(
|
||||
db *database.DB,
|
||||
webhookSecret string,
|
||||
subscriptionService *services.SubscriptionService,
|
||||
entitlementService *services.EntitlementService,
|
||||
) *WebhookHandler {
|
||||
return &WebhookHandler{
|
||||
db: db,
|
||||
webhookSecret: webhookSecret,
|
||||
subscriptionService: subscriptionService,
|
||||
entitlementService: entitlementService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStripeWebhook handles incoming Stripe webhook events
|
||||
// POST /api/v1/billing/webhook
|
||||
func (h *WebhookHandler) HandleStripeWebhook(c *gin.Context) {
|
||||
// Read the request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("Webhook: Error reading body: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "cannot read body"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get the Stripe signature header
|
||||
sigHeader := c.GetHeader("Stripe-Signature")
|
||||
if sigHeader == "" {
|
||||
log.Printf("Webhook: Missing Stripe-Signature header")
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the webhook signature
|
||||
event, err := webhook.ConstructEvent(body, sigHeader, h.webhookSecret)
|
||||
if err != nil {
|
||||
log.Printf("Webhook: Signature verification failed: %v", err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Check if we've already processed this event (idempotency)
|
||||
processed, err := h.subscriptionService.IsEventProcessed(ctx, event.ID)
|
||||
if err != nil {
|
||||
log.Printf("Webhook: Error checking event: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
if processed {
|
||||
log.Printf("Webhook: Event %s already processed", event.ID)
|
||||
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Mark event as being processed
|
||||
if err := h.subscriptionService.MarkEventProcessing(ctx, event.ID, string(event.Type)); err != nil {
|
||||
log.Printf("Webhook: Error marking event: %v", err)
|
||||
}
|
||||
|
||||
// Handle the event based on type
|
||||
var handleErr error
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
handleErr = h.handleCheckoutSessionCompleted(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.subscription.created":
|
||||
handleErr = h.handleSubscriptionCreated(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.subscription.updated":
|
||||
handleErr = h.handleSubscriptionUpdated(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.subscription.deleted":
|
||||
handleErr = h.handleSubscriptionDeleted(ctx, event.Data.Raw)
|
||||
|
||||
case "invoice.paid":
|
||||
handleErr = h.handleInvoicePaid(ctx, event.Data.Raw)
|
||||
|
||||
case "invoice.payment_failed":
|
||||
handleErr = h.handleInvoicePaymentFailed(ctx, event.Data.Raw)
|
||||
|
||||
case "customer.created":
|
||||
log.Printf("Webhook: Customer created - %s", event.ID)
|
||||
|
||||
default:
|
||||
log.Printf("Webhook: Unhandled event type: %s", event.Type)
|
||||
}
|
||||
|
||||
if handleErr != nil {
|
||||
log.Printf("Webhook: Error handling %s: %v", event.Type, handleErr)
|
||||
// Mark event as failed
|
||||
h.subscriptionService.MarkEventFailed(ctx, event.ID, handleErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Mark event as processed
|
||||
if err := h.subscriptionService.MarkEventProcessed(ctx, event.ID); err != nil {
|
||||
log.Printf("Webhook: Error marking event processed: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "processed"})
|
||||
}
|
||||
|
||||
// handleCheckoutSessionCompleted handles successful checkout
|
||||
func (h *WebhookHandler) handleCheckoutSessionCompleted(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing checkout.session.completed")
|
||||
|
||||
// Parse checkout session from data
|
||||
// The actual implementation will parse the JSON and create/update subscription
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse checkout session data
|
||||
// 2. Extract customer_id, subscription_id, user_id (from metadata)
|
||||
// 3. Create or update subscription record
|
||||
// 4. Update entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionCreated handles new subscription creation
|
||||
func (h *WebhookHandler) handleSubscriptionCreated(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing customer.subscription.created")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse subscription data
|
||||
// 2. Extract status, plan, trial_end, etc.
|
||||
// 3. Create subscription record
|
||||
// 4. Set up initial entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionUpdated handles subscription updates
|
||||
func (h *WebhookHandler) handleSubscriptionUpdated(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing customer.subscription.updated")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse subscription data
|
||||
// 2. Update subscription record (status, plan, cancel_at_period_end, etc.)
|
||||
// 3. Update entitlements if plan changed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSubscriptionDeleted handles subscription cancellation
|
||||
func (h *WebhookHandler) handleSubscriptionDeleted(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing customer.subscription.deleted")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse subscription data
|
||||
// 2. Update subscription status to canceled/expired
|
||||
// 3. Remove or downgrade entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaid handles successful invoice payment
|
||||
func (h *WebhookHandler) handleInvoicePaid(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing invoice.paid")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse invoice data
|
||||
// 2. Update subscription period
|
||||
// 3. Reset usage counters for new period
|
||||
// 4. Store invoice record
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoicePaymentFailed handles failed invoice payment
|
||||
func (h *WebhookHandler) handleInvoicePaymentFailed(ctx interface{}, data []byte) error {
|
||||
log.Printf("Webhook: Processing invoice.payment_failed")
|
||||
|
||||
// TODO: Implementation
|
||||
// 1. Parse invoice data
|
||||
// 2. Update subscription status to past_due
|
||||
// 3. Send notification to user
|
||||
// 4. Possibly restrict access
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,433 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestWebhookEventTypes tests the event types we handle
|
||||
func TestWebhookEventTypes(t *testing.T) {
|
||||
eventTypes := []struct {
|
||||
eventType string
|
||||
shouldHandle bool
|
||||
}{
|
||||
{"checkout.session.completed", true},
|
||||
{"customer.subscription.created", true},
|
||||
{"customer.subscription.updated", true},
|
||||
{"customer.subscription.deleted", true},
|
||||
{"invoice.paid", true},
|
||||
{"invoice.payment_failed", true},
|
||||
{"customer.created", true}, // Handled but just logged
|
||||
{"unknown.event.type", false},
|
||||
}
|
||||
|
||||
for _, tt := range eventTypes {
|
||||
t.Run(tt.eventType, func(t *testing.T) {
|
||||
if tt.eventType == "" {
|
||||
t.Error("Event type should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookRequest_MissingSignature tests handling of missing signature
|
||||
func TestWebhookRequest_MissingSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request without Stripe-Signature header
|
||||
body := []byte(`{"id": "evt_test_123", "type": "test.event"}`)
|
||||
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
// Note: No Stripe-Signature header
|
||||
|
||||
// Simulate the check we do in the handler
|
||||
sigHeader := c.GetHeader("Stripe-Signature")
|
||||
if sigHeader == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing signature"})
|
||||
}
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for missing signature, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "missing signature" {
|
||||
t.Errorf("Expected 'missing signature' error, got '%v'", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookRequest_EmptyBody tests handling of empty request body
|
||||
func TestWebhookRequest_EmptyBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create request with empty body
|
||||
c.Request = httptest.NewRequest("POST", "/webhook", bytes.NewReader([]byte{}))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("Stripe-Signature", "t=123,v1=signature")
|
||||
|
||||
// Read the body
|
||||
body := make([]byte, 0)
|
||||
|
||||
// Simulate empty body handling
|
||||
if len(body) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "empty body"})
|
||||
}
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for empty body, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookIdempotency tests idempotency behavior
|
||||
func TestWebhookIdempotency(t *testing.T) {
|
||||
// Test that the same event ID should not be processed twice
|
||||
eventID := "evt_test_123456789"
|
||||
|
||||
// Simulate event tracking
|
||||
processedEvents := make(map[string]bool)
|
||||
|
||||
// First time - should process
|
||||
if !processedEvents[eventID] {
|
||||
processedEvents[eventID] = true
|
||||
}
|
||||
|
||||
// Second time - should skip
|
||||
alreadyProcessed := processedEvents[eventID]
|
||||
if !alreadyProcessed {
|
||||
t.Error("Event should be marked as processed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_Processed tests successful webhook response
|
||||
func TestWebhookResponse_Processed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "processed"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "processed" {
|
||||
t.Errorf("Expected status 'processed', got '%v'", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_AlreadyProcessed tests idempotent response
|
||||
func TestWebhookResponse_AlreadyProcessed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "already_processed"})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["status"] != "already_processed" {
|
||||
t.Errorf("Expected status 'already_processed', got '%v'", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_InternalError tests error response
|
||||
func TestWebhookResponse_InternalError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler error"})
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "handler error" {
|
||||
t.Errorf("Expected 'handler error', got '%v'", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookResponse_InvalidSignature tests signature verification failure
|
||||
func TestWebhookResponse_InvalidSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid signature"})
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status 401, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "invalid signature" {
|
||||
t.Errorf("Expected 'invalid signature', got '%v'", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckoutSessionCompleted_EventStructure tests the event data structure
|
||||
func TestCheckoutSessionCompleted_EventStructure(t *testing.T) {
|
||||
// Test the expected structure of a checkout.session.completed event
|
||||
eventData := map[string]interface{}{
|
||||
"id": "cs_test_123",
|
||||
"customer": "cus_test_456",
|
||||
"subscription": "sub_test_789",
|
||||
"mode": "subscription",
|
||||
"payment_status": "paid",
|
||||
"status": "complete",
|
||||
"metadata": map[string]interface{}{
|
||||
"user_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"plan_id": "standard",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["customer"] == nil {
|
||||
t.Error("Event should have 'customer' field")
|
||||
}
|
||||
if decoded["subscription"] == nil {
|
||||
t.Error("Event should have 'subscription' field")
|
||||
}
|
||||
metadata, ok := decoded["metadata"].(map[string]interface{})
|
||||
if !ok || metadata["user_id"] == nil {
|
||||
t.Error("Event should have 'metadata.user_id' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionCreated_EventStructure tests subscription.created event structure
|
||||
func TestSubscriptionCreated_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "sub_test_123",
|
||||
"customer": "cus_test_456",
|
||||
"status": "trialing",
|
||||
"items": map[string]interface{}{
|
||||
"data": []map[string]interface{}{
|
||||
{
|
||||
"price": map[string]interface{}{
|
||||
"id": "price_test_789",
|
||||
"metadata": map[string]interface{}{"plan_id": "standard"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"trial_end": 1735689600,
|
||||
"current_period_end": 1735689600,
|
||||
"metadata": map[string]interface{}{
|
||||
"user_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"plan_id": "standard",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["status"] != "trialing" {
|
||||
t.Errorf("Expected status 'trialing', got '%v'", decoded["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionUpdated_StatusTransitions tests subscription status transitions
|
||||
func TestSubscriptionUpdated_StatusTransitions(t *testing.T) {
|
||||
validTransitions := []struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
{"trialing", "active"},
|
||||
{"active", "past_due"},
|
||||
{"past_due", "active"},
|
||||
{"active", "canceled"},
|
||||
{"trialing", "canceled"},
|
||||
}
|
||||
|
||||
for _, tt := range validTransitions {
|
||||
t.Run(tt.from+"->"+tt.to, func(t *testing.T) {
|
||||
if tt.from == "" || tt.to == "" {
|
||||
t.Error("Status should not be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvoicePaid_EventStructure tests invoice.paid event structure
|
||||
func TestInvoicePaid_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "in_test_123",
|
||||
"subscription": "sub_test_456",
|
||||
"customer": "cus_test_789",
|
||||
"status": "paid",
|
||||
"amount_paid": 1990,
|
||||
"currency": "eur",
|
||||
"period_start": 1735689600,
|
||||
"period_end": 1738368000,
|
||||
"hosted_invoice_url": "https://invoice.stripe.com/test",
|
||||
"invoice_pdf": "https://invoice.stripe.com/test.pdf",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["status"] != "paid" {
|
||||
t.Errorf("Expected status 'paid', got '%v'", decoded["status"])
|
||||
}
|
||||
if decoded["subscription"] == nil {
|
||||
t.Error("Event should have 'subscription' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvoicePaymentFailed_EventStructure tests invoice.payment_failed event structure
|
||||
func TestInvoicePaymentFailed_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "in_test_123",
|
||||
"subscription": "sub_test_456",
|
||||
"customer": "cus_test_789",
|
||||
"status": "open",
|
||||
"attempt_count": 1,
|
||||
"next_payment_attempt": 1735776000,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded["attempt_count"] == nil {
|
||||
t.Error("Event should have 'attempt_count' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionDeleted_EventStructure tests subscription.deleted event structure
|
||||
func TestSubscriptionDeleted_EventStructure(t *testing.T) {
|
||||
eventData := map[string]interface{}{
|
||||
"id": "sub_test_123",
|
||||
"customer": "cus_test_456",
|
||||
"status": "canceled",
|
||||
"ended_at": 1735689600,
|
||||
"canceled_at": 1735689600,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event data: %v", err)
|
||||
}
|
||||
|
||||
var decoded map[string]interface{}
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
// Verify required fields
|
||||
if decoded["status"] != "canceled" {
|
||||
t.Errorf("Expected status 'canceled', got '%v'", decoded["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripeSignatureFormat tests the Stripe signature header format
|
||||
func TestStripeSignatureFormat(t *testing.T) {
|
||||
// Stripe signature format: t=timestamp,v1=signature
|
||||
validSignatures := []string{
|
||||
"t=1609459200,v1=abc123def456",
|
||||
"t=1609459200,v1=signature_here,v0=old_signature",
|
||||
}
|
||||
|
||||
for _, sig := range validSignatures {
|
||||
if len(sig) < 10 {
|
||||
t.Errorf("Signature seems too short: %s", sig)
|
||||
}
|
||||
// Should start with timestamp
|
||||
if sig[:2] != "t=" {
|
||||
t.Errorf("Signature should start with 't=': %s", sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhookEventID_Format tests Stripe event ID format
|
||||
func TestWebhookEventID_Format(t *testing.T) {
|
||||
validEventIDs := []string{
|
||||
"evt_1234567890abcdef",
|
||||
"evt_test_123456789",
|
||||
"evt_live_987654321",
|
||||
}
|
||||
|
||||
for _, eventID := range validEventIDs {
|
||||
// Event IDs should start with "evt_"
|
||||
if len(eventID) < 10 || eventID[:4] != "evt_" {
|
||||
t.Errorf("Invalid event ID format: %s", eventID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,288 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserClaims represents the JWT claims for a user
|
||||
type UserClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CORS returns a CORS middleware
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// Allow localhost for development
|
||||
allowedOrigins := []string{
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8083",
|
||||
"https://breakpilot.app",
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, o := range allowedOrigins {
|
||||
if origin == o {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With, X-Internal-API-Key")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestLogger logs each request
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
|
||||
// Log only in development or for errors
|
||||
if status >= 400 {
|
||||
gin.DefaultWriter.Write([]byte(
|
||||
method + " " + path + " " +
|
||||
string(rune(status)) + " " +
|
||||
latency.String() + "\n",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiter implements a simple in-memory rate limiter
|
||||
func RateLimiter() gin.HandlerFunc {
|
||||
type client struct {
|
||||
count int
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
clients = make(map[string]*client)
|
||||
)
|
||||
|
||||
// Clean up old entries periodically
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, c := range clients {
|
||||
if time.Since(c.lastSeen) > time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := clients[ip]; !exists {
|
||||
clients[ip] = &client{}
|
||||
}
|
||||
|
||||
cli := clients[ip]
|
||||
|
||||
// Reset count if more than a minute has passed
|
||||
if time.Since(cli.lastSeen) > time.Minute {
|
||||
cli.count = 0
|
||||
}
|
||||
|
||||
cli.count++
|
||||
cli.lastSeen = time.Now()
|
||||
|
||||
// Allow 100 requests per minute
|
||||
if cli.count > 100 {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMiddleware validates JWT tokens
|
||||
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "missing_authorization",
|
||||
"message": "Authorization header is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token from "Bearer <token>"
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_authorization",
|
||||
"message": "Authorization header must be in format: Bearer <token>",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Parse and validate token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_token",
|
||||
"message": "Invalid or expired token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
// Set user info in context
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("role", claims.Role)
|
||||
c.Next()
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_claims",
|
||||
"message": "Invalid token claims",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InternalAPIKeyMiddleware validates internal API key for service-to-service communication
|
||||
func InternalAPIKeyMiddleware(apiKey string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if apiKey == "" {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "config_error",
|
||||
"message": "Internal API key not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
providedKey := c.GetHeader("X-Internal-API-Key")
|
||||
if providedKey == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "missing_api_key",
|
||||
"message": "X-Internal-API-Key header is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if providedKey != apiKey {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_api_key",
|
||||
"message": "Invalid API key",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AdminOnly ensures only admin users can access the route
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User role not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := role.(string)
|
||||
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "forbidden",
|
||||
"message": "Admin access required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID extracts the user ID from the context
|
||||
func GetUserID(c *gin.Context) (uuid.UUID, error) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// GetClientIP returns the client's IP address
|
||||
func GetClientIP(c *gin.Context) string {
|
||||
// Check X-Forwarded-For header first (for proxied requests)
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// GetUserAgent returns the client's User-Agent
|
||||
func GetUserAgent(c *gin.Context) string {
|
||||
return c.GetHeader("User-Agent")
|
||||
}
|
||||
@@ -0,0 +1,372 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SubscriptionStatus represents the status of a subscription
|
||||
type SubscriptionStatus string
|
||||
|
||||
const (
|
||||
StatusTrialing SubscriptionStatus = "trialing"
|
||||
StatusActive SubscriptionStatus = "active"
|
||||
StatusPastDue SubscriptionStatus = "past_due"
|
||||
StatusCanceled SubscriptionStatus = "canceled"
|
||||
StatusExpired SubscriptionStatus = "expired"
|
||||
)
|
||||
|
||||
// PlanID represents the available plan IDs
|
||||
type PlanID string
|
||||
|
||||
const (
|
||||
PlanBasic PlanID = "basic"
|
||||
PlanStandard PlanID = "standard"
|
||||
PlanPremium PlanID = "premium"
|
||||
)
|
||||
|
||||
// TaskType represents the type of task
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskTypeCorrection TaskType = "correction"
|
||||
TaskTypeLetter TaskType = "letter"
|
||||
TaskTypeMeeting TaskType = "meeting"
|
||||
TaskTypeBatch TaskType = "batch"
|
||||
TaskTypeOther TaskType = "other"
|
||||
)
|
||||
|
||||
// CarryoverMonthsCap is the maximum number of months tasks can accumulate
|
||||
const CarryoverMonthsCap = 5
|
||||
|
||||
// Subscription represents a user's subscription
|
||||
type Subscription struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
StripeCustomerID string `json:"stripe_customer_id"`
|
||||
StripeSubscriptionID string `json:"stripe_subscription_id"`
|
||||
PlanID PlanID `json:"plan_id"`
|
||||
Status SubscriptionStatus `json:"status"`
|
||||
TrialEnd *time.Time `json:"trial_end,omitempty"`
|
||||
CurrentPeriodEnd *time.Time `json:"current_period_end,omitempty"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BillingPlan represents a billing plan with its features and limits
|
||||
type BillingPlan struct {
|
||||
ID PlanID `json:"id"`
|
||||
StripePriceID string `json:"stripe_price_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
PriceCents int `json:"price_cents"` // Price in cents (990 = 9.90 EUR)
|
||||
Currency string `json:"currency"`
|
||||
Interval string `json:"interval"` // "month" or "year"
|
||||
Features PlanFeatures `json:"features"`
|
||||
IsActive bool `json:"is_active"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
// PlanFeatures represents the features and limits of a plan
|
||||
type PlanFeatures struct {
|
||||
// Task-based limits (primary billing unit)
|
||||
MonthlyTaskAllowance int `json:"monthly_task_allowance"` // Tasks per month
|
||||
MaxTaskBalance int `json:"max_task_balance"` // Max accumulated tasks (allowance * CarryoverMonthsCap)
|
||||
|
||||
// Legacy fields for backward compatibility (deprecated, use task-based limits)
|
||||
AIRequestsLimit int `json:"ai_requests_limit,omitempty"`
|
||||
DocumentsLimit int `json:"documents_limit,omitempty"`
|
||||
|
||||
// Feature flags
|
||||
FeatureFlags []string `json:"feature_flags"`
|
||||
MaxTeamMembers int `json:"max_team_members,omitempty"`
|
||||
PrioritySupport bool `json:"priority_support"`
|
||||
CustomBranding bool `json:"custom_branding"`
|
||||
BatchProcessing bool `json:"batch_processing"`
|
||||
CustomTemplates bool `json:"custom_templates"`
|
||||
|
||||
// Premium: Fair Use (no visible limit)
|
||||
FairUseMode bool `json:"fair_use_mode"`
|
||||
}
|
||||
|
||||
// Task represents a single task that consumes 1 unit from the balance
|
||||
type Task struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
AccountID uuid.UUID `json:"account_id"`
|
||||
TaskType TaskType `json:"task_type"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Consumed bool `json:"consumed"` // Always true when created
|
||||
// Internal metrics (not shown to user)
|
||||
PageCount int `json:"-"`
|
||||
TokenCount int `json:"-"`
|
||||
ProcessTime int `json:"-"` // in seconds
|
||||
}
|
||||
|
||||
// AccountUsage represents the task-based usage for an account
|
||||
type AccountUsage struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
AccountID uuid.UUID `json:"account_id"`
|
||||
PlanID PlanID `json:"plan"`
|
||||
MonthlyTaskAllowance int `json:"monthly_task_allowance"`
|
||||
CarryoverMonthsCap int `json:"carryover_months_cap"` // Always 5
|
||||
MaxTaskBalance int `json:"max_task_balance"` // allowance * cap
|
||||
TaskBalance int `json:"task_balance"` // Current available tasks
|
||||
LastRenewalAt time.Time `json:"last_renewal_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UsageSummary tracks usage for a specific period (internal metrics)
|
||||
type UsageSummary struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
UsageType string `json:"usage_type"` // "task", "page", "token"
|
||||
PeriodStart time.Time `json:"period_start"`
|
||||
TotalCount int `json:"total_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UserEntitlements represents cached entitlements for a user
|
||||
type UserEntitlements struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
PlanID PlanID `json:"plan_id"`
|
||||
TaskBalance int `json:"task_balance"`
|
||||
MaxBalance int `json:"max_balance"`
|
||||
Features PlanFeatures `json:"features"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
// Legacy fields for backward compatibility with old entitlement service
|
||||
AIRequestsLimit int `json:"ai_requests_limit"`
|
||||
AIRequestsUsed int `json:"ai_requests_used"`
|
||||
DocumentsLimit int `json:"documents_limit"`
|
||||
DocumentsUsed int `json:"documents_used"`
|
||||
}
|
||||
|
||||
// StripeWebhookEvent tracks processed webhook events for idempotency
|
||||
type StripeWebhookEvent struct {
|
||||
StripeEventID string `json:"stripe_event_id"`
|
||||
EventType string `json:"event_type"`
|
||||
Processed bool `json:"processed"`
|
||||
ProcessedAt time.Time `json:"processed_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// BillingStatusResponse is the response for the billing status endpoint
|
||||
type BillingStatusResponse struct {
|
||||
HasSubscription bool `json:"has_subscription"`
|
||||
Subscription *SubscriptionInfo `json:"subscription,omitempty"`
|
||||
TaskUsage *TaskUsageInfo `json:"task_usage,omitempty"`
|
||||
Entitlements *EntitlementInfo `json:"entitlements,omitempty"`
|
||||
AvailablePlans []BillingPlan `json:"available_plans,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionInfo contains subscription details for the response
|
||||
type SubscriptionInfo struct {
|
||||
PlanID PlanID `json:"plan_id"`
|
||||
PlanName string `json:"plan_name"`
|
||||
Status SubscriptionStatus `json:"status"`
|
||||
IsTrialing bool `json:"is_trialing"`
|
||||
TrialDaysLeft int `json:"trial_days_left,omitempty"`
|
||||
CurrentPeriodEnd *time.Time `json:"current_period_end,omitempty"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
PriceCents int `json:"price_cents"`
|
||||
Currency string `json:"currency"`
|
||||
}
|
||||
|
||||
// TaskUsageInfo contains current task usage information
|
||||
// This is the ONLY usage info shown to users
|
||||
type TaskUsageInfo struct {
|
||||
TasksAvailable int `json:"tasks_available"` // Current balance
|
||||
MaxTasks int `json:"max_tasks"` // Max possible balance
|
||||
InfoText string `json:"info_text"` // "Aufgaben verfuegbar: X von max. Y"
|
||||
TooltipText string `json:"tooltip_text"` // "Aufgaben koennen sich bis zu 5 Monate ansammeln."
|
||||
}
|
||||
|
||||
// EntitlementInfo contains feature entitlements
|
||||
type EntitlementInfo struct {
|
||||
Features []string `json:"features"`
|
||||
MaxTeamMembers int `json:"max_team_members,omitempty"`
|
||||
PrioritySupport bool `json:"priority_support"`
|
||||
CustomBranding bool `json:"custom_branding"`
|
||||
BatchProcessing bool `json:"batch_processing"`
|
||||
CustomTemplates bool `json:"custom_templates"`
|
||||
FairUseMode bool `json:"fair_use_mode"` // Premium only
|
||||
}
|
||||
|
||||
// StartTrialRequest is the request to start a trial
|
||||
type StartTrialRequest struct {
|
||||
PlanID PlanID `json:"plan_id" binding:"required"`
|
||||
}
|
||||
|
||||
// StartTrialResponse is the response after starting a trial
|
||||
type StartTrialResponse struct {
|
||||
CheckoutURL string `json:"checkout_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// ChangePlanRequest is the request to change plans
|
||||
type ChangePlanRequest struct {
|
||||
NewPlanID PlanID `json:"new_plan_id" binding:"required"`
|
||||
}
|
||||
|
||||
// ChangePlanResponse is the response after changing plans
|
||||
type ChangePlanResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
EffectiveDate string `json:"effective_date,omitempty"`
|
||||
}
|
||||
|
||||
// CancelSubscriptionResponse is the response after canceling
|
||||
type CancelSubscriptionResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
CancelDate string `json:"cancel_date"`
|
||||
ActiveUntil string `json:"active_until"`
|
||||
}
|
||||
|
||||
// CustomerPortalResponse contains the portal URL
|
||||
type CustomerPortalResponse struct {
|
||||
PortalURL string `json:"portal_url"`
|
||||
}
|
||||
|
||||
// ConsumeTaskRequest is the request to consume a task (internal)
|
||||
type ConsumeTaskRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
TaskType TaskType `json:"task_type" binding:"required"`
|
||||
}
|
||||
|
||||
// ConsumeTaskResponse is the response after consuming a task
|
||||
type ConsumeTaskResponse struct {
|
||||
Success bool `json:"success"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
TasksRemaining int `json:"tasks_remaining"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// CheckTaskAllowedResponse is the response for task limit checks
|
||||
type CheckTaskAllowedResponse struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
TasksAvailable int `json:"tasks_available"`
|
||||
MaxTasks int `json:"max_tasks"`
|
||||
PlanID PlanID `json:"plan_id"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// EntitlementCheckResponse is the response for entitlement checks (internal)
|
||||
type EntitlementCheckResponse struct {
|
||||
HasEntitlement bool `json:"has_entitlement"`
|
||||
PlanID PlanID `json:"plan_id,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// TaskLimitError represents the error when task limit is reached
|
||||
type TaskLimitError struct {
|
||||
Error string `json:"error"`
|
||||
CurrentBalance int `json:"current_balance"`
|
||||
Plan PlanID `json:"plan"`
|
||||
}
|
||||
|
||||
// UsageInfo represents current usage information (legacy, prefer TaskUsageInfo)
|
||||
type UsageInfo struct {
|
||||
AIRequestsUsed int `json:"ai_requests_used"`
|
||||
AIRequestsLimit int `json:"ai_requests_limit"`
|
||||
AIRequestsPercent float64 `json:"ai_requests_percent"`
|
||||
DocumentsUsed int `json:"documents_used"`
|
||||
DocumentsLimit int `json:"documents_limit"`
|
||||
DocumentsPercent float64 `json:"documents_percent"`
|
||||
PeriodStart string `json:"period_start"`
|
||||
PeriodEnd string `json:"period_end"`
|
||||
}
|
||||
|
||||
// CheckUsageResponse is the response for legacy usage checks
|
||||
type CheckUsageResponse struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
CurrentUsage int `json:"current_usage"`
|
||||
Limit int `json:"limit"`
|
||||
Remaining int `json:"remaining"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// TrackUsageRequest is the request to track usage (internal)
|
||||
type TrackUsageRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
UsageType string `json:"usage_type" binding:"required"`
|
||||
Quantity int `json:"quantity"`
|
||||
}
|
||||
|
||||
// GetDefaultPlans returns the default billing plans with task-based limits
|
||||
func GetDefaultPlans() []BillingPlan {
|
||||
return []BillingPlan{
|
||||
{
|
||||
ID: PlanBasic,
|
||||
Name: "Basic",
|
||||
Description: "Perfekt fuer den Einstieg - Gelegentliche Nutzung",
|
||||
PriceCents: 990, // 9.90 EUR
|
||||
Currency: "eur",
|
||||
Interval: "month",
|
||||
Features: PlanFeatures{
|
||||
MonthlyTaskAllowance: 30, // 30 tasks/month
|
||||
MaxTaskBalance: 30 * CarryoverMonthsCap, // 150 max
|
||||
FeatureFlags: []string{"basic_ai", "basic_documents"},
|
||||
MaxTeamMembers: 1,
|
||||
PrioritySupport: false,
|
||||
CustomBranding: false,
|
||||
BatchProcessing: false,
|
||||
CustomTemplates: false,
|
||||
FairUseMode: false,
|
||||
},
|
||||
IsActive: true,
|
||||
SortOrder: 1,
|
||||
},
|
||||
{
|
||||
ID: PlanStandard,
|
||||
Name: "Standard",
|
||||
Description: "Fuer regelmaessige Nutzer - Mehrere Klassen und regelmaessige Korrekturen",
|
||||
PriceCents: 1990, // 19.90 EUR
|
||||
Currency: "eur",
|
||||
Interval: "month",
|
||||
Features: PlanFeatures{
|
||||
MonthlyTaskAllowance: 100, // 100 tasks/month
|
||||
MaxTaskBalance: 100 * CarryoverMonthsCap, // 500 max
|
||||
FeatureFlags: []string{"basic_ai", "basic_documents", "templates", "batch_processing"},
|
||||
MaxTeamMembers: 3,
|
||||
PrioritySupport: false,
|
||||
CustomBranding: false,
|
||||
BatchProcessing: true,
|
||||
CustomTemplates: true,
|
||||
FairUseMode: false,
|
||||
},
|
||||
IsActive: true,
|
||||
SortOrder: 2,
|
||||
},
|
||||
{
|
||||
ID: PlanPremium,
|
||||
Name: "Premium",
|
||||
Description: "Sorglos-Tarif - Vielnutzer, Teams, schulischer Kontext",
|
||||
PriceCents: 3990, // 39.90 EUR
|
||||
Currency: "eur",
|
||||
Interval: "month",
|
||||
Features: PlanFeatures{
|
||||
MonthlyTaskAllowance: 1000, // Very high (Fair Use)
|
||||
MaxTaskBalance: 1000 * CarryoverMonthsCap, // 5000 max (not shown to user)
|
||||
FeatureFlags: []string{"basic_ai", "basic_documents", "templates", "batch_processing", "team_features", "admin_panel", "audit_log", "api_access"},
|
||||
MaxTeamMembers: 10,
|
||||
PrioritySupport: true,
|
||||
CustomBranding: true,
|
||||
BatchProcessing: true,
|
||||
CustomTemplates: true,
|
||||
FairUseMode: true, // No visible limit
|
||||
},
|
||||
IsActive: true,
|
||||
SortOrder: 3,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateMaxTaskBalance calculates max task balance from monthly allowance
|
||||
func CalculateMaxTaskBalance(monthlyAllowance int) int {
|
||||
return monthlyAllowance * CarryoverMonthsCap
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCarryoverMonthsCap(t *testing.T) {
|
||||
// Verify the constant is set correctly
|
||||
if CarryoverMonthsCap != 5 {
|
||||
t.Errorf("CarryoverMonthsCap should be 5, got %d", CarryoverMonthsCap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateMaxTaskBalance(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
monthlyAllowance int
|
||||
expected int
|
||||
}{
|
||||
{"Basic plan", 30, 150},
|
||||
{"Standard plan", 100, 500},
|
||||
{"Premium plan", 1000, 5000},
|
||||
{"Zero allowance", 0, 0},
|
||||
{"Single task", 1, 5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CalculateMaxTaskBalance(tt.monthlyAllowance)
|
||||
if result != tt.expected {
|
||||
t.Errorf("CalculateMaxTaskBalance(%d) = %d, expected %d",
|
||||
tt.monthlyAllowance, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultPlans(t *testing.T) {
|
||||
plans := GetDefaultPlans()
|
||||
|
||||
if len(plans) != 3 {
|
||||
t.Fatalf("Expected 3 plans, got %d", len(plans))
|
||||
}
|
||||
|
||||
// Test Basic plan
|
||||
basic := plans[0]
|
||||
if basic.ID != PlanBasic {
|
||||
t.Errorf("First plan should be Basic, got %s", basic.ID)
|
||||
}
|
||||
if basic.PriceCents != 990 {
|
||||
t.Errorf("Basic price should be 990 cents, got %d", basic.PriceCents)
|
||||
}
|
||||
if basic.Features.MonthlyTaskAllowance != 30 {
|
||||
t.Errorf("Basic monthly allowance should be 30, got %d", basic.Features.MonthlyTaskAllowance)
|
||||
}
|
||||
if basic.Features.MaxTaskBalance != 150 {
|
||||
t.Errorf("Basic max balance should be 150, got %d", basic.Features.MaxTaskBalance)
|
||||
}
|
||||
if basic.Features.FairUseMode {
|
||||
t.Error("Basic should not have FairUseMode")
|
||||
}
|
||||
|
||||
// Test Standard plan
|
||||
standard := plans[1]
|
||||
if standard.ID != PlanStandard {
|
||||
t.Errorf("Second plan should be Standard, got %s", standard.ID)
|
||||
}
|
||||
if standard.PriceCents != 1990 {
|
||||
t.Errorf("Standard price should be 1990 cents, got %d", standard.PriceCents)
|
||||
}
|
||||
if standard.Features.MonthlyTaskAllowance != 100 {
|
||||
t.Errorf("Standard monthly allowance should be 100, got %d", standard.Features.MonthlyTaskAllowance)
|
||||
}
|
||||
if !standard.Features.BatchProcessing {
|
||||
t.Error("Standard should have BatchProcessing")
|
||||
}
|
||||
if !standard.Features.CustomTemplates {
|
||||
t.Error("Standard should have CustomTemplates")
|
||||
}
|
||||
|
||||
// Test Premium plan
|
||||
premium := plans[2]
|
||||
if premium.ID != PlanPremium {
|
||||
t.Errorf("Third plan should be Premium, got %s", premium.ID)
|
||||
}
|
||||
if premium.PriceCents != 3990 {
|
||||
t.Errorf("Premium price should be 3990 cents, got %d", premium.PriceCents)
|
||||
}
|
||||
if !premium.Features.FairUseMode {
|
||||
t.Error("Premium should have FairUseMode")
|
||||
}
|
||||
if !premium.Features.PrioritySupport {
|
||||
t.Error("Premium should have PrioritySupport")
|
||||
}
|
||||
if !premium.Features.CustomBranding {
|
||||
t.Error("Premium should have CustomBranding")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanIDConstants(t *testing.T) {
|
||||
if PlanBasic != "basic" {
|
||||
t.Errorf("PlanBasic should be 'basic', got '%s'", PlanBasic)
|
||||
}
|
||||
if PlanStandard != "standard" {
|
||||
t.Errorf("PlanStandard should be 'standard', got '%s'", PlanStandard)
|
||||
}
|
||||
if PlanPremium != "premium" {
|
||||
t.Errorf("PlanPremium should be 'premium', got '%s'", PlanPremium)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionStatusConstants(t *testing.T) {
|
||||
statuses := []struct {
|
||||
status SubscriptionStatus
|
||||
expected string
|
||||
}{
|
||||
{StatusTrialing, "trialing"},
|
||||
{StatusActive, "active"},
|
||||
{StatusPastDue, "past_due"},
|
||||
{StatusCanceled, "canceled"},
|
||||
{StatusExpired, "expired"},
|
||||
}
|
||||
|
||||
for _, tt := range statuses {
|
||||
if string(tt.status) != tt.expected {
|
||||
t.Errorf("Status %s should be '%s'", tt.status, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskTypeConstants(t *testing.T) {
|
||||
types := []struct {
|
||||
taskType TaskType
|
||||
expected string
|
||||
}{
|
||||
{TaskTypeCorrection, "correction"},
|
||||
{TaskTypeLetter, "letter"},
|
||||
{TaskTypeMeeting, "meeting"},
|
||||
{TaskTypeBatch, "batch"},
|
||||
{TaskTypeOther, "other"},
|
||||
}
|
||||
|
||||
for _, tt := range types {
|
||||
if string(tt.taskType) != tt.expected {
|
||||
t.Errorf("TaskType %s should be '%s'", tt.taskType, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanFeatures_CarryoverCalculation(t *testing.T) {
|
||||
plans := GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
expectedMax := plan.Features.MonthlyTaskAllowance * CarryoverMonthsCap
|
||||
if plan.Features.MaxTaskBalance != expectedMax {
|
||||
t.Errorf("Plan %s: MaxTaskBalance should be %d (allowance * 5), got %d",
|
||||
plan.ID, expectedMax, plan.Features.MaxTaskBalance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_AllPlansActive(t *testing.T) {
|
||||
plans := GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
if !plan.IsActive {
|
||||
t.Errorf("Plan %s should be active", plan.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_CurrencyIsEuro(t *testing.T) {
|
||||
plans := GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
if plan.Currency != "eur" {
|
||||
t.Errorf("Plan %s currency should be 'eur', got '%s'", plan.ID, plan.Currency)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_IntervalIsMonth(t *testing.T) {
|
||||
plans := GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
if plan.Interval != "month" {
|
||||
t.Errorf("Plan %s interval should be 'month', got '%s'", plan.ID, plan.Interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_SortOrder(t *testing.T) {
|
||||
plans := GetDefaultPlans()
|
||||
|
||||
for i, plan := range plans {
|
||||
expectedOrder := i + 1
|
||||
if plan.SortOrder != expectedOrder {
|
||||
t.Errorf("Plan %s sort order should be %d, got %d",
|
||||
plan.ID, expectedOrder, plan.SortOrder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskUsageInfo_FormatStrings(t *testing.T) {
|
||||
usage := TaskUsageInfo{
|
||||
TasksAvailable: 45,
|
||||
MaxTasks: 150,
|
||||
InfoText: "Aufgaben verfuegbar: 45 von max. 150",
|
||||
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
|
||||
}
|
||||
|
||||
if usage.TasksAvailable != 45 {
|
||||
t.Errorf("TasksAvailable should be 45, got %d", usage.TasksAvailable)
|
||||
}
|
||||
if usage.MaxTasks != 150 {
|
||||
t.Errorf("MaxTasks should be 150, got %d", usage.MaxTasks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTaskAllowedResponse_Allowed(t *testing.T) {
|
||||
response := CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: 50,
|
||||
MaxTasks: 150,
|
||||
PlanID: PlanBasic,
|
||||
}
|
||||
|
||||
if !response.Allowed {
|
||||
t.Error("Response should be allowed")
|
||||
}
|
||||
if response.Message != "" {
|
||||
t.Errorf("Message should be empty for allowed response, got '%s'", response.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTaskAllowedResponse_NotAllowed(t *testing.T) {
|
||||
response := CheckTaskAllowedResponse{
|
||||
Allowed: false,
|
||||
TasksAvailable: 0,
|
||||
MaxTasks: 150,
|
||||
PlanID: PlanBasic,
|
||||
Message: "Dein Aufgaben-Kontingent ist aufgebraucht.",
|
||||
}
|
||||
|
||||
if response.Allowed {
|
||||
t.Error("Response should not be allowed")
|
||||
}
|
||||
if response.TasksAvailable != 0 {
|
||||
t.Errorf("TasksAvailable should be 0, got %d", response.TasksAvailable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskLimitError(t *testing.T) {
|
||||
err := TaskLimitError{
|
||||
Error: "TASK_LIMIT_REACHED",
|
||||
CurrentBalance: 0,
|
||||
Plan: PlanBasic,
|
||||
}
|
||||
|
||||
if err.Error != "TASK_LIMIT_REACHED" {
|
||||
t.Errorf("Error should be 'TASK_LIMIT_REACHED', got '%s'", err.Error)
|
||||
}
|
||||
if err.CurrentBalance != 0 {
|
||||
t.Errorf("CurrentBalance should be 0, got %d", err.CurrentBalance)
|
||||
}
|
||||
if err.Plan != PlanBasic {
|
||||
t.Errorf("Plan should be basic, got '%s'", err.Plan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeTaskRequest(t *testing.T) {
|
||||
req := ConsumeTaskRequest{
|
||||
UserID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
TaskType: TaskTypeCorrection,
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
t.Error("UserID should not be empty")
|
||||
}
|
||||
if req.TaskType != TaskTypeCorrection {
|
||||
t.Errorf("TaskType should be correction, got '%s'", req.TaskType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeTaskResponse_Success(t *testing.T) {
|
||||
resp := ConsumeTaskResponse{
|
||||
Success: true,
|
||||
TaskID: "task-123",
|
||||
TasksRemaining: 49,
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
t.Error("Response should be successful")
|
||||
}
|
||||
if resp.TasksRemaining != 49 {
|
||||
t.Errorf("TasksRemaining should be 49, got %d", resp.TasksRemaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntitlementInfo_Premium(t *testing.T) {
|
||||
premium := GetDefaultPlans()[2]
|
||||
|
||||
info := EntitlementInfo{
|
||||
Features: premium.Features.FeatureFlags,
|
||||
MaxTeamMembers: premium.Features.MaxTeamMembers,
|
||||
PrioritySupport: premium.Features.PrioritySupport,
|
||||
CustomBranding: premium.Features.CustomBranding,
|
||||
BatchProcessing: premium.Features.BatchProcessing,
|
||||
CustomTemplates: premium.Features.CustomTemplates,
|
||||
FairUseMode: premium.Features.FairUseMode,
|
||||
}
|
||||
|
||||
if !info.FairUseMode {
|
||||
t.Error("Premium should have FairUseMode")
|
||||
}
|
||||
if info.MaxTeamMembers != 10 {
|
||||
t.Errorf("Premium MaxTeamMembers should be 10, got %d", info.MaxTeamMembers)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// EntitlementService handles entitlement-related operations
|
||||
type EntitlementService struct {
|
||||
db *database.DB
|
||||
subService *SubscriptionService
|
||||
}
|
||||
|
||||
// NewEntitlementService creates a new EntitlementService
|
||||
func NewEntitlementService(db *database.DB, subService *SubscriptionService) *EntitlementService {
|
||||
return &EntitlementService{
|
||||
db: db,
|
||||
subService: subService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetEntitlements returns the entitlement info for a user
|
||||
func (s *EntitlementService) GetEntitlements(ctx context.Context, userID uuid.UUID) (*models.EntitlementInfo, error) {
|
||||
entitlements, err := s.getUserEntitlements(ctx, userID)
|
||||
if err != nil || entitlements == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.EntitlementInfo{
|
||||
Features: entitlements.Features.FeatureFlags,
|
||||
MaxTeamMembers: entitlements.Features.MaxTeamMembers,
|
||||
PrioritySupport: entitlements.Features.PrioritySupport,
|
||||
CustomBranding: entitlements.Features.CustomBranding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetEntitlementsByUserIDString returns entitlements by user ID string (for internal API)
|
||||
func (s *EntitlementService) GetEntitlementsByUserIDString(ctx context.Context, userIDStr string) (*models.UserEntitlements, error) {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.getUserEntitlements(ctx, userID)
|
||||
}
|
||||
|
||||
// getUserEntitlements retrieves or creates entitlements for a user
|
||||
func (s *EntitlementService) getUserEntitlements(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
|
||||
query := `
|
||||
SELECT id, user_id, plan_id, ai_requests_limit, ai_requests_used,
|
||||
documents_limit, documents_used, features, period_start, period_end,
|
||||
created_at, updated_at
|
||||
FROM user_entitlements
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
var ent models.UserEntitlements
|
||||
var featuresJSON []byte
|
||||
var periodStart, periodEnd *time.Time
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
|
||||
&ent.DocumentsLimit, &ent.DocumentsUsed, &featuresJSON, &periodStart, &periodEnd,
|
||||
nil, &ent.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
// Try to create entitlements based on subscription
|
||||
return s.createEntitlementsFromSubscription(ctx, userID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(featuresJSON) > 0 {
|
||||
json.Unmarshal(featuresJSON, &ent.Features)
|
||||
}
|
||||
|
||||
return &ent, nil
|
||||
}
|
||||
|
||||
// createEntitlementsFromSubscription creates entitlements based on user's subscription
|
||||
func (s *EntitlementService) createEntitlementsFromSubscription(ctx context.Context, userID uuid.UUID) (*models.UserEntitlements, error) {
|
||||
// Get user's subscription
|
||||
sub, err := s.subService.GetByUserID(ctx, userID)
|
||||
if err != nil || sub == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get plan details
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
|
||||
if err != nil || plan == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create entitlements
|
||||
return s.CreateEntitlements(ctx, userID, sub.PlanID, plan.Features, sub.CurrentPeriodEnd)
|
||||
}
|
||||
|
||||
// CreateEntitlements creates entitlements for a user
|
||||
func (s *EntitlementService) CreateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures, periodEnd *time.Time) (*models.UserEntitlements, error) {
|
||||
featuresJSON, _ := json.Marshal(features)
|
||||
|
||||
now := time.Now()
|
||||
periodStart := now
|
||||
|
||||
query := `
|
||||
INSERT INTO user_entitlements (
|
||||
user_id, plan_id, ai_requests_limit, ai_requests_used,
|
||||
documents_limit, documents_used, features, period_start, period_end
|
||||
) VALUES ($1, $2, $3, 0, $4, 0, $5, $6, $7)
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
plan_id = EXCLUDED.plan_id,
|
||||
ai_requests_limit = EXCLUDED.ai_requests_limit,
|
||||
documents_limit = EXCLUDED.documents_limit,
|
||||
features = EXCLUDED.features,
|
||||
period_start = EXCLUDED.period_start,
|
||||
period_end = EXCLUDED.period_end,
|
||||
updated_at = NOW()
|
||||
RETURNING id, user_id, plan_id, ai_requests_limit, ai_requests_used,
|
||||
documents_limit, documents_used, updated_at
|
||||
`
|
||||
|
||||
var ent models.UserEntitlements
|
||||
err := s.db.Pool.QueryRow(ctx, query,
|
||||
userID, planID, features.AIRequestsLimit, features.DocumentsLimit,
|
||||
featuresJSON, periodStart, periodEnd,
|
||||
).Scan(
|
||||
&ent.ID, &ent.UserID, &ent.PlanID, &ent.AIRequestsLimit, &ent.AIRequestsUsed,
|
||||
&ent.DocumentsLimit, &ent.DocumentsUsed, &ent.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ent.Features = features
|
||||
return &ent, nil
|
||||
}
|
||||
|
||||
// UpdateEntitlements updates entitlements for a user (e.g., on plan change)
|
||||
func (s *EntitlementService) UpdateEntitlements(ctx context.Context, userID uuid.UUID, planID models.PlanID, features models.PlanFeatures) error {
|
||||
featuresJSON, _ := json.Marshal(features)
|
||||
|
||||
query := `
|
||||
UPDATE user_entitlements SET
|
||||
plan_id = $2,
|
||||
ai_requests_limit = $3,
|
||||
documents_limit = $4,
|
||||
features = $5,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
userID, planID, features.AIRequestsLimit, features.DocumentsLimit, featuresJSON,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetUsageCounters resets usage counters for a new period
|
||||
func (s *EntitlementService) ResetUsageCounters(ctx context.Context, userID uuid.UUID, newPeriodStart, newPeriodEnd *time.Time) error {
|
||||
query := `
|
||||
UPDATE user_entitlements SET
|
||||
ai_requests_used = 0,
|
||||
documents_used = 0,
|
||||
period_start = $2,
|
||||
period_end = $3,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query, userID, newPeriodStart, newPeriodEnd)
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckEntitlement checks if a user has a specific feature entitlement
|
||||
func (s *EntitlementService) CheckEntitlement(ctx context.Context, userIDStr, feature string) (bool, models.PlanID, error) {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
ent, err := s.getUserEntitlements(ctx, userID)
|
||||
if err != nil || ent == nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
// Check if feature is in the feature flags
|
||||
for _, f := range ent.Features.FeatureFlags {
|
||||
if f == feature {
|
||||
return true, ent.PlanID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, ent.PlanID, nil
|
||||
}
|
||||
|
||||
// IncrementUsage increments a usage counter
|
||||
func (s *EntitlementService) IncrementUsage(ctx context.Context, userID uuid.UUID, usageType string, amount int) error {
|
||||
var column string
|
||||
switch usageType {
|
||||
case "ai_request":
|
||||
column = "ai_requests_used"
|
||||
case "document_created":
|
||||
column = "documents_used"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE user_entitlements SET
|
||||
` + column + ` = ` + column + ` + $2,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query, userID, amount)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteEntitlements removes entitlements for a user (on subscription cancellation)
|
||||
func (s *EntitlementService) DeleteEntitlements(ctx context.Context, userID uuid.UUID) error {
|
||||
query := `DELETE FROM user_entitlements WHERE user_id = $1`
|
||||
_, err := s.db.Pool.Exec(ctx, query, userID)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stripe/stripe-go/v76"
|
||||
"github.com/stripe/stripe-go/v76/billingportal/session"
|
||||
checkoutsession "github.com/stripe/stripe-go/v76/checkout/session"
|
||||
"github.com/stripe/stripe-go/v76/customer"
|
||||
"github.com/stripe/stripe-go/v76/price"
|
||||
"github.com/stripe/stripe-go/v76/product"
|
||||
"github.com/stripe/stripe-go/v76/subscription"
|
||||
)
|
||||
|
||||
// StripeService handles Stripe API interactions
|
||||
type StripeService struct {
|
||||
secretKey string
|
||||
webhookSecret string
|
||||
successURL string
|
||||
cancelURL string
|
||||
trialPeriodDays int64
|
||||
subService *SubscriptionService
|
||||
mockMode bool // If true, don't make real Stripe API calls
|
||||
}
|
||||
|
||||
// NewStripeService creates a new StripeService
|
||||
func NewStripeService(secretKey, webhookSecret, successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
|
||||
// Initialize Stripe with the secret key (only if not empty)
|
||||
if secretKey != "" {
|
||||
stripe.Key = secretKey
|
||||
}
|
||||
|
||||
return &StripeService{
|
||||
secretKey: secretKey,
|
||||
webhookSecret: webhookSecret,
|
||||
successURL: successURL,
|
||||
cancelURL: cancelURL,
|
||||
trialPeriodDays: int64(trialPeriodDays),
|
||||
subService: subService,
|
||||
mockMode: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMockStripeService creates a mock StripeService for development
|
||||
func NewMockStripeService(successURL, cancelURL string, trialPeriodDays int, subService *SubscriptionService) *StripeService {
|
||||
return &StripeService{
|
||||
secretKey: "",
|
||||
webhookSecret: "",
|
||||
successURL: successURL,
|
||||
cancelURL: cancelURL,
|
||||
trialPeriodDays: int64(trialPeriodDays),
|
||||
subService: subService,
|
||||
mockMode: true,
|
||||
}
|
||||
}
|
||||
|
||||
// IsMockMode returns true if running in mock mode
|
||||
func (s *StripeService) IsMockMode() bool {
|
||||
return s.mockMode
|
||||
}
|
||||
|
||||
// CreateCheckoutSession creates a Stripe Checkout session for trial start
|
||||
func (s *StripeService) CreateCheckoutSession(ctx context.Context, userID uuid.UUID, email string, planID models.PlanID) (string, string, error) {
|
||||
// Mock mode: return a fake URL for development
|
||||
if s.mockMode {
|
||||
mockSessionID := fmt.Sprintf("mock_cs_%s", uuid.New().String()[:8])
|
||||
mockURL := fmt.Sprintf("%s?session_id=%s&mock=true&plan=%s", s.successURL, mockSessionID, planID)
|
||||
return mockURL, mockSessionID, nil
|
||||
}
|
||||
|
||||
// Get plan details
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(planID))
|
||||
if err != nil || plan == nil {
|
||||
return "", "", fmt.Errorf("plan not found: %s", planID)
|
||||
}
|
||||
|
||||
// Ensure we have a Stripe price ID
|
||||
if plan.StripePriceID == "" {
|
||||
// Create product and price in Stripe if not exists
|
||||
priceID, err := s.ensurePriceExists(ctx, plan)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create stripe price: %w", err)
|
||||
}
|
||||
plan.StripePriceID = priceID
|
||||
}
|
||||
|
||||
// Create checkout session parameters
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(plan.StripePriceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
},
|
||||
},
|
||||
SuccessURL: stripe.String(s.successURL + "?session_id={CHECKOUT_SESSION_ID}"),
|
||||
CancelURL: stripe.String(s.cancelURL),
|
||||
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
|
||||
TrialPeriodDays: stripe.Int64(s.trialPeriodDays),
|
||||
Metadata: map[string]string{
|
||||
"user_id": userID.String(),
|
||||
"plan_id": string(planID),
|
||||
},
|
||||
},
|
||||
PaymentMethodCollection: stripe.String(string(stripe.CheckoutSessionPaymentMethodCollectionAlways)),
|
||||
Metadata: map[string]string{
|
||||
"user_id": userID.String(),
|
||||
"plan_id": string(planID),
|
||||
},
|
||||
}
|
||||
|
||||
// Set customer email if provided
|
||||
if email != "" {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
// Create the session
|
||||
sess, err := checkoutsession.New(params)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create checkout session: %w", err)
|
||||
}
|
||||
|
||||
return sess.URL, sess.ID, nil
|
||||
}
|
||||
|
||||
// ensurePriceExists creates a Stripe product and price if they don't exist
|
||||
func (s *StripeService) ensurePriceExists(ctx context.Context, plan *models.BillingPlan) (string, error) {
|
||||
// Create product
|
||||
productParams := &stripe.ProductParams{
|
||||
Name: stripe.String(plan.Name),
|
||||
Description: stripe.String(plan.Description),
|
||||
Metadata: map[string]string{
|
||||
"plan_id": string(plan.ID),
|
||||
},
|
||||
}
|
||||
|
||||
prod, err := product.New(productParams)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create product: %w", err)
|
||||
}
|
||||
|
||||
// Create price
|
||||
priceParams := &stripe.PriceParams{
|
||||
Product: stripe.String(prod.ID),
|
||||
UnitAmount: stripe.Int64(int64(plan.PriceCents)),
|
||||
Currency: stripe.String(plan.Currency),
|
||||
Recurring: &stripe.PriceRecurringParams{
|
||||
Interval: stripe.String(plan.Interval),
|
||||
},
|
||||
Metadata: map[string]string{
|
||||
"plan_id": string(plan.ID),
|
||||
},
|
||||
}
|
||||
|
||||
pr, err := price.New(priceParams)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create price: %w", err)
|
||||
}
|
||||
|
||||
// Update plan with Stripe IDs
|
||||
if err := s.subService.UpdatePlanStripePriceID(ctx, string(plan.ID), pr.ID, prod.ID); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Printf("Warning: Failed to update plan with Stripe IDs: %v\n", err)
|
||||
}
|
||||
|
||||
return pr.ID, nil
|
||||
}
|
||||
|
||||
// GetOrCreateCustomer gets or creates a Stripe customer for a user
|
||||
func (s *StripeService) GetOrCreateCustomer(ctx context.Context, email, name string, userID uuid.UUID) (string, error) {
|
||||
// Search for existing customer
|
||||
params := &stripe.CustomerSearchParams{
|
||||
SearchParams: stripe.SearchParams{
|
||||
Query: fmt.Sprintf("email:'%s'", email),
|
||||
},
|
||||
}
|
||||
|
||||
iter := customer.Search(params)
|
||||
for iter.Next() {
|
||||
cust := iter.Customer()
|
||||
// Check if this customer belongs to our user
|
||||
if cust.Metadata["user_id"] == userID.String() {
|
||||
return cust.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new customer
|
||||
customerParams := &stripe.CustomerParams{
|
||||
Email: stripe.String(email),
|
||||
Name: stripe.String(name),
|
||||
Metadata: map[string]string{
|
||||
"user_id": userID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
cust, err := customer.New(customerParams)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create customer: %w", err)
|
||||
}
|
||||
|
||||
return cust.ID, nil
|
||||
}
|
||||
|
||||
// ChangePlan changes a subscription to a new plan
|
||||
func (s *StripeService) ChangePlan(ctx context.Context, stripeSubID string, newPlanID models.PlanID) error {
|
||||
// Mock mode: just return success
|
||||
if s.mockMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get new plan details
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
|
||||
if err != nil || plan == nil {
|
||||
return fmt.Errorf("plan not found: %s", newPlanID)
|
||||
}
|
||||
|
||||
if plan.StripePriceID == "" {
|
||||
return fmt.Errorf("plan %s has no Stripe price ID", newPlanID)
|
||||
}
|
||||
|
||||
// Get current subscription
|
||||
sub, err := subscription.Get(stripeSubID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get subscription: %w", err)
|
||||
}
|
||||
|
||||
// Update subscription with new price
|
||||
params := &stripe.SubscriptionParams{
|
||||
Items: []*stripe.SubscriptionItemsParams{
|
||||
{
|
||||
ID: stripe.String(sub.Items.Data[0].ID),
|
||||
Price: stripe.String(plan.StripePriceID),
|
||||
},
|
||||
},
|
||||
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
|
||||
Metadata: map[string]string{
|
||||
"plan_id": string(newPlanID),
|
||||
},
|
||||
}
|
||||
|
||||
_, err = subscription.Update(stripeSubID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelSubscription cancels a subscription at period end
|
||||
func (s *StripeService) CancelSubscription(ctx context.Context, stripeSubID string) error {
|
||||
// Mock mode: just return success
|
||||
if s.mockMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(true),
|
||||
}
|
||||
|
||||
_, err := subscription.Update(stripeSubID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cancel subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReactivateSubscription removes the cancel_at_period_end flag
|
||||
func (s *StripeService) ReactivateSubscription(ctx context.Context, stripeSubID string) error {
|
||||
// Mock mode: just return success
|
||||
if s.mockMode {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{
|
||||
CancelAtPeriodEnd: stripe.Bool(false),
|
||||
}
|
||||
|
||||
_, err := subscription.Update(stripeSubID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reactivate subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateCustomerPortalSession creates a Stripe Customer Portal session
|
||||
func (s *StripeService) CreateCustomerPortalSession(ctx context.Context, customerID string) (string, error) {
|
||||
// Mock mode: return a mock URL
|
||||
if s.mockMode {
|
||||
return fmt.Sprintf("%s?mock_portal=true", s.successURL), nil
|
||||
}
|
||||
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(customerID),
|
||||
ReturnURL: stripe.String(s.successURL),
|
||||
}
|
||||
|
||||
sess, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create portal session: %w", err)
|
||||
}
|
||||
|
||||
return sess.URL, nil
|
||||
}
|
||||
|
||||
// GetSubscription retrieves a subscription from Stripe
|
||||
func (s *StripeService) GetSubscription(ctx context.Context, stripeSubID string) (*stripe.Subscription, error) {
|
||||
sub, err := subscription.Get(stripeSubID, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get subscription: %w", err)
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
@@ -0,0 +1,315 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SubscriptionService handles subscription-related operations
|
||||
type SubscriptionService struct {
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
// NewSubscriptionService creates a new SubscriptionService
|
||||
func NewSubscriptionService(db *database.DB) *SubscriptionService {
|
||||
return &SubscriptionService{db: db}
|
||||
}
|
||||
|
||||
// GetByUserID retrieves a subscription by user ID
|
||||
func (s *SubscriptionService) GetByUserID(ctx context.Context, userID uuid.UUID) (*models.Subscription, error) {
|
||||
query := `
|
||||
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
|
||||
status, trial_end, current_period_end, cancel_at_period_end,
|
||||
created_at, updated_at
|
||||
FROM subscriptions
|
||||
WHERE user_id = $1
|
||||
`
|
||||
|
||||
var sub models.Subscription
|
||||
var stripeCustomerID, stripeSubID *string
|
||||
var trialEnd, periodEnd *time.Time
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&sub.ID, &sub.UserID, &stripeCustomerID, &stripeSubID, &sub.PlanID,
|
||||
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt, &sub.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripeCustomerID != nil {
|
||||
sub.StripeCustomerID = *stripeCustomerID
|
||||
}
|
||||
if stripeSubID != nil {
|
||||
sub.StripeSubscriptionID = *stripeSubID
|
||||
}
|
||||
sub.TrialEnd = trialEnd
|
||||
sub.CurrentPeriodEnd = periodEnd
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetByStripeSubscriptionID retrieves a subscription by Stripe subscription ID
|
||||
func (s *SubscriptionService) GetByStripeSubscriptionID(ctx context.Context, stripeSubID string) (*models.Subscription, error) {
|
||||
query := `
|
||||
SELECT id, user_id, stripe_customer_id, stripe_subscription_id, plan_id,
|
||||
status, trial_end, current_period_end, cancel_at_period_end,
|
||||
created_at, updated_at
|
||||
FROM subscriptions
|
||||
WHERE stripe_subscription_id = $1
|
||||
`
|
||||
|
||||
var sub models.Subscription
|
||||
var stripeCustomerID, subID *string
|
||||
var trialEnd, periodEnd *time.Time
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, stripeSubID).Scan(
|
||||
&sub.ID, &sub.UserID, &stripeCustomerID, &subID, &sub.PlanID,
|
||||
&sub.Status, &trialEnd, &periodEnd, &sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt, &sub.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripeCustomerID != nil {
|
||||
sub.StripeCustomerID = *stripeCustomerID
|
||||
}
|
||||
if subID != nil {
|
||||
sub.StripeSubscriptionID = *subID
|
||||
}
|
||||
sub.TrialEnd = trialEnd
|
||||
sub.CurrentPeriodEnd = periodEnd
|
||||
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Create creates a new subscription
|
||||
func (s *SubscriptionService) Create(ctx context.Context, sub *models.Subscription) error {
|
||||
query := `
|
||||
INSERT INTO subscriptions (
|
||||
user_id, stripe_customer_id, stripe_subscription_id, plan_id,
|
||||
status, trial_end, current_period_end, cancel_at_period_end
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
return s.db.Pool.QueryRow(ctx, query,
|
||||
sub.UserID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
|
||||
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
|
||||
).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update updates an existing subscription
|
||||
func (s *SubscriptionService) Update(ctx context.Context, sub *models.Subscription) error {
|
||||
query := `
|
||||
UPDATE subscriptions SET
|
||||
stripe_customer_id = $2,
|
||||
stripe_subscription_id = $3,
|
||||
plan_id = $4,
|
||||
status = $5,
|
||||
trial_end = $6,
|
||||
current_period_end = $7,
|
||||
cancel_at_period_end = $8,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
sub.ID, sub.StripeCustomerID, sub.StripeSubscriptionID, sub.PlanID,
|
||||
sub.Status, sub.TrialEnd, sub.CurrentPeriodEnd, sub.CancelAtPeriodEnd,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateStatus updates the subscription status
|
||||
func (s *SubscriptionService) UpdateStatus(ctx context.Context, id uuid.UUID, status models.SubscriptionStatus) error {
|
||||
query := `UPDATE subscriptions SET status = $2, updated_at = NOW() WHERE id = $1`
|
||||
_, err := s.db.Pool.Exec(ctx, query, id, status)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAvailablePlans retrieves all active billing plans
|
||||
func (s *SubscriptionService) GetAvailablePlans(ctx context.Context) ([]models.BillingPlan, error) {
|
||||
query := `
|
||||
SELECT id, stripe_price_id, name, description, price_cents,
|
||||
currency, interval, features, is_active, sort_order
|
||||
FROM billing_plans
|
||||
WHERE is_active = true
|
||||
ORDER BY sort_order ASC
|
||||
`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var plans []models.BillingPlan
|
||||
for rows.Next() {
|
||||
var plan models.BillingPlan
|
||||
var stripePriceID *string
|
||||
var featuresJSON []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
|
||||
&plan.PriceCents, &plan.Currency, &plan.Interval,
|
||||
&featuresJSON, &plan.IsActive, &plan.SortOrder,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripePriceID != nil {
|
||||
plan.StripePriceID = *stripePriceID
|
||||
}
|
||||
|
||||
// Parse features JSON
|
||||
if len(featuresJSON) > 0 {
|
||||
json.Unmarshal(featuresJSON, &plan.Features)
|
||||
}
|
||||
|
||||
plans = append(plans, plan)
|
||||
}
|
||||
|
||||
return plans, nil
|
||||
}
|
||||
|
||||
// GetPlanByID retrieves a billing plan by ID
|
||||
func (s *SubscriptionService) GetPlanByID(ctx context.Context, planID string) (*models.BillingPlan, error) {
|
||||
query := `
|
||||
SELECT id, stripe_price_id, name, description, price_cents,
|
||||
currency, interval, features, is_active, sort_order
|
||||
FROM billing_plans
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
var plan models.BillingPlan
|
||||
var stripePriceID *string
|
||||
var featuresJSON []byte
|
||||
|
||||
err := s.db.Pool.QueryRow(ctx, query, planID).Scan(
|
||||
&plan.ID, &stripePriceID, &plan.Name, &plan.Description,
|
||||
&plan.PriceCents, &plan.Currency, &plan.Interval,
|
||||
&featuresJSON, &plan.IsActive, &plan.SortOrder,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stripePriceID != nil {
|
||||
plan.StripePriceID = *stripePriceID
|
||||
}
|
||||
|
||||
if len(featuresJSON) > 0 {
|
||||
json.Unmarshal(featuresJSON, &plan.Features)
|
||||
}
|
||||
|
||||
return &plan, nil
|
||||
}
|
||||
|
||||
// UpdatePlanStripePriceID updates the Stripe price ID for a plan
|
||||
func (s *SubscriptionService) UpdatePlanStripePriceID(ctx context.Context, planID, stripePriceID, stripeProductID string) error {
|
||||
query := `
|
||||
UPDATE billing_plans
|
||||
SET stripe_price_id = $2, stripe_product_id = $3, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, planID, stripePriceID, stripeProductID)
|
||||
return err
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Webhook Event Tracking (Idempotency)
|
||||
// =============================================
|
||||
|
||||
// IsEventProcessed checks if a webhook event has already been processed
|
||||
func (s *SubscriptionService) IsEventProcessed(ctx context.Context, eventID string) (bool, error) {
|
||||
query := `SELECT processed FROM stripe_webhook_events WHERE stripe_event_id = $1`
|
||||
|
||||
var processed bool
|
||||
err := s.db.Pool.QueryRow(ctx, query, eventID).Scan(&processed)
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return processed, nil
|
||||
}
|
||||
|
||||
// MarkEventProcessing marks an event as being processed
|
||||
func (s *SubscriptionService) MarkEventProcessing(ctx context.Context, eventID, eventType string) error {
|
||||
query := `
|
||||
INSERT INTO stripe_webhook_events (stripe_event_id, event_type, processed)
|
||||
VALUES ($1, $2, false)
|
||||
ON CONFLICT (stripe_event_id) DO NOTHING
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, eventID, eventType)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkEventProcessed marks an event as successfully processed
|
||||
func (s *SubscriptionService) MarkEventProcessed(ctx context.Context, eventID string) error {
|
||||
query := `
|
||||
UPDATE stripe_webhook_events
|
||||
SET processed = true, processed_at = NOW()
|
||||
WHERE stripe_event_id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, eventID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkEventFailed marks an event as failed with an error message
|
||||
func (s *SubscriptionService) MarkEventFailed(ctx context.Context, eventID, errorMsg string) error {
|
||||
query := `
|
||||
UPDATE stripe_webhook_events
|
||||
SET processed = false, error_message = $2, processed_at = NOW()
|
||||
WHERE stripe_event_id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, eventID, errorMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Audit Logging
|
||||
// =============================================
|
||||
|
||||
// LogAuditEvent logs a billing audit event
|
||||
func (s *SubscriptionService) LogAuditEvent(ctx context.Context, userID *uuid.UUID, action, entityType, entityID string, oldValue, newValue, metadata interface{}, ipAddress, userAgent string) error {
|
||||
oldJSON, _ := json.Marshal(oldValue)
|
||||
newJSON, _ := json.Marshal(newValue)
|
||||
metaJSON, _ := json.Marshal(metadata)
|
||||
|
||||
query := `
|
||||
INSERT INTO billing_audit_log (
|
||||
user_id, action, entity_type, entity_id,
|
||||
old_value, new_value, metadata, ip_address, user_agent
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
|
||||
_, err := s.db.Pool.Exec(ctx, query,
|
||||
userID, action, entityType, entityID,
|
||||
oldJSON, newJSON, metaJSON, ipAddress, userAgent,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
)
|
||||
|
||||
func TestSubscriptionStatus_Transitions(t *testing.T) {
|
||||
// Test valid subscription status values
|
||||
validStatuses := []models.SubscriptionStatus{
|
||||
models.StatusTrialing,
|
||||
models.StatusActive,
|
||||
models.StatusPastDue,
|
||||
models.StatusCanceled,
|
||||
models.StatusExpired,
|
||||
}
|
||||
|
||||
for _, status := range validStatuses {
|
||||
if status == "" {
|
||||
t.Errorf("Status should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanID_ValidValues(t *testing.T) {
|
||||
validPlanIDs := []models.PlanID{
|
||||
models.PlanBasic,
|
||||
models.PlanStandard,
|
||||
models.PlanPremium,
|
||||
}
|
||||
|
||||
expected := []string{"basic", "standard", "premium"}
|
||||
|
||||
for i, planID := range validPlanIDs {
|
||||
if string(planID) != expected[i] {
|
||||
t.Errorf("PlanID should be '%s', got '%s'", expected[i], planID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanFeatures_JSONSerialization(t *testing.T) {
|
||||
features := models.PlanFeatures{
|
||||
MonthlyTaskAllowance: 100,
|
||||
MaxTaskBalance: 500,
|
||||
FeatureFlags: []string{"basic_ai", "templates"},
|
||||
MaxTeamMembers: 3,
|
||||
PrioritySupport: false,
|
||||
CustomBranding: false,
|
||||
BatchProcessing: true,
|
||||
CustomTemplates: true,
|
||||
FairUseMode: false,
|
||||
}
|
||||
|
||||
// Test JSON serialization
|
||||
data, err := json.Marshal(features)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal PlanFeatures: %v", err)
|
||||
}
|
||||
|
||||
// Test JSON deserialization
|
||||
var decoded models.PlanFeatures
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal PlanFeatures: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.MonthlyTaskAllowance != features.MonthlyTaskAllowance {
|
||||
t.Errorf("MonthlyTaskAllowance mismatch: got %d, expected %d",
|
||||
decoded.MonthlyTaskAllowance, features.MonthlyTaskAllowance)
|
||||
}
|
||||
if decoded.MaxTaskBalance != features.MaxTaskBalance {
|
||||
t.Errorf("MaxTaskBalance mismatch: got %d, expected %d",
|
||||
decoded.MaxTaskBalance, features.MaxTaskBalance)
|
||||
}
|
||||
if decoded.BatchProcessing != features.BatchProcessing {
|
||||
t.Errorf("BatchProcessing mismatch: got %v, expected %v",
|
||||
decoded.BatchProcessing, features.BatchProcessing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_DefaultPlansAreValid(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
if len(plans) != 3 {
|
||||
t.Fatalf("Expected 3 default plans, got %d", len(plans))
|
||||
}
|
||||
|
||||
// Verify all plans have required fields
|
||||
for _, plan := range plans {
|
||||
if plan.ID == "" {
|
||||
t.Errorf("Plan ID should not be empty")
|
||||
}
|
||||
if plan.Name == "" {
|
||||
t.Errorf("Plan '%s' should have a name", plan.ID)
|
||||
}
|
||||
if plan.Description == "" {
|
||||
t.Errorf("Plan '%s' should have a description", plan.ID)
|
||||
}
|
||||
if plan.PriceCents <= 0 {
|
||||
t.Errorf("Plan '%s' should have a positive price, got %d", plan.ID, plan.PriceCents)
|
||||
}
|
||||
if plan.Currency != "eur" {
|
||||
t.Errorf("Plan '%s' currency should be 'eur', got '%s'", plan.ID, plan.Currency)
|
||||
}
|
||||
if plan.Interval != "month" {
|
||||
t.Errorf("Plan '%s' interval should be 'month', got '%s'", plan.ID, plan.Interval)
|
||||
}
|
||||
if !plan.IsActive {
|
||||
t.Errorf("Plan '%s' should be active", plan.ID)
|
||||
}
|
||||
if plan.SortOrder <= 0 {
|
||||
t.Errorf("Plan '%s' should have a positive sort order, got %d", plan.ID, plan.SortOrder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_TaskAllowanceProgression(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
// Basic should have lowest allowance
|
||||
basic := plans[0]
|
||||
standard := plans[1]
|
||||
premium := plans[2]
|
||||
|
||||
if basic.Features.MonthlyTaskAllowance >= standard.Features.MonthlyTaskAllowance {
|
||||
t.Error("Standard plan should have more tasks than Basic")
|
||||
}
|
||||
|
||||
if standard.Features.MonthlyTaskAllowance >= premium.Features.MonthlyTaskAllowance {
|
||||
t.Error("Premium plan should have more tasks than Standard")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_PriceProgression(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
// Prices should increase with each tier
|
||||
if plans[0].PriceCents >= plans[1].PriceCents {
|
||||
t.Error("Standard should cost more than Basic")
|
||||
}
|
||||
if plans[1].PriceCents >= plans[2].PriceCents {
|
||||
t.Error("Premium should cost more than Standard")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_FairUseModeOnlyForPremium(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
if plan.ID == models.PlanPremium {
|
||||
if !plan.Features.FairUseMode {
|
||||
t.Error("Premium plan should have FairUseMode enabled")
|
||||
}
|
||||
} else {
|
||||
if plan.Features.FairUseMode {
|
||||
t.Errorf("Plan '%s' should not have FairUseMode enabled", plan.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingPlan_MaxTaskBalanceCalculation(t *testing.T) {
|
||||
plans := models.GetDefaultPlans()
|
||||
|
||||
for _, plan := range plans {
|
||||
expected := plan.Features.MonthlyTaskAllowance * models.CarryoverMonthsCap
|
||||
if plan.Features.MaxTaskBalance != expected {
|
||||
t.Errorf("Plan '%s' MaxTaskBalance should be %d (allowance * 5), got %d",
|
||||
plan.ID, expected, plan.Features.MaxTaskBalance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditLogJSON_Marshaling(t *testing.T) {
|
||||
// Test that audit log values can be properly serialized
|
||||
oldValue := map[string]interface{}{
|
||||
"plan_id": "basic",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
newValue := map[string]interface{}{
|
||||
"plan_id": "standard",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
metadata := map[string]interface{}{
|
||||
"reason": "upgrade",
|
||||
}
|
||||
|
||||
// Marshal all values
|
||||
oldJSON, err := json.Marshal(oldValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal oldValue: %v", err)
|
||||
}
|
||||
|
||||
newJSON, err := json.Marshal(newValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal newValue: %v", err)
|
||||
}
|
||||
|
||||
metaJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal metadata: %v", err)
|
||||
}
|
||||
|
||||
// Verify non-empty
|
||||
if len(oldJSON) == 0 || len(newJSON) == 0 || len(metaJSON) == 0 {
|
||||
t.Error("JSON outputs should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionTrialCalculation(t *testing.T) {
|
||||
// Test trial days calculation logic
|
||||
trialDays := 7
|
||||
|
||||
if trialDays <= 0 {
|
||||
t.Error("Trial days should be positive")
|
||||
}
|
||||
|
||||
if trialDays > 30 {
|
||||
t.Error("Trial days should not exceed 30")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionInfo_TrialingStatus(t *testing.T) {
|
||||
info := models.SubscriptionInfo{
|
||||
PlanID: models.PlanBasic,
|
||||
PlanName: "Basic",
|
||||
Status: models.StatusTrialing,
|
||||
IsTrialing: true,
|
||||
TrialDaysLeft: 5,
|
||||
CancelAtPeriodEnd: false,
|
||||
PriceCents: 990,
|
||||
Currency: "eur",
|
||||
}
|
||||
|
||||
if !info.IsTrialing {
|
||||
t.Error("Should be trialing")
|
||||
}
|
||||
if info.Status != models.StatusTrialing {
|
||||
t.Errorf("Status should be 'trialing', got '%s'", info.Status)
|
||||
}
|
||||
if info.TrialDaysLeft <= 0 {
|
||||
t.Error("TrialDaysLeft should be positive during trial")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionInfo_ActiveStatus(t *testing.T) {
|
||||
info := models.SubscriptionInfo{
|
||||
PlanID: models.PlanStandard,
|
||||
PlanName: "Standard",
|
||||
Status: models.StatusActive,
|
||||
IsTrialing: false,
|
||||
TrialDaysLeft: 0,
|
||||
CancelAtPeriodEnd: false,
|
||||
PriceCents: 1990,
|
||||
Currency: "eur",
|
||||
}
|
||||
|
||||
if info.IsTrialing {
|
||||
t.Error("Should not be trialing")
|
||||
}
|
||||
if info.Status != models.StatusActive {
|
||||
t.Errorf("Status should be 'active', got '%s'", info.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionInfo_CanceledStatus(t *testing.T) {
|
||||
info := models.SubscriptionInfo{
|
||||
PlanID: models.PlanStandard,
|
||||
PlanName: "Standard",
|
||||
Status: models.StatusActive,
|
||||
IsTrialing: false,
|
||||
CancelAtPeriodEnd: true, // Scheduled for cancellation
|
||||
PriceCents: 1990,
|
||||
Currency: "eur",
|
||||
}
|
||||
|
||||
if !info.CancelAtPeriodEnd {
|
||||
t.Error("CancelAtPeriodEnd should be true")
|
||||
}
|
||||
// Status remains active until period end
|
||||
if info.Status != models.StatusActive {
|
||||
t.Errorf("Status should still be 'active', got '%s'", info.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookEventTypes(t *testing.T) {
|
||||
// Test common Stripe webhook event types we handle
|
||||
eventTypes := []string{
|
||||
"checkout.session.completed",
|
||||
"customer.subscription.created",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"invoice.paid",
|
||||
"invoice.payment_failed",
|
||||
}
|
||||
|
||||
for _, eventType := range eventTypes {
|
||||
if eventType == "" {
|
||||
t.Error("Event type should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdempotencyKey_Format(t *testing.T) {
|
||||
// Test that we can handle Stripe event IDs
|
||||
sampleEventIDs := []string{
|
||||
"evt_1234567890abcdef",
|
||||
"evt_test_abc123xyz789",
|
||||
"evt_live_real_event_id",
|
||||
}
|
||||
|
||||
for _, eventID := range sampleEventIDs {
|
||||
if len(eventID) < 10 {
|
||||
t.Errorf("Event ID '%s' seems too short", eventID)
|
||||
}
|
||||
// Stripe event IDs typically start with "evt_"
|
||||
if eventID[:4] != "evt_" {
|
||||
t.Errorf("Event ID '%s' should start with 'evt_'", eventID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTaskLimitReached is returned when task balance is 0
|
||||
ErrTaskLimitReached = errors.New("TASK_LIMIT_REACHED")
|
||||
// ErrNoSubscription is returned when user has no subscription
|
||||
ErrNoSubscription = errors.New("NO_SUBSCRIPTION")
|
||||
)
|
||||
|
||||
// TaskService handles task consumption and balance management
|
||||
type TaskService struct {
|
||||
db *database.DB
|
||||
subService *SubscriptionService
|
||||
}
|
||||
|
||||
// NewTaskService creates a new TaskService
|
||||
func NewTaskService(db *database.DB, subService *SubscriptionService) *TaskService {
|
||||
return &TaskService{
|
||||
db: db,
|
||||
subService: subService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountUsage retrieves or creates account usage for a user
|
||||
func (s *TaskService) GetAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
|
||||
query := `
|
||||
SELECT id, account_id, plan, monthly_task_allowance, carryover_months_cap,
|
||||
max_task_balance, task_balance, last_renewal_at, created_at, updated_at
|
||||
FROM account_usage
|
||||
WHERE account_id = $1
|
||||
`
|
||||
|
||||
var usage models.AccountUsage
|
||||
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
|
||||
&usage.ID, &usage.AccountID, &usage.PlanID, &usage.MonthlyTaskAllowance,
|
||||
&usage.CarryoverMonthsCap, &usage.MaxTaskBalance, &usage.TaskBalance,
|
||||
&usage.LastRenewalAt, &usage.CreatedAt, &usage.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "no rows in result set" {
|
||||
// Create new account usage based on subscription
|
||||
return s.createAccountUsage(ctx, userID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if month renewal is needed
|
||||
if err := s.checkAndApplyMonthRenewal(ctx, &usage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
// createAccountUsage creates account usage based on user's subscription
|
||||
func (s *TaskService) createAccountUsage(ctx context.Context, userID uuid.UUID) (*models.AccountUsage, error) {
|
||||
// Get subscription to determine plan
|
||||
sub, err := s.subService.GetByUserID(ctx, userID)
|
||||
if err != nil || sub == nil {
|
||||
return nil, ErrNoSubscription
|
||||
}
|
||||
|
||||
// Get plan features
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(sub.PlanID))
|
||||
if err != nil || plan == nil {
|
||||
return nil, fmt.Errorf("plan not found: %s", sub.PlanID)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
usage := &models.AccountUsage{
|
||||
AccountID: userID,
|
||||
PlanID: sub.PlanID,
|
||||
MonthlyTaskAllowance: plan.Features.MonthlyTaskAllowance,
|
||||
CarryoverMonthsCap: models.CarryoverMonthsCap,
|
||||
MaxTaskBalance: plan.Features.MaxTaskBalance,
|
||||
TaskBalance: plan.Features.MonthlyTaskAllowance, // Start with one month's worth
|
||||
LastRenewalAt: now,
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO account_usage (
|
||||
account_id, plan, monthly_task_allowance, carryover_months_cap,
|
||||
max_task_balance, task_balance, last_renewal_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
usage.AccountID, usage.PlanID, usage.MonthlyTaskAllowance,
|
||||
usage.CarryoverMonthsCap, usage.MaxTaskBalance, usage.TaskBalance, usage.LastRenewalAt,
|
||||
).Scan(&usage.ID, &usage.CreatedAt, &usage.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// checkAndApplyMonthRenewal checks if a month has passed and adds allowance
|
||||
// Implements the carryover logic: tasks accumulate up to max_task_balance
|
||||
func (s *TaskService) checkAndApplyMonthRenewal(ctx context.Context, usage *models.AccountUsage) error {
|
||||
now := time.Now()
|
||||
|
||||
// Check if at least one month has passed since last renewal
|
||||
monthsSinceRenewal := monthsBetween(usage.LastRenewalAt, now)
|
||||
if monthsSinceRenewal < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate new balance with carryover
|
||||
// Add monthly allowance for each month that passed
|
||||
newBalance := usage.TaskBalance
|
||||
for i := 0; i < monthsSinceRenewal; i++ {
|
||||
newBalance += usage.MonthlyTaskAllowance
|
||||
// Cap at max balance
|
||||
if newBalance > usage.MaxTaskBalance {
|
||||
newBalance = usage.MaxTaskBalance
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate new renewal date (add the number of months)
|
||||
newRenewalAt := usage.LastRenewalAt.AddDate(0, monthsSinceRenewal, 0)
|
||||
|
||||
// Update in database
|
||||
query := `
|
||||
UPDATE account_usage
|
||||
SET task_balance = $2, last_renewal_at = $3, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := s.db.Pool.Exec(ctx, query, usage.ID, newBalance, newRenewalAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update local struct
|
||||
usage.TaskBalance = newBalance
|
||||
usage.LastRenewalAt = newRenewalAt
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monthsBetween calculates full months between two dates
|
||||
func monthsBetween(start, end time.Time) int {
|
||||
months := 0
|
||||
for start.AddDate(0, months+1, 0).Before(end) || start.AddDate(0, months+1, 0).Equal(end) {
|
||||
months++
|
||||
}
|
||||
return months
|
||||
}
|
||||
|
||||
// CheckTaskAllowed checks if a task can be consumed (balance > 0)
|
||||
func (s *TaskService) CheckTaskAllowed(ctx context.Context, userID uuid.UUID) (*models.CheckTaskAllowedResponse, error) {
|
||||
usage, err := s.GetAccountUsage(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoSubscription) {
|
||||
return &models.CheckTaskAllowedResponse{
|
||||
Allowed: false,
|
||||
PlanID: "",
|
||||
Message: "Kein aktives Abonnement gefunden.",
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Premium Fair Use mode - always allow
|
||||
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
|
||||
if plan != nil && plan.Features.FairUseMode {
|
||||
return &models.CheckTaskAllowedResponse{
|
||||
Allowed: true,
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
PlanID: usage.PlanID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
allowed := usage.TaskBalance > 0
|
||||
|
||||
response := &models.CheckTaskAllowedResponse{
|
||||
Allowed: allowed,
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
PlanID: usage.PlanID,
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
response.Message = "Dein Aufgaben-Kontingent ist aufgebraucht."
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ConsumeTask consumes one task from the balance
|
||||
// Returns error if balance is 0
|
||||
func (s *TaskService) ConsumeTask(ctx context.Context, userID uuid.UUID, taskType models.TaskType) (*models.ConsumeTaskResponse, error) {
|
||||
// First check if allowed
|
||||
checkResponse, err := s.CheckTaskAllowed(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !checkResponse.Allowed {
|
||||
return &models.ConsumeTaskResponse{
|
||||
Success: false,
|
||||
TasksRemaining: 0,
|
||||
Message: checkResponse.Message,
|
||||
}, ErrTaskLimitReached
|
||||
}
|
||||
|
||||
// Get current usage
|
||||
usage, err := s.GetAccountUsage(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx, err := s.db.Pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Decrement balance (only if not Premium Fair Use)
|
||||
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
|
||||
newBalance := usage.TaskBalance
|
||||
if plan == nil || !plan.Features.FairUseMode {
|
||||
newBalance = usage.TaskBalance - 1
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE account_usage
|
||||
SET task_balance = $2, updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`, userID, newBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Create task record
|
||||
taskID := uuid.New()
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO tasks (id, account_id, task_type, consumed, created_at)
|
||||
VALUES ($1, $2, $3, true, NOW())
|
||||
`, taskID, userID, taskType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err = tx.Commit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.ConsumeTaskResponse{
|
||||
Success: true,
|
||||
TaskID: taskID.String(),
|
||||
TasksRemaining: newBalance,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTaskUsageInfo returns formatted task usage info for display
|
||||
func (s *TaskService) GetTaskUsageInfo(ctx context.Context, userID uuid.UUID) (*models.TaskUsageInfo, error) {
|
||||
usage, err := s.GetAccountUsage(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for Fair Use mode (Premium)
|
||||
plan, _ := s.subService.GetPlanByID(ctx, string(usage.PlanID))
|
||||
if plan != nil && plan.Features.FairUseMode {
|
||||
return &models.TaskUsageInfo{
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
InfoText: "Unbegrenzte Aufgaben (Fair Use)",
|
||||
TooltipText: "Im Premium-Tarif gibt es keine praktische Begrenzung.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &models.TaskUsageInfo{
|
||||
TasksAvailable: usage.TaskBalance,
|
||||
MaxTasks: usage.MaxTaskBalance,
|
||||
InfoText: fmt.Sprintf("Aufgaben verfuegbar: %d von max. %d", usage.TaskBalance, usage.MaxTaskBalance),
|
||||
TooltipText: "Aufgaben koennen sich bis zu 5 Monate ansammeln.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdatePlanForUser updates the plan and adjusts allowances
|
||||
func (s *TaskService) UpdatePlanForUser(ctx context.Context, userID uuid.UUID, newPlanID models.PlanID) error {
|
||||
plan, err := s.subService.GetPlanByID(ctx, string(newPlanID))
|
||||
if err != nil || plan == nil {
|
||||
return fmt.Errorf("plan not found: %s", newPlanID)
|
||||
}
|
||||
|
||||
// Update account usage with new plan limits
|
||||
query := `
|
||||
UPDATE account_usage
|
||||
SET plan = $2,
|
||||
monthly_task_allowance = $3,
|
||||
max_task_balance = $4,
|
||||
updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`
|
||||
|
||||
_, err = s.db.Pool.Exec(ctx, query,
|
||||
userID, newPlanID, plan.Features.MonthlyTaskAllowance, plan.Features.MaxTaskBalance)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTaskHistory returns task history for a user
|
||||
func (s *TaskService) GetTaskHistory(ctx context.Context, userID uuid.UUID, limit int) ([]models.Task, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, account_id, task_type, created_at, consumed
|
||||
FROM tasks
|
||||
WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, userID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []models.Task
|
||||
for rows.Next() {
|
||||
var task models.Task
|
||||
err := rows.Scan(&task.ID, &task.AccountID, &task.TaskType, &task.CreatedAt, &task.Consumed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMonthsBetween(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
start time.Time
|
||||
end time.Time
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Same day",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Less than one month",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 10, 0, 0, 0, 0, time.UTC),
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Exactly one month",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "One month and one day",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 16, 0, 0, 0, 0, time.UTC),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Two months",
|
||||
start: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 3, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Five months exactly",
|
||||
start: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC),
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "Year boundary",
|
||||
start: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Leap year February to March",
|
||||
start: time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC),
|
||||
end: time.Date(2024, 3, 29, 0, 0, 0, 0, time.UTC),
|
||||
expected: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := monthsBetween(tt.start, tt.end)
|
||||
if result != tt.expected {
|
||||
t.Errorf("monthsBetween(%v, %v) = %d, expected %d",
|
||||
tt.start.Format("2006-01-02"), tt.end.Format("2006-01-02"),
|
||||
result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCarryoverLogic(t *testing.T) {
|
||||
// Test the carryover calculation logic
|
||||
tests := []struct {
|
||||
name string
|
||||
currentBalance int
|
||||
monthlyAllowance int
|
||||
maxBalance int
|
||||
monthsSinceRenewal int
|
||||
expectedNewBalance int
|
||||
}{
|
||||
{
|
||||
name: "Normal renewal - add allowance",
|
||||
currentBalance: 50,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 80,
|
||||
},
|
||||
{
|
||||
name: "Two months missed",
|
||||
currentBalance: 50,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 2,
|
||||
expectedNewBalance: 110,
|
||||
},
|
||||
{
|
||||
name: "Cap at max balance",
|
||||
currentBalance: 140,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Already at max - no change",
|
||||
currentBalance: 150,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Multiple months - cap applies",
|
||||
currentBalance: 100,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 5,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Empty balance - add one month",
|
||||
currentBalance: 0,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 30,
|
||||
},
|
||||
{
|
||||
name: "Empty balance - add five months",
|
||||
currentBalance: 0,
|
||||
monthlyAllowance: 30,
|
||||
maxBalance: 150,
|
||||
monthsSinceRenewal: 5,
|
||||
expectedNewBalance: 150,
|
||||
},
|
||||
{
|
||||
name: "Standard plan - normal case",
|
||||
currentBalance: 200,
|
||||
monthlyAllowance: 100,
|
||||
maxBalance: 500,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 300,
|
||||
},
|
||||
{
|
||||
name: "Premium plan - Fair Use",
|
||||
currentBalance: 1000,
|
||||
monthlyAllowance: 1000,
|
||||
maxBalance: 5000,
|
||||
monthsSinceRenewal: 1,
|
||||
expectedNewBalance: 2000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the carryover logic
|
||||
newBalance := tt.currentBalance
|
||||
for i := 0; i < tt.monthsSinceRenewal; i++ {
|
||||
newBalance += tt.monthlyAllowance
|
||||
if newBalance > tt.maxBalance {
|
||||
newBalance = tt.maxBalance
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if newBalance != tt.expectedNewBalance {
|
||||
t.Errorf("Carryover for balance=%d, allowance=%d, max=%d, months=%d = %d, expected %d",
|
||||
tt.currentBalance, tt.monthlyAllowance, tt.maxBalance, tt.monthsSinceRenewal,
|
||||
newBalance, tt.expectedNewBalance)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskBalanceAfterConsumption(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentBalance int
|
||||
tasksToConsume int
|
||||
expectedBalance int
|
||||
shouldBeAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "Normal consumption",
|
||||
currentBalance: 50,
|
||||
tasksToConsume: 1,
|
||||
expectedBalance: 49,
|
||||
shouldBeAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "Last task",
|
||||
currentBalance: 1,
|
||||
tasksToConsume: 1,
|
||||
expectedBalance: 0,
|
||||
shouldBeAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "Empty balance - not allowed",
|
||||
currentBalance: 0,
|
||||
tasksToConsume: 1,
|
||||
expectedBalance: 0,
|
||||
shouldBeAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple tasks",
|
||||
currentBalance: 50,
|
||||
tasksToConsume: 5,
|
||||
expectedBalance: 45,
|
||||
shouldBeAllowed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test if allowed
|
||||
allowed := tt.currentBalance > 0
|
||||
if allowed != tt.shouldBeAllowed {
|
||||
t.Errorf("Task allowed with balance=%d: got %v, expected %v",
|
||||
tt.currentBalance, allowed, tt.shouldBeAllowed)
|
||||
}
|
||||
|
||||
// Test balance calculation
|
||||
if allowed {
|
||||
newBalance := tt.currentBalance - tt.tasksToConsume
|
||||
if newBalance != tt.expectedBalance {
|
||||
t.Errorf("Balance after consuming %d tasks from %d: got %d, expected %d",
|
||||
tt.tasksToConsume, tt.currentBalance, newBalance, tt.expectedBalance)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskServiceErrors(t *testing.T) {
|
||||
// Test error constants
|
||||
if ErrTaskLimitReached == nil {
|
||||
t.Error("ErrTaskLimitReached should not be nil")
|
||||
}
|
||||
if ErrTaskLimitReached.Error() != "TASK_LIMIT_REACHED" {
|
||||
t.Errorf("ErrTaskLimitReached should be 'TASK_LIMIT_REACHED', got '%s'", ErrTaskLimitReached.Error())
|
||||
}
|
||||
|
||||
if ErrNoSubscription == nil {
|
||||
t.Error("ErrNoSubscription should not be nil")
|
||||
}
|
||||
if ErrNoSubscription.Error() != "NO_SUBSCRIPTION" {
|
||||
t.Errorf("ErrNoSubscription should be 'NO_SUBSCRIPTION', got '%s'", ErrNoSubscription.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewalDateCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lastRenewal time.Time
|
||||
monthsToAdd int
|
||||
expectedRenewal time.Time
|
||||
}{
|
||||
{
|
||||
name: "Add one month",
|
||||
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 1,
|
||||
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "Add three months",
|
||||
lastRenewal: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 3,
|
||||
expectedRenewal: time.Date(2025, 4, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "Year boundary",
|
||||
lastRenewal: time.Date(2024, 11, 15, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 3,
|
||||
expectedRenewal: time.Date(2025, 2, 15, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "End of month adjustment",
|
||||
lastRenewal: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
|
||||
monthsToAdd: 1,
|
||||
// Go's AddDate handles this - February doesn't have 31 days
|
||||
expectedRenewal: time.Date(2025, 3, 3, 0, 0, 0, 0, time.UTC), // Feb 31 -> March 3
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.lastRenewal.AddDate(0, tt.monthsToAdd, 0)
|
||||
if !result.Equal(tt.expectedRenewal) {
|
||||
t.Errorf("AddDate(%v, %d months) = %v, expected %v",
|
||||
tt.lastRenewal.Format("2006-01-02"), tt.monthsToAdd,
|
||||
result.Format("2006-01-02"), tt.expectedRenewal.Format("2006-01-02"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFairUseModeLogic(t *testing.T) {
|
||||
// Test that Fair Use mode always allows tasks regardless of balance
|
||||
tests := []struct {
|
||||
name string
|
||||
fairUseMode bool
|
||||
balance int
|
||||
shouldAllow bool
|
||||
}{
|
||||
{
|
||||
name: "Fair Use - zero balance still allowed",
|
||||
fairUseMode: true,
|
||||
balance: 0,
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "Fair Use - normal balance allowed",
|
||||
fairUseMode: true,
|
||||
balance: 1000,
|
||||
shouldAllow: true,
|
||||
},
|
||||
{
|
||||
name: "Not Fair Use - zero balance not allowed",
|
||||
fairUseMode: false,
|
||||
balance: 0,
|
||||
shouldAllow: false,
|
||||
},
|
||||
{
|
||||
name: "Not Fair Use - positive balance allowed",
|
||||
fairUseMode: false,
|
||||
balance: 50,
|
||||
shouldAllow: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the check logic
|
||||
var allowed bool
|
||||
if tt.fairUseMode {
|
||||
allowed = true // Fair Use always allows
|
||||
} else {
|
||||
allowed = tt.balance > 0
|
||||
}
|
||||
|
||||
if allowed != tt.shouldAllow {
|
||||
t.Errorf("FairUseMode=%v, balance=%d: allowed=%v, expected=%v",
|
||||
tt.fairUseMode, tt.balance, allowed, tt.shouldAllow)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBalanceDecrementLogic(t *testing.T) {
|
||||
// Test that Fair Use mode doesn't decrement balance
|
||||
tests := []struct {
|
||||
name string
|
||||
fairUseMode bool
|
||||
initialBalance int
|
||||
expectedAfter int
|
||||
}{
|
||||
{
|
||||
name: "Normal plan - decrement",
|
||||
fairUseMode: false,
|
||||
initialBalance: 50,
|
||||
expectedAfter: 49,
|
||||
},
|
||||
{
|
||||
name: "Fair Use - no decrement",
|
||||
fairUseMode: true,
|
||||
initialBalance: 1000,
|
||||
expectedAfter: 1000,
|
||||
},
|
||||
{
|
||||
name: "Normal plan - last task",
|
||||
fairUseMode: false,
|
||||
initialBalance: 1,
|
||||
expectedAfter: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
newBalance := tt.initialBalance
|
||||
if !tt.fairUseMode {
|
||||
newBalance = tt.initialBalance - 1
|
||||
}
|
||||
|
||||
if newBalance != tt.expectedAfter {
|
||||
t.Errorf("FairUseMode=%v, initial=%d: got %d, expected %d",
|
||||
tt.fairUseMode, tt.initialBalance, newBalance, tt.expectedAfter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/billing-service/internal/database"
|
||||
"github.com/breakpilot/billing-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UsageService handles usage tracking operations
|
||||
type UsageService struct {
|
||||
db *database.DB
|
||||
entitlementService *EntitlementService
|
||||
}
|
||||
|
||||
// NewUsageService creates a new UsageService
|
||||
func NewUsageService(db *database.DB, entitlementService *EntitlementService) *UsageService {
|
||||
return &UsageService{
|
||||
db: db,
|
||||
entitlementService: entitlementService,
|
||||
}
|
||||
}
|
||||
|
||||
// TrackUsage tracks usage for a user
|
||||
func (s *UsageService) TrackUsage(ctx context.Context, userIDStr, usageType string, quantity int) error {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid user ID: %w", err)
|
||||
}
|
||||
|
||||
// Get current period start (beginning of current month)
|
||||
now := time.Now()
|
||||
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Upsert usage summary
|
||||
query := `
|
||||
INSERT INTO usage_summary (user_id, usage_type, period_start, total_count)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, usage_type, period_start) DO UPDATE SET
|
||||
total_count = usage_summary.total_count + EXCLUDED.total_count,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
_, err = s.db.Pool.Exec(ctx, query, userID, usageType, periodStart, quantity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to track usage: %w", err)
|
||||
}
|
||||
|
||||
// Also update entitlements cache
|
||||
return s.entitlementService.IncrementUsage(ctx, userID, usageType, quantity)
|
||||
}
|
||||
|
||||
// GetUsageSummary returns usage summary for a user
|
||||
func (s *UsageService) GetUsageSummary(ctx context.Context, userID uuid.UUID) (*models.UsageInfo, error) {
|
||||
// Get entitlements (which include current usage)
|
||||
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
|
||||
if err != nil || ent == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate percentages
|
||||
aiPercent := 0.0
|
||||
if ent.AIRequestsLimit > 0 {
|
||||
aiPercent = float64(ent.AIRequestsUsed) / float64(ent.AIRequestsLimit) * 100
|
||||
}
|
||||
|
||||
docPercent := 0.0
|
||||
if ent.DocumentsLimit > 0 {
|
||||
docPercent = float64(ent.DocumentsUsed) / float64(ent.DocumentsLimit) * 100
|
||||
}
|
||||
|
||||
// Get period dates
|
||||
now := time.Now()
|
||||
periodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
periodEnd := periodStart.AddDate(0, 1, 0).Add(-time.Second)
|
||||
|
||||
return &models.UsageInfo{
|
||||
AIRequestsUsed: ent.AIRequestsUsed,
|
||||
AIRequestsLimit: ent.AIRequestsLimit,
|
||||
AIRequestsPercent: aiPercent,
|
||||
DocumentsUsed: ent.DocumentsUsed,
|
||||
DocumentsLimit: ent.DocumentsLimit,
|
||||
DocumentsPercent: docPercent,
|
||||
PeriodStart: periodStart.Format("2006-01-02"),
|
||||
PeriodEnd: periodEnd.Format("2006-01-02"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckUsageAllowed checks if a user is allowed to perform a usage action
|
||||
func (s *UsageService) CheckUsageAllowed(ctx context.Context, userIDStr, usageType string) (*models.CheckUsageResponse, error) {
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
Message: "Invalid user ID",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get entitlements
|
||||
ent, err := s.entitlementService.getUserEntitlements(ctx, userID)
|
||||
if err != nil {
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
Message: "Failed to get entitlements",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if ent == nil {
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: false,
|
||||
Message: "No subscription found",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var currentUsage, limit int
|
||||
switch usageType {
|
||||
case "ai_request":
|
||||
currentUsage = ent.AIRequestsUsed
|
||||
limit = ent.AIRequestsLimit
|
||||
case "document_created":
|
||||
currentUsage = ent.DocumentsUsed
|
||||
limit = ent.DocumentsLimit
|
||||
default:
|
||||
return &models.CheckUsageResponse{
|
||||
Allowed: true,
|
||||
Message: "Unknown usage type - allowing",
|
||||
}, nil
|
||||
}
|
||||
|
||||
remaining := limit - currentUsage
|
||||
allowed := remaining > 0
|
||||
|
||||
response := &models.CheckUsageResponse{
|
||||
Allowed: allowed,
|
||||
CurrentUsage: currentUsage,
|
||||
Limit: limit,
|
||||
Remaining: remaining,
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
response.Message = fmt.Sprintf("Usage limit reached for %s (%d/%d)", usageType, currentUsage, limit)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetUsageHistory returns usage history for a user
|
||||
func (s *UsageService) GetUsageHistory(ctx context.Context, userID uuid.UUID, months int) ([]models.UsageSummary, error) {
|
||||
query := `
|
||||
SELECT id, user_id, usage_type, period_start, total_count, created_at, updated_at
|
||||
FROM usage_summary
|
||||
WHERE user_id = $1
|
||||
AND period_start >= $2
|
||||
ORDER BY period_start DESC, usage_type
|
||||
`
|
||||
|
||||
// Calculate start date
|
||||
startDate := time.Now().AddDate(0, -months, 0)
|
||||
startDate = time.Date(startDate.Year(), startDate.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, userID, startDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var summaries []models.UsageSummary
|
||||
for rows.Next() {
|
||||
var summary models.UsageSummary
|
||||
err := rows.Scan(
|
||||
&summary.ID, &summary.UserID, &summary.UsageType,
|
||||
&summary.PeriodStart, &summary.TotalCount,
|
||||
&summary.CreatedAt, &summary.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// ResetPeriodUsage resets usage for a new billing period
|
||||
func (s *UsageService) ResetPeriodUsage(ctx context.Context, userID uuid.UUID) error {
|
||||
now := time.Now()
|
||||
newPeriodStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
newPeriodEnd := newPeriodStart.AddDate(0, 1, 0).Add(-time.Second)
|
||||
|
||||
return s.entitlementService.ResetUsageCounters(ctx, userID, &newPeriodStart, &newPeriodEnd)
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
# BreakPilot BPMN Prozesse
|
||||
|
||||
Dieses Verzeichnis enthaelt die BPMN 2.0 Prozessdefinitionen fuer BreakPilot.
|
||||
|
||||
## Prozess-Uebersicht
|
||||
|
||||
| Datei | Prozess | Beschreibung | Status |
|
||||
|-------|---------|--------------|--------|
|
||||
| `classroom-lesson.bpmn` | Unterrichtsstunde | Phasenbasierte Unterrichtssteuerung | Entwurf |
|
||||
| `consent-document.bpmn` | Consent-Dokument | DSB-Approval, Publishing, Monitoring | Entwurf |
|
||||
| `klausur-korrektur.bpmn` | Klausurkorrektur | OCR, AI-Grading, Export | Entwurf |
|
||||
| `dsr-request.bpmn` | DSR/GDPR | Betroffenenanfragen (Art. 15-20) | Entwurf |
|
||||
|
||||
## Verwendung
|
||||
|
||||
### Im BPMN Editor laden
|
||||
|
||||
1. Navigiere zu http://localhost:3000/admin/workflow oder http://localhost:8000/app (Workflow)
|
||||
2. Klicke "Oeffnen" und waehle eine .bpmn Datei
|
||||
3. Bearbeite den Prozess im Editor
|
||||
4. Speichere und deploye zu Camunda
|
||||
|
||||
### In Camunda deployen
|
||||
|
||||
```bash
|
||||
# Camunda starten (falls noch nicht aktiv)
|
||||
docker compose --profile bpmn up -d camunda
|
||||
|
||||
# Prozess deployen via API
|
||||
curl -X POST http://localhost:8000/api/bpmn/deployment/create \
|
||||
-F "deployment-name=breakpilot-processes" \
|
||||
-F "data=@classroom-lesson.bpmn"
|
||||
```
|
||||
|
||||
### Prozess starten
|
||||
|
||||
```bash
|
||||
# Unterrichtsstunde starten
|
||||
curl -X POST http://localhost:8000/api/bpmn/process-definition/ClassroomLessonProcess/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"variables": {
|
||||
"teacherId": {"value": "teacher-123"},
|
||||
"classId": {"value": "class-7a"},
|
||||
"subject": {"value": "Mathematik"}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Prozess-Details
|
||||
|
||||
### 1. Classroom Lesson (classroom-lesson.bpmn)
|
||||
|
||||
**Phasen:**
|
||||
- Einstieg (Motivation, Problemstellung)
|
||||
- Erarbeitung I (Einzelarbeit, Partnerarbeit, Gruppenarbeit)
|
||||
- Erarbeitung II (optional)
|
||||
- Sicherung (Tafel, Digital, Schueler-Praesentation)
|
||||
- Transfer (Anwendungsaufgaben)
|
||||
- Reflexion & Abschluss (Hausaufgaben, Notizen)
|
||||
|
||||
**Service Tasks:**
|
||||
- `contentSuggestionDelegate` - Content-Vorschlaege basierend auf Phase
|
||||
- `lessonProtocolDelegate` - Automatisches Stundenprotokoll
|
||||
|
||||
**Timer Events:**
|
||||
- Phasen-Timer mit Warnungen
|
||||
|
||||
---
|
||||
|
||||
### 2. Consent Document (consent-document.bpmn)
|
||||
|
||||
**Workflow:**
|
||||
1. Dokument bearbeiten (Autor)
|
||||
2. DSB-Pruefung (Vier-Augen-Prinzip)
|
||||
3. Bei Ablehnung: Zurueck an Autor
|
||||
4. Bei Genehmigung: Veroeffentlichen
|
||||
5. Benutzer benachrichtigen
|
||||
6. Consent sammeln mit Deadline-Timer
|
||||
7. Monitoring-Subprocess fuer jaehrliche Erneuerung
|
||||
8. Archivierung bei neuer Version
|
||||
|
||||
**Service Tasks:**
|
||||
- `publishConsentDocumentDelegate`
|
||||
- `notifyUsersDelegate`
|
||||
- `sendConsentReminderDelegate`
|
||||
- `checkConsentStatusDelegate`
|
||||
- `triggerRenewalDelegate`
|
||||
- `archiveDocumentDelegate`
|
||||
|
||||
---
|
||||
|
||||
### 3. Klausur Korrektur (klausur-korrektur.bpmn)
|
||||
|
||||
**Workflow:**
|
||||
1. OCR-Verarbeitung der hochgeladenen Klausuren
|
||||
2. Qualitaets-Check (Confidence >= 85%)
|
||||
3. Bei schlechter Qualitaet: Manuelle Nachbearbeitung
|
||||
4. Erwartungshorizont definieren
|
||||
5. AI-Bewertung mit Claude
|
||||
6. Lehrer-Review mit Anpassungsmoeglichkeit
|
||||
7. Noten berechnen (15-Punkte-Skala)
|
||||
8. Notenbuch aktualisieren
|
||||
9. Export (PDF, Excel)
|
||||
10. Optional: Eltern benachrichtigen
|
||||
11. Archivierung
|
||||
|
||||
**Service Tasks:**
|
||||
- `ocrProcessingDelegate`
|
||||
- `ocrQualityCheckDelegate`
|
||||
- `aiGradingDelegate`
|
||||
- `calculateGradesDelegate`
|
||||
- `updateGradebookDelegate`
|
||||
- `generateExportDelegate`
|
||||
- `notifyParentsDelegate`
|
||||
- `archiveExamDelegate`
|
||||
- `deadlineWarningDelegate`
|
||||
|
||||
---
|
||||
|
||||
### 4. DSR Request (dsr-request.bpmn)
|
||||
|
||||
**GDPR Artikel:**
|
||||
- Art. 15: Recht auf Auskunft (Access)
|
||||
- Art. 16: Recht auf Berichtigung (Rectification)
|
||||
- Art. 17: Recht auf Loeschung (Deletion)
|
||||
- Art. 20: Recht auf Datenuebertragbarkeit (Portability)
|
||||
|
||||
**Workflow:**
|
||||
1. Anfrage validieren
|
||||
2. Bei ungueltig: Ablehnen
|
||||
3. Je nach Typ:
|
||||
- Access: Daten sammeln → Anonymisieren → Review → Export
|
||||
- Deletion: Identifizieren → Genehmigen → Loeschen → Verifizieren
|
||||
- Portability: Sammeln → JSON formatieren
|
||||
- Rectification: Pruefen → Anwenden
|
||||
4. Betroffenen benachrichtigen
|
||||
5. Audit Log erstellen
|
||||
|
||||
**30-Tage Frist:**
|
||||
- Timer-Event nach 25 Tagen fuer Eskalation an DSB
|
||||
|
||||
**Service Tasks:**
|
||||
- `validateDSRDelegate`
|
||||
- `rejectDSRDelegate`
|
||||
- `collectUserDataDelegate`
|
||||
- `anonymizeDataDelegate`
|
||||
- `prepareExportDelegate`
|
||||
- `identifyUserDataDelegate`
|
||||
- `executeDataDeletionDelegate`
|
||||
- `verifyDeletionDelegate`
|
||||
- `collectPortableDataDelegate`
|
||||
- `formatPortableDataDelegate`
|
||||
- `applyRectificationDelegate`
|
||||
- `notifyDataSubjectDelegate`
|
||||
- `createAuditLogDelegate`
|
||||
- `escalateToDSBDelegate`
|
||||
|
||||
## Naechste Schritte
|
||||
|
||||
1. **Delegates implementieren**: Java/Python Service Tasks
|
||||
2. **Camunda Connect**: REST-Aufrufe zu Backend-APIs
|
||||
3. **User Task Forms**: Camunda Forms oder Custom UI
|
||||
4. **Timer konfigurieren**: Realistische Dauern setzen
|
||||
5. **Testing**: Prozesse mit Testdaten durchlaufen
|
||||
|
||||
## Referenzen
|
||||
|
||||
- [Camunda 7 Docs](https://docs.camunda.org/manual/7.21/)
|
||||
- [BPMN 2.0 Spec](https://www.omg.org/spec/BPMN/2.0/)
|
||||
- [bpmn-js](https://bpmn.io/toolkit/bpmn-js/)
|
||||
@@ -0,0 +1,181 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
|
||||
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
|
||||
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
|
||||
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
|
||||
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
id="Definitions_Classroom"
|
||||
targetNamespace="http://breakpilot.de/bpmn/classroom">
|
||||
|
||||
<bpmn:process id="ClassroomLessonProcess" name="Unterrichtsstunde" isExecutable="true">
|
||||
|
||||
<!-- Start Event -->
|
||||
<bpmn:startEvent id="start" name="Stunde beginnen">
|
||||
<bpmn:outgoing>flow_to_einstieg</bpmn:outgoing>
|
||||
</bpmn:startEvent>
|
||||
|
||||
<!-- Phase 1: Einstieg -->
|
||||
<bpmn:userTask id="phase_einstieg" name="Einstiegsphase" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="10" />
|
||||
<camunda:formField id="activity" label="Aktivitaet" type="string" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_einstieg</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_erarbeitung1</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Service Task: Content-Vorschlaege Einstieg -->
|
||||
<bpmn:serviceTask id="suggest_einstieg" name="Content-Vorschlaege" camunda:delegateExpression="${contentSuggestionDelegate}">
|
||||
<bpmn:incoming>flow_suggest_einstieg</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_from_suggest_einstieg</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Phase 2: Erarbeitung I -->
|
||||
<bpmn:userTask id="phase_erarbeitung1" name="Erarbeitung I" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="15" />
|
||||
<camunda:formField id="sozialform" label="Sozialform" type="enum">
|
||||
<camunda:value id="einzelarbeit" name="Einzelarbeit" />
|
||||
<camunda:value id="partnerarbeit" name="Partnerarbeit" />
|
||||
<camunda:value id="gruppenarbeit" name="Gruppenarbeit" />
|
||||
</camunda:formField>
|
||||
<camunda:formField id="content" label="Lerneinheit" type="string" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_erarbeitung1</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_erarbeitung_gateway</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Gateway: Weitere Erarbeitung? -->
|
||||
<bpmn:exclusiveGateway id="erarbeitung_gateway" name="Weitere Erarbeitung?">
|
||||
<bpmn:incoming>flow_to_erarbeitung_gateway</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_erarbeitung2</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_to_sicherung</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Phase 2b: Erarbeitung II (optional) -->
|
||||
<bpmn:userTask id="phase_erarbeitung2" name="Erarbeitung II" camunda:assignee="${teacherId}">
|
||||
<bpmn:incoming>flow_to_erarbeitung2</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_from_erarbeitung2</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Phase 3: Sicherung -->
|
||||
<bpmn:userTask id="phase_sicherung" name="Sicherungsphase" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="10" />
|
||||
<camunda:formField id="method" label="Methode" type="enum">
|
||||
<camunda:value id="tafel" name="Tafelanschrieb" />
|
||||
<camunda:value id="digital" name="Digitale Zusammenfassung" />
|
||||
<camunda:value id="schueler" name="Schueler-Praesentation" />
|
||||
</camunda:formField>
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_sicherung</bpmn:incoming>
|
||||
<bpmn:incoming>flow_from_erarbeitung2</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_transfer</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Phase 4: Transfer -->
|
||||
<bpmn:userTask id="phase_transfer" name="Transferphase" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="8" />
|
||||
<camunda:formField id="aufgabe" label="Transfer-Aufgabe" type="string" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_transfer</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_reflexion</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Phase 5: Reflexion -->
|
||||
<bpmn:userTask id="phase_reflexion" name="Reflexion & Abschluss" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="duration" label="Dauer (Minuten)" type="long" defaultValue="5" />
|
||||
<camunda:formField id="hausaufgabe" label="Hausaufgabe" type="string" />
|
||||
<camunda:formField id="notizen" label="Stundennotizen" type="string" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_reflexion</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_protokoll</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Service Task: Stundenprotokoll -->
|
||||
<bpmn:serviceTask id="create_protokoll" name="Stundenprotokoll erstellen" camunda:delegateExpression="${lessonProtocolDelegate}">
|
||||
<bpmn:incoming>flow_to_protokoll</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- End Event -->
|
||||
<bpmn:endEvent id="end" name="Stunde beendet">
|
||||
<bpmn:incoming>flow_to_end</bpmn:incoming>
|
||||
</bpmn:endEvent>
|
||||
|
||||
<!-- Boundary Timer Events -->
|
||||
<bpmn:boundaryEvent id="timer_einstieg" attachedToRef="phase_einstieg" cancelActivity="false">
|
||||
<bpmn:timerEventDefinition>
|
||||
<bpmn:timeDuration>PT${einstiegDuration}M</bpmn:timeDuration>
|
||||
</bpmn:timerEventDefinition>
|
||||
<bpmn:outgoing>flow_timer_warning</bpmn:outgoing>
|
||||
</bpmn:boundaryEvent>
|
||||
|
||||
<!-- Sequence Flows -->
|
||||
<bpmn:sequenceFlow id="flow_to_einstieg" sourceRef="start" targetRef="phase_einstieg" />
|
||||
<bpmn:sequenceFlow id="flow_to_erarbeitung1" sourceRef="phase_einstieg" targetRef="phase_erarbeitung1" />
|
||||
<bpmn:sequenceFlow id="flow_to_erarbeitung_gateway" sourceRef="phase_erarbeitung1" targetRef="erarbeitung_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_to_erarbeitung2" sourceRef="erarbeitung_gateway" targetRef="phase_erarbeitung2">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${needsMoreWork == true}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_to_sicherung" sourceRef="erarbeitung_gateway" targetRef="phase_sicherung">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${needsMoreWork == false}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_from_erarbeitung2" sourceRef="phase_erarbeitung2" targetRef="phase_sicherung" />
|
||||
<bpmn:sequenceFlow id="flow_to_transfer" sourceRef="phase_sicherung" targetRef="phase_transfer" />
|
||||
<bpmn:sequenceFlow id="flow_to_reflexion" sourceRef="phase_transfer" targetRef="phase_reflexion" />
|
||||
<bpmn:sequenceFlow id="flow_to_protokoll" sourceRef="phase_reflexion" targetRef="create_protokoll" />
|
||||
<bpmn:sequenceFlow id="flow_to_end" sourceRef="create_protokoll" targetRef="end" />
|
||||
|
||||
</bpmn:process>
|
||||
|
||||
<!-- BPMN Diagram -->
|
||||
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
|
||||
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="ClassroomLessonProcess">
|
||||
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
|
||||
<dc:Bounds x="152" y="102" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="phase_einstieg_di" bpmnElement="phase_einstieg">
|
||||
<dc:Bounds x="240" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="phase_erarbeitung1_di" bpmnElement="phase_erarbeitung1">
|
||||
<dc:Bounds x="390" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="erarbeitung_gateway_di" bpmnElement="erarbeitung_gateway" isMarkerVisible="true">
|
||||
<dc:Bounds x="545" y="95" width="50" height="50" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="phase_erarbeitung2_di" bpmnElement="phase_erarbeitung2">
|
||||
<dc:Bounds x="520" y="200" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="phase_sicherung_di" bpmnElement="phase_sicherung">
|
||||
<dc:Bounds x="670" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="phase_transfer_di" bpmnElement="phase_transfer">
|
||||
<dc:Bounds x="820" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="phase_reflexion_di" bpmnElement="phase_reflexion">
|
||||
<dc:Bounds x="970" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="create_protokoll_di" bpmnElement="create_protokoll">
|
||||
<dc:Bounds x="1120" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
|
||||
<dc:Bounds x="1272" y="102" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
</bpmndi:BPMNPlane>
|
||||
</bpmndi:BPMNDiagram>
|
||||
|
||||
</bpmn:definitions>
|
||||
@@ -0,0 +1,206 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
|
||||
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
|
||||
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
|
||||
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
|
||||
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
id="Definitions_Consent"
|
||||
targetNamespace="http://breakpilot.de/bpmn/consent">
|
||||
|
||||
<bpmn:process id="ConsentDocumentProcess" name="Consent-Dokument Workflow" isExecutable="true">
|
||||
|
||||
<!-- Start Event -->
|
||||
<bpmn:startEvent id="start" name="Dokument erstellt">
|
||||
<bpmn:outgoing>flow_to_edit</bpmn:outgoing>
|
||||
</bpmn:startEvent>
|
||||
|
||||
<!-- User Task: Dokument bearbeiten -->
|
||||
<bpmn:userTask id="edit_document" name="Dokument bearbeiten" camunda:assignee="${authorId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="title" label="Titel" type="string" />
|
||||
<camunda:formField id="type" label="Dokumenttyp" type="enum">
|
||||
<camunda:value id="terms" name="AGB" />
|
||||
<camunda:value id="privacy" name="Datenschutzerklaerung" />
|
||||
<camunda:value id="cookies" name="Cookie-Richtlinie" />
|
||||
<camunda:value id="consent_form" name="Einwilligungserklaerung" />
|
||||
</camunda:formField>
|
||||
<camunda:formField id="content" label="Inhalt" type="string" />
|
||||
<camunda:formField id="version" label="Version" type="string" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_edit</bpmn:incoming>
|
||||
<bpmn:incoming>flow_rejected_to_edit</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_review</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- User Task: DSB Review -->
|
||||
<bpmn:userTask id="dsb_review" name="DSB Pruefung" camunda:candidateGroups="data_protection_officer">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="approved" label="Genehmigt" type="boolean" />
|
||||
<camunda:formField id="comments" label="Kommentare" type="string" />
|
||||
<camunda:formField id="legalCheck" label="Rechtliche Pruefung OK" type="boolean" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_review</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_approval_gateway</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Gateway: Genehmigt? -->
|
||||
<bpmn:exclusiveGateway id="approval_gateway" name="Genehmigt?">
|
||||
<bpmn:incoming>flow_to_approval_gateway</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_approved</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_rejected</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Service Task: Veroeffentlichen -->
|
||||
<bpmn:serviceTask id="publish_document" name="Dokument veroeffentlichen" camunda:delegateExpression="${publishConsentDocumentDelegate}">
|
||||
<bpmn:incoming>flow_approved</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_notify</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: Benutzer benachrichtigen -->
|
||||
<bpmn:serviceTask id="notify_users" name="Benutzer benachrichtigen" camunda:delegateExpression="${notifyUsersDelegate}">
|
||||
<bpmn:incoming>flow_to_notify</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_collect_consent</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- User Task: Consent sammeln (wartet auf Benutzer-Zustimmungen) -->
|
||||
<bpmn:receiveTask id="collect_consent" name="Auf Zustimmungen warten">
|
||||
<bpmn:incoming>flow_to_collect_consent</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_check_deadline</bpmn:outgoing>
|
||||
</bpmn:receiveTask>
|
||||
|
||||
<!-- Boundary Timer: Deadline -->
|
||||
<bpmn:boundaryEvent id="consent_deadline" attachedToRef="collect_consent" cancelActivity="false">
|
||||
<bpmn:timerEventDefinition>
|
||||
<bpmn:timeDuration>P${consentDeadlineDays}D</bpmn:timeDuration>
|
||||
</bpmn:timerEventDefinition>
|
||||
<bpmn:outgoing>flow_to_reminder</bpmn:outgoing>
|
||||
</bpmn:boundaryEvent>
|
||||
|
||||
<!-- Service Task: Reminder senden -->
|
||||
<bpmn:serviceTask id="send_reminder" name="Reminder senden" camunda:delegateExpression="${sendConsentReminderDelegate}">
|
||||
<bpmn:incoming>flow_to_reminder</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_back_to_collect</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: Consent-Status pruefen -->
|
||||
<bpmn:serviceTask id="check_consent_status" name="Consent-Status pruefen" camunda:delegateExpression="${checkConsentStatusDelegate}">
|
||||
<bpmn:incoming>flow_to_check_deadline</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_active</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Intermediate Event: Dokument aktiv -->
|
||||
<bpmn:intermediateThrowEvent id="document_active" name="Dokument aktiv">
|
||||
<bpmn:incoming>flow_to_active</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_monitor</bpmn:outgoing>
|
||||
</bpmn:intermediateThrowEvent>
|
||||
|
||||
<!-- Sub-Process: Monitoring -->
|
||||
<bpmn:subProcess id="monitoring_subprocess" name="Consent Monitoring">
|
||||
<bpmn:incoming>flow_to_monitor</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_archive</bpmn:outgoing>
|
||||
|
||||
<bpmn:startEvent id="monitoring_start" />
|
||||
|
||||
<!-- Event-based Gateway: Warte auf Events -->
|
||||
<bpmn:eventBasedGateway id="event_gateway">
|
||||
<bpmn:incoming>flow_from_monitoring_start</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_renewal_timer</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_to_supersede_event</bpmn:outgoing>
|
||||
</bpmn:eventBasedGateway>
|
||||
|
||||
<!-- Timer: Jaehrliche Erneuerung -->
|
||||
<bpmn:intermediateCatchEvent id="renewal_timer" name="Erneuerungsdatum">
|
||||
<bpmn:incoming>flow_to_renewal_timer</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_renewal_task</bpmn:outgoing>
|
||||
<bpmn:timerEventDefinition>
|
||||
<bpmn:timeDuration>P1Y</bpmn:timeDuration>
|
||||
</bpmn:timerEventDefinition>
|
||||
</bpmn:intermediateCatchEvent>
|
||||
|
||||
<!-- Message: Dokument ersetzt -->
|
||||
<bpmn:intermediateCatchEvent id="supersede_event" name="Neue Version">
|
||||
<bpmn:incoming>flow_to_supersede_event</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_monitoring_end</bpmn:outgoing>
|
||||
<bpmn:messageEventDefinition messageRef="Message_Supersede" />
|
||||
</bpmn:intermediateCatchEvent>
|
||||
|
||||
<bpmn:serviceTask id="trigger_renewal" name="Erneuerung anfordern" camunda:delegateExpression="${triggerRenewalDelegate}">
|
||||
<bpmn:incoming>flow_to_renewal_task</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_back_to_gateway</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<bpmn:endEvent id="monitoring_end" />
|
||||
</bpmn:subProcess>
|
||||
|
||||
<!-- Service Task: Archivieren -->
|
||||
<bpmn:serviceTask id="archive_document" name="Dokument archivieren" camunda:delegateExpression="${archiveDocumentDelegate}">
|
||||
<bpmn:incoming>flow_to_archive</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- End Event -->
|
||||
<bpmn:endEvent id="end" name="Workflow beendet">
|
||||
<bpmn:incoming>flow_to_end</bpmn:incoming>
|
||||
</bpmn:endEvent>
|
||||
|
||||
<!-- Sequence Flows -->
|
||||
<bpmn:sequenceFlow id="flow_to_edit" sourceRef="start" targetRef="edit_document" />
|
||||
<bpmn:sequenceFlow id="flow_to_review" sourceRef="edit_document" targetRef="dsb_review" />
|
||||
<bpmn:sequenceFlow id="flow_to_approval_gateway" sourceRef="dsb_review" targetRef="approval_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_approved" sourceRef="approval_gateway" targetRef="publish_document">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == true}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_rejected" sourceRef="approval_gateway" targetRef="edit_document">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == false}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_rejected_to_edit" sourceRef="approval_gateway" targetRef="edit_document" />
|
||||
<bpmn:sequenceFlow id="flow_to_notify" sourceRef="publish_document" targetRef="notify_users" />
|
||||
<bpmn:sequenceFlow id="flow_to_collect_consent" sourceRef="notify_users" targetRef="collect_consent" />
|
||||
<bpmn:sequenceFlow id="flow_to_reminder" sourceRef="consent_deadline" targetRef="send_reminder" />
|
||||
<bpmn:sequenceFlow id="flow_to_check_deadline" sourceRef="collect_consent" targetRef="check_consent_status" />
|
||||
<bpmn:sequenceFlow id="flow_to_active" sourceRef="check_consent_status" targetRef="document_active" />
|
||||
<bpmn:sequenceFlow id="flow_to_monitor" sourceRef="document_active" targetRef="monitoring_subprocess" />
|
||||
<bpmn:sequenceFlow id="flow_to_archive" sourceRef="monitoring_subprocess" targetRef="archive_document" />
|
||||
<bpmn:sequenceFlow id="flow_to_end" sourceRef="archive_document" targetRef="end" />
|
||||
|
||||
</bpmn:process>
|
||||
|
||||
<!-- Messages -->
|
||||
<bpmn:message id="Message_Supersede" name="DocumentSuperseded" />
|
||||
|
||||
<!-- BPMN Diagram -->
|
||||
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
|
||||
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="ConsentDocumentProcess">
|
||||
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
|
||||
<dc:Bounds x="152" y="102" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="edit_document_di" bpmnElement="edit_document">
|
||||
<dc:Bounds x="240" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="dsb_review_di" bpmnElement="dsb_review">
|
||||
<dc:Bounds x="390" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="approval_gateway_di" bpmnElement="approval_gateway" isMarkerVisible="true">
|
||||
<dc:Bounds x="545" y="95" width="50" height="50" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="publish_document_di" bpmnElement="publish_document">
|
||||
<dc:Bounds x="650" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="notify_users_di" bpmnElement="notify_users">
|
||||
<dc:Bounds x="800" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="collect_consent_di" bpmnElement="collect_consent">
|
||||
<dc:Bounds x="950" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
|
||||
<dc:Bounds x="1502" y="102" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
</bpmndi:BPMNPlane>
|
||||
</bpmndi:BPMNDiagram>
|
||||
|
||||
</bpmn:definitions>
|
||||
@@ -0,0 +1,222 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
|
||||
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
|
||||
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
|
||||
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
|
||||
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
id="Definitions_DSR"
|
||||
targetNamespace="http://breakpilot.de/bpmn/dsr">
|
||||
|
||||
<bpmn:process id="DSRRequestProcess" name="Data Subject Request (GDPR)" isExecutable="true">
|
||||
|
||||
<!-- Start Event -->
|
||||
<bpmn:startEvent id="start" name="DSR eingereicht">
|
||||
<bpmn:outgoing>flow_to_validate</bpmn:outgoing>
|
||||
</bpmn:startEvent>
|
||||
|
||||
<!-- Service Task: Anfrage validieren -->
|
||||
<bpmn:serviceTask id="validate_request" name="Anfrage validieren" camunda:delegateExpression="${validateDSRDelegate}">
|
||||
<bpmn:incoming>flow_to_validate</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_validation_gateway</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Gateway: Anfrage gueltig? -->
|
||||
<bpmn:exclusiveGateway id="validation_gateway" name="Anfrage gueltig?">
|
||||
<bpmn:incoming>flow_to_validation_gateway</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_valid</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_invalid</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Service Task: Anfrage ablehnen -->
|
||||
<bpmn:serviceTask id="reject_request" name="Anfrage ablehnen" camunda:delegateExpression="${rejectDSRDelegate}">
|
||||
<bpmn:incoming>flow_invalid</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_reject_end</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- End Event: Abgelehnt -->
|
||||
<bpmn:endEvent id="end_rejected" name="DSR abgelehnt">
|
||||
<bpmn:incoming>flow_to_reject_end</bpmn:incoming>
|
||||
</bpmn:endEvent>
|
||||
|
||||
<!-- Gateway: Request-Typ -->
|
||||
<bpmn:exclusiveGateway id="type_gateway" name="Request-Typ?">
|
||||
<bpmn:incoming>flow_valid</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_access</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_deletion</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_portability</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_rectification</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Sub-Process: Daten-Zugang (Art. 15) -->
|
||||
<bpmn:subProcess id="access_subprocess" name="Daten-Zugang (Art. 15)">
|
||||
<bpmn:incoming>flow_access</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_access_done</bpmn:outgoing>
|
||||
|
||||
<bpmn:startEvent id="access_start" />
|
||||
|
||||
<bpmn:serviceTask id="collect_data" name="Daten sammeln" camunda:delegateExpression="${collectUserDataDelegate}" />
|
||||
|
||||
<bpmn:serviceTask id="anonymize_data" name="Daten anonymisieren" camunda:delegateExpression="${anonymizeDataDelegate}" />
|
||||
|
||||
<bpmn:userTask id="review_data" name="Daten pruefen" camunda:candidateGroups="data_protection_officer">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="dataComplete" label="Daten vollstaendig" type="boolean" />
|
||||
<camunda:formField id="sensitivePII" label="Sensible PII entfernt" type="boolean" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
</bpmn:userTask>
|
||||
|
||||
<bpmn:serviceTask id="prepare_export" name="Export vorbereiten" camunda:delegateExpression="${prepareExportDelegate}" />
|
||||
|
||||
<bpmn:endEvent id="access_end" />
|
||||
</bpmn:subProcess>
|
||||
|
||||
<!-- Sub-Process: Daten-Loeschung (Art. 17) -->
|
||||
<bpmn:subProcess id="deletion_subprocess" name="Daten-Loeschung (Art. 17)">
|
||||
<bpmn:incoming>flow_deletion</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_deletion_done</bpmn:outgoing>
|
||||
|
||||
<bpmn:startEvent id="deletion_start" />
|
||||
|
||||
<bpmn:serviceTask id="identify_data" name="Daten identifizieren" camunda:delegateExpression="${identifyUserDataDelegate}" />
|
||||
|
||||
<bpmn:userTask id="approve_deletion" name="Loeschung genehmigen" camunda:candidateGroups="data_protection_officer">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="legalRetention" label="Aufbewahrungspflicht?" type="boolean" />
|
||||
<camunda:formField id="deletionApproved" label="Loeschung genehmigt" type="boolean" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
</bpmn:userTask>
|
||||
|
||||
<bpmn:serviceTask id="execute_deletion" name="Daten loeschen" camunda:delegateExpression="${executeDataDeletionDelegate}" />
|
||||
|
||||
<bpmn:serviceTask id="verify_deletion" name="Loeschung verifizieren" camunda:delegateExpression="${verifyDeletionDelegate}" />
|
||||
|
||||
<bpmn:endEvent id="deletion_end" />
|
||||
</bpmn:subProcess>
|
||||
|
||||
<!-- Sub-Process: Daten-Portabilitaet (Art. 20) -->
|
||||
<bpmn:subProcess id="portability_subprocess" name="Daten-Portabilitaet (Art. 20)">
|
||||
<bpmn:incoming>flow_portability</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_portability_done</bpmn:outgoing>
|
||||
|
||||
<bpmn:startEvent id="portability_start" />
|
||||
|
||||
<bpmn:serviceTask id="collect_portable_data" name="Portable Daten sammeln" camunda:delegateExpression="${collectPortableDataDelegate}" />
|
||||
|
||||
<bpmn:serviceTask id="format_data" name="Daten formatieren (JSON)" camunda:delegateExpression="${formatPortableDataDelegate}" />
|
||||
|
||||
<bpmn:endEvent id="portability_end" />
|
||||
</bpmn:subProcess>
|
||||
|
||||
<!-- Sub-Process: Berichtigung (Art. 16) -->
|
||||
<bpmn:subProcess id="rectification_subprocess" name="Berichtigung (Art. 16)">
|
||||
<bpmn:incoming>flow_rectification</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_rectification_done</bpmn:outgoing>
|
||||
|
||||
<bpmn:startEvent id="rectification_start" />
|
||||
|
||||
<bpmn:userTask id="review_rectification" name="Berichtigung pruefen" camunda:candidateGroups="data_protection_officer" />
|
||||
|
||||
<bpmn:serviceTask id="apply_rectification" name="Daten berichtigen" camunda:delegateExpression="${applyRectificationDelegate}" />
|
||||
|
||||
<bpmn:endEvent id="rectification_end" />
|
||||
</bpmn:subProcess>
|
||||
|
||||
<!-- Gateway: Zusammenfuehrung -->
|
||||
<bpmn:exclusiveGateway id="merge_gateway">
|
||||
<bpmn:incoming>flow_access_done</bpmn:incoming>
|
||||
<bpmn:incoming>flow_deletion_done</bpmn:incoming>
|
||||
<bpmn:incoming>flow_portability_done</bpmn:incoming>
|
||||
<bpmn:incoming>flow_rectification_done</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_notify</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Service Task: Betroffenen benachrichtigen -->
|
||||
<bpmn:serviceTask id="notify_subject" name="Betroffenen benachrichtigen" camunda:delegateExpression="${notifyDataSubjectDelegate}">
|
||||
<bpmn:incoming>flow_to_notify</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_audit</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: Audit Log -->
|
||||
<bpmn:serviceTask id="create_audit" name="Audit Log erstellen" camunda:delegateExpression="${createAuditLogDelegate}">
|
||||
<bpmn:incoming>flow_to_audit</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- End Event -->
|
||||
<bpmn:endEvent id="end" name="DSR abgeschlossen">
|
||||
<bpmn:incoming>flow_to_end</bpmn:incoming>
|
||||
</bpmn:endEvent>
|
||||
|
||||
<!-- Boundary Timer: 30-Tage GDPR Frist -->
|
||||
<bpmn:boundaryEvent id="gdpr_deadline" attachedToRef="access_subprocess" cancelActivity="false">
|
||||
<bpmn:timerEventDefinition>
|
||||
<bpmn:timeDuration>P25D</bpmn:timeDuration>
|
||||
</bpmn:timerEventDefinition>
|
||||
<bpmn:outgoing>flow_deadline_escalation</bpmn:outgoing>
|
||||
</bpmn:boundaryEvent>
|
||||
|
||||
<!-- Service Task: Eskalation an DSB -->
|
||||
<bpmn:serviceTask id="escalate_dsb" name="Eskalation an DSB" camunda:delegateExpression="${escalateToDSBDelegate}">
|
||||
<bpmn:incoming>flow_deadline_escalation</bpmn:incoming>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Sequence Flows -->
|
||||
<bpmn:sequenceFlow id="flow_to_validate" sourceRef="start" targetRef="validate_request" />
|
||||
<bpmn:sequenceFlow id="flow_to_validation_gateway" sourceRef="validate_request" targetRef="validation_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_valid" sourceRef="validation_gateway" targetRef="type_gateway">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${valid == true}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_invalid" sourceRef="validation_gateway" targetRef="reject_request">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${valid == false}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_to_reject_end" sourceRef="reject_request" targetRef="end_rejected" />
|
||||
<bpmn:sequenceFlow id="flow_access" sourceRef="type_gateway" targetRef="access_subprocess">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'access'}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_deletion" sourceRef="type_gateway" targetRef="deletion_subprocess">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'deletion'}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_portability" sourceRef="type_gateway" targetRef="portability_subprocess">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'portability'}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_rectification" sourceRef="type_gateway" targetRef="rectification_subprocess">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${requestType == 'rectification'}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_access_done" sourceRef="access_subprocess" targetRef="merge_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_deletion_done" sourceRef="deletion_subprocess" targetRef="merge_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_portability_done" sourceRef="portability_subprocess" targetRef="merge_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_rectification_done" sourceRef="rectification_subprocess" targetRef="merge_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_to_notify" sourceRef="merge_gateway" targetRef="notify_subject" />
|
||||
<bpmn:sequenceFlow id="flow_to_audit" sourceRef="notify_subject" targetRef="create_audit" />
|
||||
<bpmn:sequenceFlow id="flow_to_end" sourceRef="create_audit" targetRef="end" />
|
||||
<bpmn:sequenceFlow id="flow_deadline_escalation" sourceRef="gdpr_deadline" targetRef="escalate_dsb" />
|
||||
|
||||
</bpmn:process>
|
||||
|
||||
<!-- BPMN Diagram -->
|
||||
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
|
||||
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="DSRRequestProcess">
|
||||
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
|
||||
<dc:Bounds x="152" y="252" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="validate_request_di" bpmnElement="validate_request">
|
||||
<dc:Bounds x="240" y="230" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="validation_gateway_di" bpmnElement="validation_gateway" isMarkerVisible="true">
|
||||
<dc:Bounds x="395" y="245" width="50" height="50" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="type_gateway_di" bpmnElement="type_gateway" isMarkerVisible="true">
|
||||
<dc:Bounds x="545" y="245" width="50" height="50" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
|
||||
<dc:Bounds x="1502" y="252" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
</bpmndi:BPMNPlane>
|
||||
</bpmndi:BPMNDiagram>
|
||||
|
||||
</bpmn:definitions>
|
||||
@@ -0,0 +1,215 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<bpmn:definitions xmlns:bpmn="http://www.omg.org/spec/BPMN/20100524/MODEL"
|
||||
xmlns:bpmndi="http://www.omg.org/spec/BPMN/20100524/DI"
|
||||
xmlns:dc="http://www.omg.org/spec/DD/20100524/DC"
|
||||
xmlns:di="http://www.omg.org/spec/DD/20100524/DI"
|
||||
xmlns:camunda="http://camunda.org/schema/1.0/bpmn"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
id="Definitions_Klausur"
|
||||
targetNamespace="http://breakpilot.de/bpmn/klausur">
|
||||
|
||||
<bpmn:process id="KlausurKorrekturProcess" name="Klausurkorrektur Workflow" isExecutable="true">
|
||||
|
||||
<!-- Start Event -->
|
||||
<bpmn:startEvent id="start" name="Klausuren hochgeladen">
|
||||
<bpmn:outgoing>flow_to_ocr</bpmn:outgoing>
|
||||
</bpmn:startEvent>
|
||||
|
||||
<!-- Service Task: OCR Verarbeitung -->
|
||||
<bpmn:serviceTask id="ocr_processing" name="OCR Verarbeitung" camunda:delegateExpression="${ocrProcessingDelegate}">
|
||||
<bpmn:incoming>flow_to_ocr</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_quality_check</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: Qualitaets-Check -->
|
||||
<bpmn:serviceTask id="quality_check" name="OCR Qualitaets-Check" camunda:delegateExpression="${ocrQualityCheckDelegate}">
|
||||
<bpmn:incoming>flow_to_quality_check</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_quality_gateway</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Gateway: OCR Qualitaet ausreichend? -->
|
||||
<bpmn:exclusiveGateway id="quality_gateway" name="OCR Qualitaet OK?">
|
||||
<bpmn:incoming>flow_to_quality_gateway</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_quality_ok</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_quality_bad</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- User Task: Manuelle Nachbearbeitung -->
|
||||
<bpmn:userTask id="manual_ocr_fix" name="Manuelle OCR-Korrektur" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="correctedText" label="Korrigierter Text" type="string" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_quality_bad</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_from_manual_fix</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- User Task: Erwartungshorizont definieren -->
|
||||
<bpmn:userTask id="define_expectations" name="Erwartungshorizont" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="criteria" label="Bewertungskriterien" type="string" />
|
||||
<camunda:formField id="maxPoints" label="Maximale Punktzahl" type="long" />
|
||||
<camunda:formField id="useTemplate" label="Template verwenden" type="boolean" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_quality_ok</bpmn:incoming>
|
||||
<bpmn:incoming>flow_from_manual_fix</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_ai_grading</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Service Task: AI Bewertung -->
|
||||
<bpmn:serviceTask id="ai_grading" name="AI-Bewertung (Claude)" camunda:delegateExpression="${aiGradingDelegate}">
|
||||
<bpmn:incoming>flow_to_ai_grading</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_teacher_review</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- User Task: Lehrer-Review -->
|
||||
<bpmn:userTask id="teacher_review" name="Lehrer-Review" camunda:assignee="${teacherId}">
|
||||
<bpmn:extensionElements>
|
||||
<camunda:formData>
|
||||
<camunda:formField id="adjustedGrade" label="Angepasste Bewertung" type="long" />
|
||||
<camunda:formField id="comments" label="Kommentare" type="string" />
|
||||
<camunda:formField id="approved" label="Bewertung final" type="boolean" />
|
||||
</camunda:formData>
|
||||
</bpmn:extensionElements>
|
||||
<bpmn:incoming>flow_to_teacher_review</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_review_gateway</bpmn:outgoing>
|
||||
</bpmn:userTask>
|
||||
|
||||
<!-- Gateway: Review abgeschlossen? -->
|
||||
<bpmn:exclusiveGateway id="review_gateway" name="Review OK?">
|
||||
<bpmn:incoming>flow_to_review_gateway</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_review_ok</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_review_adjust</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Service Task: Noten berechnen -->
|
||||
<bpmn:serviceTask id="calculate_grades" name="Noten berechnen" camunda:delegateExpression="${calculateGradesDelegate}">
|
||||
<bpmn:incoming>flow_review_ok</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_gradebook</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: In Notenbuch uebertragen -->
|
||||
<bpmn:serviceTask id="update_gradebook" name="Notenbuch aktualisieren" camunda:delegateExpression="${updateGradebookDelegate}">
|
||||
<bpmn:incoming>flow_to_gradebook</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_export</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: Export generieren -->
|
||||
<bpmn:serviceTask id="generate_export" name="Export generieren" camunda:delegateExpression="${generateExportDelegate}">
|
||||
<bpmn:incoming>flow_to_export</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_notify_gateway</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Gateway: Eltern benachrichtigen? -->
|
||||
<bpmn:exclusiveGateway id="notify_gateway" name="Eltern benachrichtigen?">
|
||||
<bpmn:incoming>flow_to_notify_gateway</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_notify_yes</bpmn:outgoing>
|
||||
<bpmn:outgoing>flow_notify_no</bpmn:outgoing>
|
||||
</bpmn:exclusiveGateway>
|
||||
|
||||
<!-- Service Task: Eltern benachrichtigen -->
|
||||
<bpmn:serviceTask id="notify_parents" name="Eltern benachrichtigen" camunda:delegateExpression="${notifyParentsDelegate}">
|
||||
<bpmn:incoming>flow_notify_yes</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_from_notify</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Service Task: Archivieren -->
|
||||
<bpmn:serviceTask id="archive_exam" name="Klausur archivieren" camunda:delegateExpression="${archiveExamDelegate}">
|
||||
<bpmn:incoming>flow_notify_no</bpmn:incoming>
|
||||
<bpmn:incoming>flow_from_notify</bpmn:incoming>
|
||||
<bpmn:outgoing>flow_to_end</bpmn:outgoing>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- End Event -->
|
||||
<bpmn:endEvent id="end" name="Korrektur abgeschlossen">
|
||||
<bpmn:incoming>flow_to_end</bpmn:incoming>
|
||||
</bpmn:endEvent>
|
||||
|
||||
<!-- Boundary Timer: Korrektur-Deadline -->
|
||||
<bpmn:boundaryEvent id="correction_deadline" attachedToRef="teacher_review" cancelActivity="false">
|
||||
<bpmn:timerEventDefinition>
|
||||
<bpmn:timeDuration>P${correctionDeadlineDays}D</bpmn:timeDuration>
|
||||
</bpmn:timerEventDefinition>
|
||||
<bpmn:outgoing>flow_deadline_warning</bpmn:outgoing>
|
||||
</bpmn:boundaryEvent>
|
||||
|
||||
<!-- Service Task: Deadline-Warnung -->
|
||||
<bpmn:serviceTask id="deadline_warning" name="Deadline-Warnung" camunda:delegateExpression="${deadlineWarningDelegate}">
|
||||
<bpmn:incoming>flow_deadline_warning</bpmn:incoming>
|
||||
</bpmn:serviceTask>
|
||||
|
||||
<!-- Sequence Flows -->
|
||||
<bpmn:sequenceFlow id="flow_to_ocr" sourceRef="start" targetRef="ocr_processing" />
|
||||
<bpmn:sequenceFlow id="flow_to_quality_check" sourceRef="ocr_processing" targetRef="quality_check" />
|
||||
<bpmn:sequenceFlow id="flow_to_quality_gateway" sourceRef="quality_check" targetRef="quality_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_quality_ok" sourceRef="quality_gateway" targetRef="define_expectations">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${ocrConfidence >= 0.85}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_quality_bad" sourceRef="quality_gateway" targetRef="manual_ocr_fix">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${ocrConfidence < 0.85}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_from_manual_fix" sourceRef="manual_ocr_fix" targetRef="define_expectations" />
|
||||
<bpmn:sequenceFlow id="flow_to_ai_grading" sourceRef="define_expectations" targetRef="ai_grading" />
|
||||
<bpmn:sequenceFlow id="flow_to_teacher_review" sourceRef="ai_grading" targetRef="teacher_review" />
|
||||
<bpmn:sequenceFlow id="flow_to_review_gateway" sourceRef="teacher_review" targetRef="review_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_review_ok" sourceRef="review_gateway" targetRef="calculate_grades">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == true}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_review_adjust" sourceRef="review_gateway" targetRef="ai_grading">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${approved == false}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_to_gradebook" sourceRef="calculate_grades" targetRef="update_gradebook" />
|
||||
<bpmn:sequenceFlow id="flow_to_export" sourceRef="update_gradebook" targetRef="generate_export" />
|
||||
<bpmn:sequenceFlow id="flow_to_notify_gateway" sourceRef="generate_export" targetRef="notify_gateway" />
|
||||
<bpmn:sequenceFlow id="flow_notify_yes" sourceRef="notify_gateway" targetRef="notify_parents">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${notifyParents == true}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_notify_no" sourceRef="notify_gateway" targetRef="archive_exam">
|
||||
<bpmn:conditionExpression xsi:type="bpmn:tFormalExpression">${notifyParents == false}</bpmn:conditionExpression>
|
||||
</bpmn:sequenceFlow>
|
||||
<bpmn:sequenceFlow id="flow_from_notify" sourceRef="notify_parents" targetRef="archive_exam" />
|
||||
<bpmn:sequenceFlow id="flow_to_end" sourceRef="archive_exam" targetRef="end" />
|
||||
<bpmn:sequenceFlow id="flow_deadline_warning" sourceRef="correction_deadline" targetRef="deadline_warning" />
|
||||
|
||||
</bpmn:process>
|
||||
|
||||
<!-- BPMN Diagram -->
|
||||
<bpmndi:BPMNDiagram id="BPMNDiagram_1">
|
||||
<bpmndi:BPMNPlane id="BPMNPlane_1" bpmnElement="KlausurKorrekturProcess">
|
||||
<bpmndi:BPMNShape id="start_di" bpmnElement="start">
|
||||
<dc:Bounds x="152" y="102" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="ocr_processing_di" bpmnElement="ocr_processing">
|
||||
<dc:Bounds x="240" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="quality_check_di" bpmnElement="quality_check">
|
||||
<dc:Bounds x="390" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="quality_gateway_di" bpmnElement="quality_gateway" isMarkerVisible="true">
|
||||
<dc:Bounds x="545" y="95" width="50" height="50" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="manual_ocr_fix_di" bpmnElement="manual_ocr_fix">
|
||||
<dc:Bounds x="520" y="200" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="define_expectations_di" bpmnElement="define_expectations">
|
||||
<dc:Bounds x="650" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="ai_grading_di" bpmnElement="ai_grading">
|
||||
<dc:Bounds x="800" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="teacher_review_di" bpmnElement="teacher_review">
|
||||
<dc:Bounds x="950" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="calculate_grades_di" bpmnElement="calculate_grades">
|
||||
<dc:Bounds x="1100" y="80" width="100" height="80" />
|
||||
</bpmndi:BPMNShape>
|
||||
<bpmndi:BPMNShape id="end_di" bpmnElement="end">
|
||||
<dc:Bounds x="1702" y="102" width="36" height="36" />
|
||||
</bpmndi:BPMNShape>
|
||||
</bpmndi:BPMNPlane>
|
||||
</bpmndi:BPMNDiagram>
|
||||
|
||||
</bpmn:definitions>
|
||||
@@ -0,0 +1,48 @@
|
||||
# Binaries
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
server
|
||||
|
||||
# Test binary
|
||||
*.test
|
||||
|
||||
# Output of go coverage tool
|
||||
*.out
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Local config
|
||||
.env
|
||||
.env.local
|
||||
*.local
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Temp files
|
||||
*.tmp
|
||||
*.temp
|
||||
.DS_Store
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Docker
|
||||
Dockerfile
|
||||
docker-compose*.yml
|
||||
|
||||
# Vendor (if using)
|
||||
vendor/
|
||||
@@ -0,0 +1,21 @@
|
||||
# Server Configuration
|
||||
PORT=8081
|
||||
ENVIRONMENT=development
|
||||
|
||||
# Database Configuration
|
||||
# PostgreSQL connection string
|
||||
DATABASE_URL=postgres://user:password@localhost:5432/consent_db?sslmode=disable
|
||||
|
||||
# JWT Configuration (should match BreakPilot's JWT secret for token validation)
|
||||
JWT_SECRET=your-jwt-secret-here
|
||||
JWT_REFRESH_SECRET=your-refresh-secret-here
|
||||
|
||||
# CORS Configuration
|
||||
ALLOWED_ORIGINS=http://localhost:3000,http://localhost:8000,https://breakpilot.app
|
||||
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_REQUESTS=100
|
||||
RATE_LIMIT_WINDOW=60
|
||||
|
||||
# BreakPilot Integration
|
||||
BREAKPILOT_API_URL=http://localhost:8000
|
||||
@@ -0,0 +1,42 @@
|
||||
# Build stage
|
||||
FROM golang:1.23-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache git ca-certificates
|
||||
|
||||
# Copy go mod files
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the binary
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o consent-service ./cmd/server
|
||||
|
||||
# Runtime stage
|
||||
FROM alpine:3.19
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk --no-cache add ca-certificates tzdata
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /app/consent-service .
|
||||
|
||||
# Create non-root user
|
||||
RUN adduser -D -g '' appuser
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8081
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider http://localhost:8081/health || exit 1
|
||||
|
||||
# Run the binary
|
||||
CMD ["./consent-service"]
|
||||
@@ -0,0 +1,471 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/config"
|
||||
"github.com/breakpilot/consent-service/internal/database"
|
||||
"github.com/breakpilot/consent-service/internal/handlers"
|
||||
"github.com/breakpilot/consent-service/internal/middleware"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
"github.com/breakpilot/consent-service/internal/services/jitsi"
|
||||
"github.com/breakpilot/consent-service/internal/services/matrix"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
db, err := database.Connect(cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Run migrations
|
||||
if err := database.Migrate(db); err != nil {
|
||||
log.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Setup Gin router
|
||||
if cfg.Environment == "production" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
router := gin.Default()
|
||||
|
||||
// Global middleware
|
||||
router.Use(middleware.CORS())
|
||||
router.Use(middleware.RequestLogger())
|
||||
router.Use(middleware.RateLimiter())
|
||||
|
||||
// Health check
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "healthy",
|
||||
"service": "consent-service",
|
||||
"version": "1.0.0",
|
||||
})
|
||||
})
|
||||
|
||||
// Initialize services
|
||||
authService := services.NewAuthService(db.Pool, cfg.JWTSecret, cfg.JWTRefreshSecret)
|
||||
oauthService := services.NewOAuthService(db.Pool, cfg.JWTSecret)
|
||||
totpService := services.NewTOTPService(db.Pool, "BreakPilot")
|
||||
emailService := services.NewEmailService(services.EmailConfig{
|
||||
Host: cfg.SMTPHost,
|
||||
Port: cfg.SMTPPort,
|
||||
Username: cfg.SMTPUsername,
|
||||
Password: cfg.SMTPPassword,
|
||||
FromName: cfg.SMTPFromName,
|
||||
FromAddr: cfg.SMTPFromAddr,
|
||||
BaseURL: cfg.FrontendURL,
|
||||
})
|
||||
notificationService := services.NewNotificationService(db.Pool, emailService)
|
||||
deadlineService := services.NewDeadlineService(db.Pool, notificationService)
|
||||
emailTemplateService := services.NewEmailTemplateService(db.Pool)
|
||||
dsrService := services.NewDSRService(db.Pool, notificationService, emailService)
|
||||
|
||||
// Initialize handlers
|
||||
h := handlers.New(db)
|
||||
authHandler := handlers.NewAuthHandler(authService, emailService)
|
||||
oauthHandler := handlers.NewOAuthHandler(oauthService, totpService, authService)
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
deadlineHandler := handlers.NewDeadlineHandler(deadlineService)
|
||||
emailTemplateHandler := handlers.NewEmailTemplateHandler(emailTemplateService)
|
||||
dsrHandler := handlers.NewDSRHandler(dsrService)
|
||||
|
||||
// Initialize Matrix service (if enabled)
|
||||
var matrixService *matrix.MatrixService
|
||||
if cfg.MatrixEnabled && cfg.MatrixAccessToken != "" {
|
||||
matrixService = matrix.NewMatrixService(matrix.Config{
|
||||
HomeserverURL: cfg.MatrixHomeserverURL,
|
||||
AccessToken: cfg.MatrixAccessToken,
|
||||
ServerName: cfg.MatrixServerName,
|
||||
})
|
||||
log.Println("Matrix service initialized")
|
||||
} else {
|
||||
log.Println("Matrix service disabled or not configured")
|
||||
}
|
||||
|
||||
// Initialize Jitsi service (if enabled)
|
||||
var jitsiService *jitsi.JitsiService
|
||||
if cfg.JitsiEnabled {
|
||||
jitsiService = jitsi.NewJitsiService(jitsi.Config{
|
||||
BaseURL: cfg.JitsiBaseURL,
|
||||
AppID: cfg.JitsiAppID,
|
||||
AppSecret: cfg.JitsiAppSecret,
|
||||
})
|
||||
log.Println("Jitsi service initialized")
|
||||
} else {
|
||||
log.Println("Jitsi service disabled")
|
||||
}
|
||||
|
||||
// Initialize communication handlers
|
||||
communicationHandler := handlers.NewCommunicationHandlers(matrixService, jitsiService)
|
||||
|
||||
// Initialize default email templates (runs only once)
|
||||
if err := emailTemplateService.InitDefaultTemplates(context.Background()); err != nil {
|
||||
log.Printf("Warning: Failed to initialize default email templates: %v", err)
|
||||
}
|
||||
|
||||
// API v1 routes
|
||||
v1 := router.Group("/api/v1")
|
||||
{
|
||||
// =============================================
|
||||
// OAuth 2.0 Endpoints (RFC 6749)
|
||||
// =============================================
|
||||
oauth := v1.Group("/oauth")
|
||||
{
|
||||
// Authorization endpoint (requires user auth for consent)
|
||||
oauth.GET("/authorize", middleware.AuthMiddleware(cfg.JWTSecret), oauthHandler.Authorize)
|
||||
// Token endpoint (public)
|
||||
oauth.POST("/token", oauthHandler.Token)
|
||||
// Revocation endpoint (RFC 7009)
|
||||
oauth.POST("/revoke", oauthHandler.Revoke)
|
||||
// Introspection endpoint (RFC 7662)
|
||||
oauth.POST("/introspect", oauthHandler.Introspect)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Authentication Routes (with 2FA support)
|
||||
// =============================================
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
// Registration with mandatory 2FA setup
|
||||
auth.POST("/register", oauthHandler.RegisterWith2FA)
|
||||
// Login with 2FA support
|
||||
auth.POST("/login", oauthHandler.LoginWith2FA)
|
||||
// 2FA challenge verification (during login)
|
||||
auth.POST("/2fa/verify", oauthHandler.Verify2FAChallenge)
|
||||
// Legacy endpoints (kept for compatibility)
|
||||
auth.POST("/logout", authHandler.Logout)
|
||||
auth.POST("/refresh", authHandler.RefreshToken)
|
||||
auth.POST("/verify-email", authHandler.VerifyEmail)
|
||||
auth.POST("/resend-verification", authHandler.ResendVerification)
|
||||
auth.POST("/forgot-password", authHandler.ForgotPassword)
|
||||
auth.POST("/reset-password", authHandler.ResetPassword)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// 2FA Management Routes (require auth)
|
||||
// =============================================
|
||||
twoFA := v1.Group("/auth/2fa")
|
||||
twoFA.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
{
|
||||
twoFA.GET("/status", oauthHandler.Get2FAStatus)
|
||||
twoFA.POST("/setup", oauthHandler.Setup2FA)
|
||||
twoFA.POST("/verify-setup", oauthHandler.Verify2FASetup)
|
||||
twoFA.POST("/disable", oauthHandler.Disable2FA)
|
||||
twoFA.POST("/recovery-codes", oauthHandler.RegenerateRecoveryCodes)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Profile Routes (require auth)
|
||||
// =============================================
|
||||
profile := v1.Group("/profile")
|
||||
profile.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
{
|
||||
profile.GET("", authHandler.GetProfile)
|
||||
profile.PUT("", authHandler.UpdateProfile)
|
||||
profile.PUT("/password", authHandler.ChangePassword)
|
||||
profile.GET("/sessions", authHandler.GetActiveSessions)
|
||||
profile.DELETE("/sessions/:id", authHandler.RevokeSession)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Public consent routes (require user auth)
|
||||
// =============================================
|
||||
public := v1.Group("")
|
||||
public.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
{
|
||||
// Documents
|
||||
public.GET("/documents", h.GetDocuments)
|
||||
public.GET("/documents/:type", h.GetDocumentByType)
|
||||
public.GET("/documents/:type/latest", h.GetLatestDocumentVersion)
|
||||
|
||||
// User Consent
|
||||
public.POST("/consent", h.CreateConsent)
|
||||
public.GET("/consent/my", h.GetMyConsents)
|
||||
public.GET("/consent/check/:documentType", h.CheckConsent)
|
||||
public.DELETE("/consent/:id", h.WithdrawConsent)
|
||||
|
||||
// Cookie Consent
|
||||
public.GET("/cookies/categories", h.GetCookieCategories)
|
||||
public.POST("/cookies/consent", h.SetCookieConsent)
|
||||
public.GET("/cookies/consent/my", h.GetMyCookieConsent)
|
||||
|
||||
// GDPR / Data Subject Rights
|
||||
public.GET("/privacy/my-data", h.GetMyData)
|
||||
public.POST("/privacy/export", h.RequestDataExport)
|
||||
public.POST("/privacy/delete", h.RequestDataDeletion)
|
||||
|
||||
// Data Subject Requests (User-facing)
|
||||
public.POST("/dsr", dsrHandler.CreateDSR)
|
||||
public.GET("/dsr", dsrHandler.GetMyDSRs)
|
||||
public.GET("/dsr/:id", dsrHandler.GetMyDSR)
|
||||
public.POST("/dsr/:id/cancel", dsrHandler.CancelMyDSR)
|
||||
|
||||
// Notifications
|
||||
public.GET("/notifications", notificationHandler.GetNotifications)
|
||||
public.GET("/notifications/unread-count", notificationHandler.GetUnreadCount)
|
||||
public.PUT("/notifications/:id/read", notificationHandler.MarkAsRead)
|
||||
public.PUT("/notifications/read-all", notificationHandler.MarkAllAsRead)
|
||||
public.DELETE("/notifications/:id", notificationHandler.DeleteNotification)
|
||||
public.GET("/notifications/preferences", notificationHandler.GetPreferences)
|
||||
public.PUT("/notifications/preferences", notificationHandler.UpdatePreferences)
|
||||
|
||||
// Consent Deadlines & Suspension Status
|
||||
public.GET("/consent/deadlines", deadlineHandler.GetPendingDeadlines)
|
||||
public.GET("/account/suspension-status", deadlineHandler.GetSuspensionStatus)
|
||||
}
|
||||
|
||||
// Admin routes (require admin auth)
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
admin.Use(middleware.AdminOnly())
|
||||
{
|
||||
// Document Management
|
||||
admin.GET("/documents", h.AdminGetDocuments)
|
||||
admin.POST("/documents", h.AdminCreateDocument)
|
||||
admin.PUT("/documents/:id", h.AdminUpdateDocument)
|
||||
admin.DELETE("/documents/:id", h.AdminDeleteDocument)
|
||||
admin.GET("/documents/:docId/versions", h.AdminGetVersions)
|
||||
|
||||
// Version Management
|
||||
admin.POST("/versions", h.AdminCreateVersion)
|
||||
admin.PUT("/versions/:id", h.AdminUpdateVersion)
|
||||
admin.DELETE("/versions/:id", h.AdminDeleteVersion)
|
||||
admin.POST("/versions/:id/archive", h.AdminArchiveVersion)
|
||||
admin.POST("/versions/:id/submit-review", h.AdminSubmitForReview)
|
||||
admin.POST("/versions/:id/approve", h.AdminApproveVersion)
|
||||
admin.POST("/versions/:id/reject", h.AdminRejectVersion)
|
||||
admin.GET("/versions/:id/compare", h.AdminCompareVersions)
|
||||
admin.GET("/versions/:id/approval-history", h.AdminGetApprovalHistory)
|
||||
|
||||
// Publishing (DSB role recommended but Admin can also do it in dev)
|
||||
admin.POST("/versions/:id/publish", h.AdminPublishVersion)
|
||||
|
||||
// Cookie Categories
|
||||
admin.GET("/cookies/categories", h.AdminGetCookieCategories)
|
||||
admin.POST("/cookies/categories", h.AdminCreateCookieCategory)
|
||||
admin.PUT("/cookies/categories/:id", h.AdminUpdateCookieCategory)
|
||||
admin.DELETE("/cookies/categories/:id", h.AdminDeleteCookieCategory)
|
||||
|
||||
// Statistics & Audit
|
||||
admin.GET("/stats/consents", h.GetConsentStats)
|
||||
admin.GET("/stats/cookies", h.GetCookieStats)
|
||||
admin.GET("/audit-log", h.GetAuditLog)
|
||||
|
||||
// Deadline Management (for testing/manual trigger)
|
||||
admin.POST("/deadlines/process", deadlineHandler.TriggerDeadlineProcessing)
|
||||
|
||||
// Scheduled Publishing
|
||||
admin.GET("/scheduled-versions", h.GetScheduledVersions)
|
||||
admin.POST("/scheduled-publishing/process", h.ProcessScheduledPublishing)
|
||||
|
||||
// OAuth Client Management
|
||||
admin.GET("/oauth/clients", oauthHandler.AdminGetClients)
|
||||
admin.POST("/oauth/clients", oauthHandler.AdminCreateClient)
|
||||
|
||||
// =============================================
|
||||
// E-Mail Template Management
|
||||
// =============================================
|
||||
admin.GET("/email-templates/types", emailTemplateHandler.GetAllTemplateTypes)
|
||||
admin.GET("/email-templates", emailTemplateHandler.GetAllTemplates)
|
||||
admin.GET("/email-templates/settings", emailTemplateHandler.GetSettings)
|
||||
admin.PUT("/email-templates/settings", emailTemplateHandler.UpdateSettings)
|
||||
admin.GET("/email-templates/stats", emailTemplateHandler.GetEmailStats)
|
||||
admin.GET("/email-templates/logs", emailTemplateHandler.GetSendLogs)
|
||||
admin.GET("/email-templates/default/:type", emailTemplateHandler.GetDefaultContent)
|
||||
admin.POST("/email-templates/initialize", emailTemplateHandler.InitializeTemplates)
|
||||
admin.GET("/email-templates/:id", emailTemplateHandler.GetTemplate)
|
||||
admin.POST("/email-templates", emailTemplateHandler.CreateTemplate)
|
||||
admin.GET("/email-templates/:id/versions", emailTemplateHandler.GetTemplateVersions)
|
||||
|
||||
// E-Mail Template Versions
|
||||
admin.GET("/email-template-versions/:id", emailTemplateHandler.GetVersion)
|
||||
admin.POST("/email-template-versions", emailTemplateHandler.CreateVersion)
|
||||
admin.PUT("/email-template-versions/:id", emailTemplateHandler.UpdateVersion)
|
||||
admin.POST("/email-template-versions/:id/submit", emailTemplateHandler.SubmitForReview)
|
||||
admin.POST("/email-template-versions/:id/approve", emailTemplateHandler.ApproveVersion)
|
||||
admin.POST("/email-template-versions/:id/reject", emailTemplateHandler.RejectVersion)
|
||||
admin.POST("/email-template-versions/:id/publish", emailTemplateHandler.PublishVersion)
|
||||
admin.GET("/email-template-versions/:id/approvals", emailTemplateHandler.GetApprovals)
|
||||
admin.POST("/email-template-versions/:id/preview", emailTemplateHandler.PreviewVersion)
|
||||
admin.POST("/email-template-versions/:id/send-test", emailTemplateHandler.SendTestEmail)
|
||||
|
||||
// =============================================
|
||||
// Data Subject Requests (DSR) Management
|
||||
// =============================================
|
||||
admin.GET("/dsr", dsrHandler.AdminListDSR)
|
||||
admin.GET("/dsr/stats", dsrHandler.AdminGetDSRStats)
|
||||
admin.POST("/dsr", dsrHandler.AdminCreateDSR)
|
||||
admin.GET("/dsr/:id", dsrHandler.AdminGetDSR)
|
||||
admin.PUT("/dsr/:id", dsrHandler.AdminUpdateDSR)
|
||||
admin.POST("/dsr/:id/status", dsrHandler.AdminUpdateDSRStatus)
|
||||
admin.POST("/dsr/:id/verify-identity", dsrHandler.AdminVerifyIdentity)
|
||||
admin.POST("/dsr/:id/assign", dsrHandler.AdminAssignDSR)
|
||||
admin.POST("/dsr/:id/extend", dsrHandler.AdminExtendDSRDeadline)
|
||||
admin.POST("/dsr/:id/complete", dsrHandler.AdminCompleteDSR)
|
||||
admin.POST("/dsr/:id/reject", dsrHandler.AdminRejectDSR)
|
||||
admin.GET("/dsr/:id/history", dsrHandler.AdminGetDSRHistory)
|
||||
admin.GET("/dsr/:id/communications", dsrHandler.AdminGetDSRCommunications)
|
||||
admin.POST("/dsr/:id/communicate", dsrHandler.AdminSendDSRCommunication)
|
||||
admin.GET("/dsr/:id/exception-checks", dsrHandler.AdminGetExceptionChecks)
|
||||
admin.POST("/dsr/:id/exception-checks/init", dsrHandler.AdminInitExceptionChecks)
|
||||
admin.PUT("/dsr/:id/exception-checks/:checkId", dsrHandler.AdminUpdateExceptionCheck)
|
||||
admin.POST("/dsr/deadlines/process", dsrHandler.ProcessDeadlines)
|
||||
|
||||
// DSR Templates
|
||||
admin.GET("/dsr-templates", dsrHandler.AdminGetDSRTemplates)
|
||||
admin.GET("/dsr-templates/published", dsrHandler.AdminGetPublishedDSRTemplates)
|
||||
admin.GET("/dsr-templates/:id/versions", dsrHandler.AdminGetDSRTemplateVersions)
|
||||
admin.POST("/dsr-templates/:id/versions", dsrHandler.AdminCreateDSRTemplateVersion)
|
||||
admin.POST("/dsr-template-versions/:versionId/publish", dsrHandler.AdminPublishDSRTemplateVersion)
|
||||
}
|
||||
|
||||
// =============================================
|
||||
// Communication Routes (Matrix + Jitsi)
|
||||
// =============================================
|
||||
communicationHandler.RegisterRoutes(v1, cfg.JWTSecret, middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
|
||||
// =============================================
|
||||
// Cookie Banner SDK Routes (Public - Anonymous)
|
||||
// =============================================
|
||||
// Diese Endpoints werden vom @breakpilot/consent-sdk verwendet
|
||||
// für anonyme (device-basierte) Cookie-Einwilligungen.
|
||||
banner := v1.Group("/banner")
|
||||
{
|
||||
// Public Endpoints (keine Auth erforderlich)
|
||||
banner.POST("/consent", h.CreateBannerConsent)
|
||||
banner.GET("/consent", h.GetBannerConsent)
|
||||
banner.DELETE("/consent/:consentId", h.RevokeBannerConsent)
|
||||
banner.GET("/config/:siteId", h.GetSiteConfig)
|
||||
banner.GET("/consent/export", h.ExportBannerConsent)
|
||||
}
|
||||
|
||||
// Banner Admin Routes (require admin auth)
|
||||
bannerAdmin := v1.Group("/banner/admin")
|
||||
bannerAdmin.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
bannerAdmin.Use(middleware.AdminOnly())
|
||||
{
|
||||
bannerAdmin.GET("/stats/:siteId", h.GetBannerStats)
|
||||
}
|
||||
}
|
||||
|
||||
// Start background scheduler for scheduled publishing
|
||||
go startScheduledPublishingWorker(db)
|
||||
|
||||
// Start DSR deadline monitoring worker
|
||||
go startDSRDeadlineWorker(dsrService)
|
||||
|
||||
// Start server
|
||||
port := cfg.Port
|
||||
if port == "" {
|
||||
port = "8080"
|
||||
}
|
||||
|
||||
log.Printf("Starting Consent Service on port %s", port)
|
||||
if err := router.Run(":" + port); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startScheduledPublishingWorker runs every minute to check for scheduled versions
|
||||
func startScheduledPublishingWorker(db *database.DB) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
log.Println("Scheduled publishing worker started (checking every minute)")
|
||||
|
||||
for range ticker.C {
|
||||
processScheduledVersions(db)
|
||||
}
|
||||
}
|
||||
|
||||
func processScheduledVersions(db *database.DB) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Find all scheduled versions that are due
|
||||
rows, err := db.Pool.Query(ctx, `
|
||||
SELECT id, document_id, version
|
||||
FROM document_versions
|
||||
WHERE status = 'scheduled'
|
||||
AND scheduled_publish_at IS NOT NULL
|
||||
AND scheduled_publish_at <= NOW()
|
||||
`)
|
||||
if err != nil {
|
||||
log.Printf("Scheduler: Error fetching scheduled versions: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var publishedCount int
|
||||
for rows.Next() {
|
||||
var versionID, docID string
|
||||
var version string
|
||||
if err := rows.Scan(&versionID, &docID, &version); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Publish this version
|
||||
_, err := db.Pool.Exec(ctx, `
|
||||
UPDATE document_versions
|
||||
SET status = 'published', published_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, versionID)
|
||||
|
||||
if err == nil {
|
||||
// Archive previous published versions for this document
|
||||
db.Pool.Exec(ctx, `
|
||||
UPDATE document_versions
|
||||
SET status = 'archived', updated_at = NOW()
|
||||
WHERE document_id = $1 AND id != $2 AND status = 'published'
|
||||
`, docID, versionID)
|
||||
|
||||
// Log the publishing
|
||||
details := fmt.Sprintf("Version %s automatically published by scheduler", version)
|
||||
db.Pool.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (action, entity_type, entity_id, details, user_agent)
|
||||
VALUES ('version_scheduled_published', 'document_version', $1, $2, 'scheduler')
|
||||
`, versionID, details)
|
||||
|
||||
publishedCount++
|
||||
log.Printf("Scheduler: Published version %s (ID: %s)", version, versionID)
|
||||
}
|
||||
}
|
||||
|
||||
if publishedCount > 0 {
|
||||
log.Printf("Scheduler: Published %d version(s)", publishedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// startDSRDeadlineWorker monitors DSR deadlines and sends notifications
|
||||
func startDSRDeadlineWorker(dsrService *services.DSRService) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
log.Println("DSR deadline monitoring worker started (checking every hour)")
|
||||
|
||||
// Run immediately on startup
|
||||
ctx := context.Background()
|
||||
if err := dsrService.ProcessDeadlines(ctx); err != nil {
|
||||
log.Printf("DSR Worker: Error processing deadlines: %v", err)
|
||||
}
|
||||
|
||||
for range ticker.C {
|
||||
ctx := context.Background()
|
||||
if err := dsrService.ProcessDeadlines(ctx); err != nil {
|
||||
log.Printf("DSR Worker: Error processing deadlines: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
consent-service:
|
||||
build: .
|
||||
ports:
|
||||
- "8081:8081"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- DATABASE_URL=postgres://consent:consent123@postgres:5432/consent_db?sslmode=disable
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- consent-network
|
||||
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
ports:
|
||||
- "5433:5432"
|
||||
environment:
|
||||
- POSTGRES_USER=consent
|
||||
- POSTGRES_PASSWORD=consent123
|
||||
- POSTGRES_DB=consent_db
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U consent -d consent_db"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- consent-network
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
||||
networks:
|
||||
consent-network:
|
||||
driver: bridge
|
||||
@@ -0,0 +1,49 @@
|
||||
module github.com/breakpilot/consent-service
|
||||
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
golang.org/x/crypto v0.40.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.14.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.27.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.54.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
golang.org/x/arch v0.20.0 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/net v0.42.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.9 // indirect
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
|
||||
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
|
||||
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
|
||||
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
|
||||
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
|
||||
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
|
||||
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg=
|
||||
github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
|
||||
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
|
||||
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
|
||||
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
|
||||
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
|
||||
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,170 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// Config holds all configuration for the service
|
||||
type Config struct {
|
||||
// Server
|
||||
Port string
|
||||
Environment string
|
||||
|
||||
// Database
|
||||
DatabaseURL string
|
||||
|
||||
// JWT
|
||||
JWTSecret string
|
||||
JWTRefreshSecret string
|
||||
|
||||
// CORS
|
||||
AllowedOrigins []string
|
||||
|
||||
// Rate Limiting
|
||||
RateLimitRequests int
|
||||
RateLimitWindow int // in seconds
|
||||
|
||||
// BreakPilot Integration
|
||||
BreakPilotAPIURL string
|
||||
FrontendURL string
|
||||
|
||||
// SMTP Email Configuration
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFromName string
|
||||
SMTPFromAddr string
|
||||
|
||||
// Consent Settings
|
||||
ConsentDeadlineDays int
|
||||
ConsentReminderEnabled bool
|
||||
|
||||
// VAPID Keys for Web Push
|
||||
VAPIDPublicKey string
|
||||
VAPIDPrivateKey string
|
||||
|
||||
// Matrix (Synapse) Configuration
|
||||
MatrixHomeserverURL string
|
||||
MatrixAccessToken string
|
||||
MatrixServerName string
|
||||
MatrixEnabled bool
|
||||
|
||||
// Jitsi Configuration
|
||||
JitsiBaseURL string
|
||||
JitsiAppID string
|
||||
JitsiAppSecret string
|
||||
JitsiEnabled bool
|
||||
}
|
||||
|
||||
// Load loads configuration from environment variables
|
||||
func Load() (*Config, error) {
|
||||
// Load .env file if exists (for development)
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := &Config{
|
||||
Port: getEnv("PORT", "8080"),
|
||||
Environment: getEnv("ENVIRONMENT", "development"),
|
||||
DatabaseURL: getEnv("DATABASE_URL", ""),
|
||||
JWTSecret: getEnv("JWT_SECRET", ""),
|
||||
JWTRefreshSecret: getEnv("JWT_REFRESH_SECRET", ""),
|
||||
RateLimitRequests: getEnvInt("RATE_LIMIT_REQUESTS", 100),
|
||||
RateLimitWindow: getEnvInt("RATE_LIMIT_WINDOW", 60),
|
||||
BreakPilotAPIURL: getEnv("BREAKPILOT_API_URL", "http://localhost:8000"),
|
||||
FrontendURL: getEnv("FRONTEND_URL", "http://localhost:8000"),
|
||||
|
||||
// SMTP Configuration
|
||||
SMTPHost: getEnv("SMTP_HOST", ""),
|
||||
SMTPPort: getEnvInt("SMTP_PORT", 587),
|
||||
SMTPUsername: getEnv("SMTP_USERNAME", ""),
|
||||
SMTPPassword: getEnv("SMTP_PASSWORD", ""),
|
||||
SMTPFromName: getEnv("SMTP_FROM_NAME", "BreakPilot"),
|
||||
SMTPFromAddr: getEnv("SMTP_FROM_ADDR", "noreply@breakpilot.app"),
|
||||
|
||||
// Consent Settings
|
||||
ConsentDeadlineDays: getEnvInt("CONSENT_DEADLINE_DAYS", 30),
|
||||
ConsentReminderEnabled: getEnvBool("CONSENT_REMINDER_ENABLED", true),
|
||||
|
||||
// VAPID Keys
|
||||
VAPIDPublicKey: getEnv("VAPID_PUBLIC_KEY", ""),
|
||||
VAPIDPrivateKey: getEnv("VAPID_PRIVATE_KEY", ""),
|
||||
|
||||
// Matrix Configuration
|
||||
MatrixHomeserverURL: getEnv("MATRIX_HOMESERVER_URL", "http://synapse:8008"),
|
||||
MatrixAccessToken: getEnv("MATRIX_ACCESS_TOKEN", ""),
|
||||
MatrixServerName: getEnv("MATRIX_SERVER_NAME", "breakpilot.local"),
|
||||
MatrixEnabled: getEnvBool("MATRIX_ENABLED", true),
|
||||
|
||||
// Jitsi Configuration
|
||||
JitsiBaseURL: getEnv("JITSI_BASE_URL", "http://localhost:8443"),
|
||||
JitsiAppID: getEnv("JITSI_APP_ID", "breakpilot"),
|
||||
JitsiAppSecret: getEnv("JITSI_APP_SECRET", ""),
|
||||
JitsiEnabled: getEnvBool("JITSI_ENABLED", true),
|
||||
}
|
||||
|
||||
// Parse allowed origins
|
||||
originsStr := getEnv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000")
|
||||
cfg.AllowedOrigins = parseCommaSeparated(originsStr)
|
||||
|
||||
// Validate required fields
|
||||
if cfg.DatabaseURL == "" {
|
||||
return nil, fmt.Errorf("DATABASE_URL is required")
|
||||
}
|
||||
|
||||
if cfg.JWTSecret == "" {
|
||||
return nil, fmt.Errorf("JWT_SECRET is required")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var result int
|
||||
fmt.Sscanf(value, "%d", &result)
|
||||
return result
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value == "true" || value == "1" || value == "yes"
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func parseCommaSeparated(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i == len(s) || s[i] == ',' {
|
||||
item := s[start:i]
|
||||
// Trim whitespace
|
||||
for len(item) > 0 && item[0] == ' ' {
|
||||
item = item[1:]
|
||||
}
|
||||
for len(item) > 0 && item[len(item)-1] == ' ' {
|
||||
item = item[:len(item)-1]
|
||||
}
|
||||
if item != "" {
|
||||
result = append(result, item)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,322 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetEnv tests the getEnv helper function
|
||||
func TestGetEnv(t *testing.T) {
|
||||
// Test with default value when env var not set
|
||||
result := getEnv("TEST_NONEXISTENT_VAR_12345", "default")
|
||||
if result != "default" {
|
||||
t.Errorf("Expected 'default', got '%s'", result)
|
||||
}
|
||||
|
||||
// Test with set env var
|
||||
os.Setenv("TEST_ENV_VAR", "custom_value")
|
||||
defer os.Unsetenv("TEST_ENV_VAR")
|
||||
|
||||
result = getEnv("TEST_ENV_VAR", "default")
|
||||
if result != "custom_value" {
|
||||
t.Errorf("Expected 'custom_value', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetEnvInt tests the getEnvInt helper function
|
||||
func TestGetEnvInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
defaultValue int
|
||||
expected int
|
||||
}{
|
||||
{"default when not set", "", 100, 100},
|
||||
{"parse valid int", "42", 0, 42},
|
||||
{"parse zero", "0", 100, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue != "" {
|
||||
os.Setenv("TEST_INT_VAR", tt.envValue)
|
||||
defer os.Unsetenv("TEST_INT_VAR")
|
||||
} else {
|
||||
os.Unsetenv("TEST_INT_VAR")
|
||||
}
|
||||
|
||||
result := getEnvInt("TEST_INT_VAR", tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetEnvBool tests the getEnvBool helper function
|
||||
func TestGetEnvBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
defaultValue bool
|
||||
expected bool
|
||||
}{
|
||||
{"default when not set", "", true, true},
|
||||
{"default false when not set", "", false, false},
|
||||
{"parse true", "true", false, true},
|
||||
{"parse 1", "1", false, true},
|
||||
{"parse yes", "yes", false, true},
|
||||
{"parse false", "false", true, false},
|
||||
{"parse 0", "0", true, false},
|
||||
{"parse no", "no", true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue != "" {
|
||||
os.Setenv("TEST_BOOL_VAR", tt.envValue)
|
||||
defer os.Unsetenv("TEST_BOOL_VAR")
|
||||
} else {
|
||||
os.Unsetenv("TEST_BOOL_VAR")
|
||||
}
|
||||
|
||||
result := getEnvBool("TEST_BOOL_VAR", tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseCommaSeparated tests the parseCommaSeparated helper function
|
||||
func TestParseCommaSeparated(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"empty string", "", []string{}},
|
||||
{"single value", "value1", []string{"value1"}},
|
||||
{"multiple values", "value1,value2,value3", []string{"value1", "value2", "value3"}},
|
||||
{"with spaces", "value1, value2, value3", []string{"value1", "value2", "value3"}},
|
||||
{"with trailing comma", "value1,value2,", []string{"value1", "value2"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseCommaSeparated(tt.input)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected length %d, got %d", len(tt.expected), len(result))
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("At index %d: expected '%s', got '%s'", i, tt.expected[i], result[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigEnvironmentDefaults tests default environment values
|
||||
func TestConfigEnvironmentDefaults(t *testing.T) {
|
||||
// Clear any existing env vars that might interfere
|
||||
varsToUnset := []string{
|
||||
"PORT", "ENVIRONMENT", "DATABASE_URL", "JWT_SECRET", "JWT_REFRESH_SECRET",
|
||||
}
|
||||
for _, v := range varsToUnset {
|
||||
os.Unsetenv(v)
|
||||
}
|
||||
|
||||
// Set required vars
|
||||
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
|
||||
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
|
||||
defer func() {
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWT_SECRET")
|
||||
}()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Test defaults
|
||||
if cfg.Port != "8080" {
|
||||
t.Errorf("Expected default port '8080', got '%s'", cfg.Port)
|
||||
}
|
||||
|
||||
if cfg.Environment != "development" {
|
||||
t.Errorf("Expected default environment 'development', got '%s'", cfg.Environment)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoadWithEnvironment tests loading config with different environments
|
||||
func TestConfigLoadWithEnvironment(t *testing.T) {
|
||||
// Set required vars
|
||||
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
|
||||
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
|
||||
defer func() {
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWT_SECRET")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
environment string
|
||||
}{
|
||||
{"development", "development"},
|
||||
{"staging", "staging"},
|
||||
{"production", "production"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Setenv("ENVIRONMENT", tt.environment)
|
||||
defer os.Unsetenv("ENVIRONMENT")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Environment != tt.environment {
|
||||
t.Errorf("Expected environment '%s', got '%s'", tt.environment, cfg.Environment)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigMissingRequiredVars tests that missing required vars return errors
|
||||
func TestConfigMissingRequiredVars(t *testing.T) {
|
||||
// Clear all env vars
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWT_SECRET")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Error("Expected error when DATABASE_URL is missing")
|
||||
}
|
||||
|
||||
// Set DATABASE_URL but not JWT_SECRET
|
||||
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
|
||||
defer os.Unsetenv("DATABASE_URL")
|
||||
|
||||
_, err = Load()
|
||||
if err == nil {
|
||||
t.Error("Expected error when JWT_SECRET is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigAllowedOrigins tests that allowed origins are parsed correctly
|
||||
func TestConfigAllowedOrigins(t *testing.T) {
|
||||
// Set required vars
|
||||
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
|
||||
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
|
||||
os.Setenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8000,http://localhost:8001")
|
||||
defer func() {
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWT_SECRET")
|
||||
os.Unsetenv("ALLOWED_ORIGINS")
|
||||
}()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"http://localhost:3000", "http://localhost:8000", "http://localhost:8001"}
|
||||
if len(cfg.AllowedOrigins) != len(expected) {
|
||||
t.Errorf("Expected %d origins, got %d", len(expected), len(cfg.AllowedOrigins))
|
||||
}
|
||||
|
||||
for i, origin := range cfg.AllowedOrigins {
|
||||
if origin != expected[i] {
|
||||
t.Errorf("At index %d: expected '%s', got '%s'", i, expected[i], origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigDebugSettings tests debug-related settings for different environments
|
||||
func TestConfigDebugSettings(t *testing.T) {
|
||||
// Set required vars
|
||||
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5432/test")
|
||||
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
|
||||
defer func() {
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWT_SECRET")
|
||||
}()
|
||||
|
||||
// Test development environment
|
||||
t.Run("development", func(t *testing.T) {
|
||||
os.Setenv("ENVIRONMENT", "development")
|
||||
defer os.Unsetenv("ENVIRONMENT")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Environment != "development" {
|
||||
t.Errorf("Expected 'development', got '%s'", cfg.Environment)
|
||||
}
|
||||
})
|
||||
|
||||
// Test staging environment
|
||||
t.Run("staging", func(t *testing.T) {
|
||||
os.Setenv("ENVIRONMENT", "staging")
|
||||
defer os.Unsetenv("ENVIRONMENT")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Environment != "staging" {
|
||||
t.Errorf("Expected 'staging', got '%s'", cfg.Environment)
|
||||
}
|
||||
})
|
||||
|
||||
// Test production environment
|
||||
t.Run("production", func(t *testing.T) {
|
||||
os.Setenv("ENVIRONMENT", "production")
|
||||
defer os.Unsetenv("ENVIRONMENT")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Environment != "production" {
|
||||
t.Errorf("Expected 'production', got '%s'", cfg.Environment)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigStagingPorts tests that staging uses different ports
|
||||
func TestConfigStagingPorts(t *testing.T) {
|
||||
// Set required vars
|
||||
os.Setenv("DATABASE_URL", "postgres://test:test@localhost:5433/breakpilot_staging")
|
||||
os.Setenv("JWT_SECRET", "test-secret-32-chars-minimum-here")
|
||||
os.Setenv("ENVIRONMENT", "staging")
|
||||
os.Setenv("PORT", "8081")
|
||||
defer func() {
|
||||
os.Unsetenv("DATABASE_URL")
|
||||
os.Unsetenv("JWT_SECRET")
|
||||
os.Unsetenv("ENVIRONMENT")
|
||||
os.Unsetenv("PORT")
|
||||
}()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Port != "8081" {
|
||||
t.Errorf("Expected staging port '8081', got '%s'", cfg.Port)
|
||||
}
|
||||
|
||||
if cfg.Environment != "staging" {
|
||||
t.Errorf("Expected 'staging', got '%s'", cfg.Environment)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,442 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication endpoints
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
emailService *services.EmailService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
// @Summary Register a new user
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body models.RegisterRequest true "Registration data"
|
||||
// @Success 201 {object} map[string]interface{}
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string
|
||||
// @Router /auth/register [post]
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req models.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, verificationToken, err := h.authService.Register(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
if err == services.ErrUserExists {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "User with this email already exists"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to register user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Send verification email (async, don't block response)
|
||||
go func() {
|
||||
var name string
|
||||
if user.Name != nil {
|
||||
name = *user.Name
|
||||
}
|
||||
if err := h.emailService.SendVerificationEmail(user.Email, name, verificationToken); err != nil {
|
||||
// Log error but don't fail registration
|
||||
println("Failed to send verification email:", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Registration successful. Please check your email to verify your account.",
|
||||
"user": gin.H{
|
||||
"id": user.ID,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Login handles user login
|
||||
// @Summary Login user
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body models.LoginRequest true "Login credentials"
|
||||
// @Success 200 {object} models.LoginResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /auth/login [post]
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req models.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
response, err := h.authService.Login(c.Request.Context(), &req, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrInvalidCredentials:
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid email or password"})
|
||||
case services.ErrAccountLocked:
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Account is temporarily locked. Please try again later."})
|
||||
case services.ErrAccountSuspended:
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "Account is suspended",
|
||||
"reason": "consent_required",
|
||||
"redirect": "/consent/pending",
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Login failed"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Logout handles user logout
|
||||
// @Summary Logout user
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "Bearer token"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Router /auth/logout [post]
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err == nil && req.RefreshToken != "" {
|
||||
_ = h.authService.Logout(c.Request.Context(), req.RefreshToken)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
|
||||
}
|
||||
|
||||
// RefreshToken refreshes the access token
|
||||
// @Summary Refresh access token
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body models.RefreshTokenRequest true "Refresh token"
|
||||
// @Success 200 {object} models.LoginResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Router /auth/refresh [post]
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req models.RefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
if err == services.ErrAccountSuspended {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "Account is suspended",
|
||||
"reason": "consent_required",
|
||||
"redirect": "/consent/pending",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired refresh token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// VerifyEmail verifies user email
|
||||
// @Summary Verify email address
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body models.VerifyEmailRequest true "Verification token"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Router /auth/verify-email [post]
|
||||
func (h *AuthHandler) VerifyEmail(c *gin.Context) {
|
||||
var req models.VerifyEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.VerifyEmail(c.Request.Context(), req.Token); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired verification token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Email verified successfully. You can now log in."})
|
||||
}
|
||||
|
||||
// ResendVerification resends verification email
|
||||
// @Summary Resend verification email
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body map[string]string true "Email"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Router /auth/resend-verification [post]
|
||||
func (h *AuthHandler) ResendVerification(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
// Always return success to prevent email enumeration
|
||||
c.JSON(http.StatusOK, gin.H{"message": "If an account exists with this email, a verification email has been sent."})
|
||||
}
|
||||
|
||||
// ForgotPassword initiates password reset
|
||||
// @Summary Request password reset
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body models.ForgotPasswordRequest true "Email"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Router /auth/forgot-password [post]
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
var req models.ForgotPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
token, userID, err := h.authService.CreatePasswordResetToken(c.Request.Context(), req.Email, c.ClientIP())
|
||||
if err == nil && userID != nil {
|
||||
// Send email asynchronously
|
||||
go func() {
|
||||
_ = h.emailService.SendPasswordResetEmail(req.Email, "", token)
|
||||
}()
|
||||
}
|
||||
|
||||
// Always return success to prevent email enumeration
|
||||
c.JSON(http.StatusOK, gin.H{"message": "If an account exists with this email, a password reset link has been sent."})
|
||||
}
|
||||
|
||||
// ResetPassword resets password with token
|
||||
// @Summary Reset password
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body models.ResetPasswordRequest true "Reset token and new password"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Router /auth/reset-password [post]
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req models.ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.ResetPassword(c.Request.Context(), req.Token, req.NewPassword); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid or expired reset token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Password reset successfully. You can now log in with your new password."})
|
||||
}
|
||||
|
||||
// GetProfile returns the current user's profile
|
||||
// @Summary Get user profile
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} models.User
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Router /profile [get]
|
||||
func (h *AuthHandler) GetProfile(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
// UpdateProfile updates the current user's profile
|
||||
// @Summary Update user profile
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body models.UpdateProfileRequest true "Profile data"
|
||||
// @Success 200 {object} models.User
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Router /profile [put]
|
||||
func (h *AuthHandler) UpdateProfile(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.UpdateProfile(c.Request.Context(), userID, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update profile"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
// ChangePassword changes the current user's password
|
||||
// @Summary Change password
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body models.ChangePasswordRequest true "Password data"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Router /profile/password [put]
|
||||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.ChangePassword(c.Request.Context(), userID, req.CurrentPassword, req.NewPassword); err != nil {
|
||||
if err == services.ErrInvalidCredentials {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is incorrect"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to change password"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Password changed successfully"})
|
||||
}
|
||||
|
||||
// GetActiveSessions returns all active sessions for the current user
|
||||
// @Summary Get active sessions
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {array} models.UserSession
|
||||
// @Router /profile/sessions [get]
|
||||
func (h *AuthHandler) GetActiveSessions(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
sessions, err := h.authService.GetActiveSessions(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get sessions"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"sessions": sessions})
|
||||
}
|
||||
|
||||
// RevokeSession revokes a specific session
|
||||
// @Summary Revoke session
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path string true "Session ID"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 404 {object} map[string]string
|
||||
// @Router /profile/sessions/{id} [delete]
|
||||
func (h *AuthHandler) RevokeSession(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.RevokeSession(c.Request.Context(), userID, sessionID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Session revoked successfully"})
|
||||
}
|
||||
@@ -0,0 +1,561 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ========================================
|
||||
// Cookie Banner SDK API Handlers
|
||||
// ========================================
|
||||
// Diese Endpoints werden vom @breakpilot/consent-sdk verwendet
|
||||
// für anonyme (device-basierte) Cookie-Einwilligungen.
|
||||
|
||||
// BannerConsentRecord repräsentiert einen anonymen Consent-Eintrag
|
||||
type BannerConsentRecord struct {
|
||||
ID string `json:"id"`
|
||||
SiteID string `json:"site_id"`
|
||||
DeviceFingerprint string `json:"device_fingerprint"`
|
||||
UserID *string `json:"user_id,omitempty"`
|
||||
Categories map[string]bool `json:"categories"`
|
||||
Vendors map[string]bool `json:"vendors,omitempty"`
|
||||
TCFString *string `json:"tcf_string,omitempty"`
|
||||
IPHash *string `json:"ip_hash,omitempty"`
|
||||
UserAgent *string `json:"user_agent,omitempty"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
Platform *string `json:"platform,omitempty"`
|
||||
AppVersion *string `json:"app_version,omitempty"`
|
||||
Version string `json:"version"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||
}
|
||||
|
||||
// BannerConsentRequest ist der Request-Body für POST /consent
|
||||
type BannerConsentRequest struct {
|
||||
SiteID string `json:"siteId" binding:"required"`
|
||||
UserID *string `json:"userId,omitempty"`
|
||||
DeviceFingerprint string `json:"deviceFingerprint" binding:"required"`
|
||||
Consent ConsentData `json:"consent" binding:"required"`
|
||||
Metadata *ConsentMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ConsentData enthält die eigentlichen Consent-Daten
|
||||
type ConsentData struct {
|
||||
Categories map[string]bool `json:"categories" binding:"required"`
|
||||
Vendors map[string]bool `json:"vendors,omitempty"`
|
||||
}
|
||||
|
||||
// ConsentMetadata enthält optionale Metadaten
|
||||
type ConsentMetadata struct {
|
||||
UserAgent *string `json:"userAgent,omitempty"`
|
||||
Language *string `json:"language,omitempty"`
|
||||
ScreenResolution *string `json:"screenResolution,omitempty"`
|
||||
Platform *string `json:"platform,omitempty"`
|
||||
AppVersion *string `json:"appVersion,omitempty"`
|
||||
}
|
||||
|
||||
// BannerConsentResponse ist die Antwort auf POST /consent
|
||||
type BannerConsentResponse struct {
|
||||
ConsentID string `json:"consentId"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SiteConfig repräsentiert die Konfiguration für eine Site
|
||||
type SiteConfig struct {
|
||||
SiteID string `json:"siteId"`
|
||||
SiteName string `json:"siteName"`
|
||||
Categories []CategoryConfig `json:"categories"`
|
||||
UI UIConfig `json:"ui"`
|
||||
Legal LegalConfig `json:"legal"`
|
||||
TCF *TCFConfig `json:"tcf,omitempty"`
|
||||
}
|
||||
|
||||
// CategoryConfig repräsentiert eine Consent-Kategorie
|
||||
type CategoryConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name map[string]string `json:"name"`
|
||||
Description map[string]string `json:"description"`
|
||||
Required bool `json:"required"`
|
||||
Vendors []VendorConfig `json:"vendors"`
|
||||
}
|
||||
|
||||
// VendorConfig repräsentiert einen Vendor (Third-Party)
|
||||
type VendorConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
PrivacyPolicyURL string `json:"privacyPolicyUrl"`
|
||||
Cookies []CookieInfo `json:"cookies"`
|
||||
}
|
||||
|
||||
// CookieInfo repräsentiert ein Cookie
|
||||
type CookieInfo struct {
|
||||
Name string `json:"name"`
|
||||
Expiration string `json:"expiration"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// UIConfig repräsentiert UI-Einstellungen
|
||||
type UIConfig struct {
|
||||
Theme string `json:"theme"`
|
||||
Position string `json:"position"`
|
||||
}
|
||||
|
||||
// LegalConfig repräsentiert rechtliche Informationen
|
||||
type LegalConfig struct {
|
||||
PrivacyPolicyURL string `json:"privacyPolicyUrl"`
|
||||
ImprintURL string `json:"imprintUrl"`
|
||||
}
|
||||
|
||||
// TCFConfig repräsentiert TCF 2.2 Einstellungen
|
||||
type TCFConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CmpID int `json:"cmpId"`
|
||||
CmpVersion int `json:"cmpVersion"`
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Handler Methods
|
||||
// ========================================
|
||||
|
||||
// CreateBannerConsent erstellt oder aktualisiert einen Consent-Eintrag
|
||||
// POST /api/v1/banner/consent
|
||||
func (h *Handler) CreateBannerConsent(c *gin.Context) {
|
||||
var req BannerConsentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"message": "Invalid request body: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// IP-Adresse anonymisieren
|
||||
ipHash := anonymizeIP(c.ClientIP())
|
||||
|
||||
// Consent-ID generieren
|
||||
consentID := uuid.New().String()
|
||||
|
||||
// Ablaufdatum (1 Jahr)
|
||||
expiresAt := time.Now().AddDate(1, 0, 0)
|
||||
|
||||
// Categories und Vendors als JSON
|
||||
categoriesJSON, _ := json.Marshal(req.Consent.Categories)
|
||||
vendorsJSON, _ := json.Marshal(req.Consent.Vendors)
|
||||
|
||||
// Metadaten extrahieren
|
||||
var userAgent, language, platform, appVersion *string
|
||||
if req.Metadata != nil {
|
||||
userAgent = req.Metadata.UserAgent
|
||||
language = req.Metadata.Language
|
||||
platform = req.Metadata.Platform
|
||||
appVersion = req.Metadata.AppVersion
|
||||
}
|
||||
|
||||
// In Datenbank speichern
|
||||
_, err := h.db.Pool.Exec(ctx, `
|
||||
INSERT INTO banner_consents (
|
||||
id, site_id, device_fingerprint, user_id,
|
||||
categories, vendors, ip_hash, user_agent,
|
||||
language, platform, app_version, version,
|
||||
expires_at, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NOW(), NOW())
|
||||
ON CONFLICT (site_id, device_fingerprint)
|
||||
DO UPDATE SET
|
||||
categories = $5,
|
||||
vendors = $6,
|
||||
ip_hash = $7,
|
||||
user_agent = $8,
|
||||
language = $9,
|
||||
platform = $10,
|
||||
app_version = $11,
|
||||
version = $12,
|
||||
expires_at = $13,
|
||||
updated_at = NOW()
|
||||
RETURNING id
|
||||
`, consentID, req.SiteID, req.DeviceFingerprint, req.UserID,
|
||||
categoriesJSON, vendorsJSON, ipHash, userAgent,
|
||||
language, platform, appVersion, "1.0.0", expiresAt)
|
||||
|
||||
if err != nil {
|
||||
// Fallback: Existierenden Consent abrufen
|
||||
var existingID string
|
||||
err2 := h.db.Pool.QueryRow(ctx, `
|
||||
SELECT id FROM banner_consents
|
||||
WHERE site_id = $1 AND device_fingerprint = $2
|
||||
`, req.SiteID, req.DeviceFingerprint).Scan(&existingID)
|
||||
|
||||
if err2 == nil {
|
||||
consentID = existingID
|
||||
}
|
||||
}
|
||||
|
||||
// Audit-Log schreiben
|
||||
h.logBannerConsentAudit(ctx, consentID, "created", req, ipHash)
|
||||
|
||||
// Response
|
||||
c.JSON(http.StatusCreated, BannerConsentResponse{
|
||||
ConsentID: consentID,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt.UTC().Format(time.RFC3339),
|
||||
Version: "1.0.0",
|
||||
})
|
||||
}
|
||||
|
||||
// GetBannerConsent ruft einen bestehenden Consent ab
|
||||
// GET /api/v1/banner/consent?siteId=xxx&deviceFingerprint=xxx
|
||||
func (h *Handler) GetBannerConsent(c *gin.Context) {
|
||||
siteID := c.Query("siteId")
|
||||
deviceFingerprint := c.Query("deviceFingerprint")
|
||||
|
||||
if siteID == "" || deviceFingerprint == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "missing_parameters",
|
||||
"message": "siteId and deviceFingerprint are required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var record BannerConsentRecord
|
||||
var categoriesJSON, vendorsJSON []byte
|
||||
|
||||
err := h.db.Pool.QueryRow(ctx, `
|
||||
SELECT id, site_id, device_fingerprint, user_id,
|
||||
categories, vendors, version,
|
||||
created_at, updated_at, expires_at, revoked_at
|
||||
FROM banner_consents
|
||||
WHERE site_id = $1 AND device_fingerprint = $2 AND revoked_at IS NULL
|
||||
`, siteID, deviceFingerprint).Scan(
|
||||
&record.ID, &record.SiteID, &record.DeviceFingerprint, &record.UserID,
|
||||
&categoriesJSON, &vendorsJSON, &record.Version,
|
||||
&record.CreatedAt, &record.UpdatedAt, &record.ExpiresAt, &record.RevokedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "consent_not_found",
|
||||
"message": "No consent record found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// JSON parsen
|
||||
json.Unmarshal(categoriesJSON, &record.Categories)
|
||||
json.Unmarshal(vendorsJSON, &record.Vendors)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"consentId": record.ID,
|
||||
"consent": gin.H{
|
||||
"categories": record.Categories,
|
||||
"vendors": record.Vendors,
|
||||
},
|
||||
"createdAt": record.CreatedAt.UTC().Format(time.RFC3339),
|
||||
"updatedAt": record.UpdatedAt.UTC().Format(time.RFC3339),
|
||||
"expiresAt": record.ExpiresAt.UTC().Format(time.RFC3339),
|
||||
"version": record.Version,
|
||||
})
|
||||
}
|
||||
|
||||
// RevokeBannerConsent widerruft einen Consent
|
||||
// DELETE /api/v1/banner/consent/:consentId
|
||||
func (h *Handler) RevokeBannerConsent(c *gin.Context) {
|
||||
consentID := c.Param("consentId")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := h.db.Pool.Exec(ctx, `
|
||||
UPDATE banner_consents
|
||||
SET revoked_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
`, consentID)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "revoke_failed",
|
||||
"message": "Failed to revoke consent",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "consent_not_found",
|
||||
"message": "Consent not found or already revoked",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Audit-Log
|
||||
h.logBannerConsentAudit(ctx, consentID, "revoked", nil, anonymizeIP(c.ClientIP()))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "revoked",
|
||||
"revokedAt": time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// GetSiteConfig gibt die Konfiguration für eine Site zurück
|
||||
// GET /api/v1/banner/config/:siteId
|
||||
func (h *Handler) GetSiteConfig(c *gin.Context) {
|
||||
siteID := c.Param("siteId")
|
||||
|
||||
// Standard-Kategorien (aus Datenbank oder Default)
|
||||
categories := []CategoryConfig{
|
||||
{
|
||||
ID: "essential",
|
||||
Name: map[string]string{
|
||||
"de": "Essentiell",
|
||||
"en": "Essential",
|
||||
},
|
||||
Description: map[string]string{
|
||||
"de": "Notwendig für die Grundfunktionen der Website.",
|
||||
"en": "Required for basic website functionality.",
|
||||
},
|
||||
Required: true,
|
||||
Vendors: []VendorConfig{},
|
||||
},
|
||||
{
|
||||
ID: "functional",
|
||||
Name: map[string]string{
|
||||
"de": "Funktional",
|
||||
"en": "Functional",
|
||||
},
|
||||
Description: map[string]string{
|
||||
"de": "Ermöglicht Personalisierung und Komfortfunktionen.",
|
||||
"en": "Enables personalization and comfort features.",
|
||||
},
|
||||
Required: false,
|
||||
Vendors: []VendorConfig{},
|
||||
},
|
||||
{
|
||||
ID: "analytics",
|
||||
Name: map[string]string{
|
||||
"de": "Statistik",
|
||||
"en": "Analytics",
|
||||
},
|
||||
Description: map[string]string{
|
||||
"de": "Hilft uns, die Website zu verbessern.",
|
||||
"en": "Helps us improve the website.",
|
||||
},
|
||||
Required: false,
|
||||
Vendors: []VendorConfig{},
|
||||
},
|
||||
{
|
||||
ID: "marketing",
|
||||
Name: map[string]string{
|
||||
"de": "Marketing",
|
||||
"en": "Marketing",
|
||||
},
|
||||
Description: map[string]string{
|
||||
"de": "Ermöglicht personalisierte Werbung.",
|
||||
"en": "Enables personalized advertising.",
|
||||
},
|
||||
Required: false,
|
||||
Vendors: []VendorConfig{},
|
||||
},
|
||||
{
|
||||
ID: "social",
|
||||
Name: map[string]string{
|
||||
"de": "Soziale Medien",
|
||||
"en": "Social Media",
|
||||
},
|
||||
Description: map[string]string{
|
||||
"de": "Ermöglicht Inhalte von sozialen Netzwerken.",
|
||||
"en": "Enables content from social networks.",
|
||||
},
|
||||
Required: false,
|
||||
Vendors: []VendorConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
config := SiteConfig{
|
||||
SiteID: siteID,
|
||||
SiteName: "BreakPilot",
|
||||
Categories: categories,
|
||||
UI: UIConfig{
|
||||
Theme: "auto",
|
||||
Position: "bottom",
|
||||
},
|
||||
Legal: LegalConfig{
|
||||
PrivacyPolicyURL: "/datenschutz",
|
||||
ImprintURL: "/impressum",
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
}
|
||||
|
||||
// ExportBannerConsent exportiert alle Consent-Daten eines Nutzers (DSGVO Art. 20)
|
||||
// GET /api/v1/banner/consent/export?userId=xxx
|
||||
func (h *Handler) ExportBannerConsent(c *gin.Context) {
|
||||
userID := c.Query("userId")
|
||||
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "missing_user_id",
|
||||
"message": "userId parameter is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rows, err := h.db.Pool.Query(ctx, `
|
||||
SELECT id, site_id, device_fingerprint, categories, vendors,
|
||||
version, created_at, updated_at, revoked_at
|
||||
FROM banner_consents
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, userID)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "export_failed",
|
||||
"message": "Failed to export consent data",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var consents []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, siteID, deviceFingerprint, version string
|
||||
var categoriesJSON, vendorsJSON []byte
|
||||
var createdAt, updatedAt time.Time
|
||||
var revokedAt *time.Time
|
||||
|
||||
rows.Scan(&id, &siteID, &deviceFingerprint, &categoriesJSON, &vendorsJSON,
|
||||
&version, &createdAt, &updatedAt, &revokedAt)
|
||||
|
||||
var categories, vendors map[string]bool
|
||||
json.Unmarshal(categoriesJSON, &categories)
|
||||
json.Unmarshal(vendorsJSON, &vendors)
|
||||
|
||||
consent := map[string]interface{}{
|
||||
"consentId": id,
|
||||
"siteId": siteID,
|
||||
"consent": map[string]interface{}{
|
||||
"categories": categories,
|
||||
"vendors": vendors,
|
||||
},
|
||||
"createdAt": createdAt.UTC().Format(time.RFC3339),
|
||||
"revokedAt": nil,
|
||||
}
|
||||
|
||||
if revokedAt != nil {
|
||||
consent["revokedAt"] = revokedAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
consents = append(consents, consent)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"userId": userID,
|
||||
"exportedAt": time.Now().UTC().Format(time.RFC3339),
|
||||
"consents": consents,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBannerStats gibt anonymisierte Statistiken zurück (Admin)
|
||||
// GET /api/v1/banner/admin/stats/:siteId
|
||||
func (h *Handler) GetBannerStats(c *gin.Context) {
|
||||
siteID := c.Param("siteId")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Gesamtanzahl Consents
|
||||
var totalConsents int
|
||||
h.db.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM banner_consents
|
||||
WHERE site_id = $1 AND revoked_at IS NULL
|
||||
`, siteID).Scan(&totalConsents)
|
||||
|
||||
// Consent-Rate pro Kategorie
|
||||
categoryStats := make(map[string]map[string]interface{})
|
||||
|
||||
rows, _ := h.db.Pool.Query(ctx, `
|
||||
SELECT
|
||||
key as category,
|
||||
COUNT(*) FILTER (WHERE value::text = 'true') as accepted,
|
||||
COUNT(*) as total
|
||||
FROM banner_consents,
|
||||
jsonb_each(categories::jsonb)
|
||||
WHERE site_id = $1 AND revoked_at IS NULL
|
||||
GROUP BY key
|
||||
`, siteID)
|
||||
|
||||
if rows != nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var category string
|
||||
var accepted, total int
|
||||
rows.Scan(&category, &accepted, &total)
|
||||
|
||||
rate := float64(0)
|
||||
if total > 0 {
|
||||
rate = float64(accepted) / float64(total)
|
||||
}
|
||||
|
||||
categoryStats[category] = map[string]interface{}{
|
||||
"accepted": accepted,
|
||||
"rate": rate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"siteId": siteID,
|
||||
"period": gin.H{
|
||||
"from": time.Now().AddDate(0, -1, 0).Format("2006-01-02"),
|
||||
"to": time.Now().Format("2006-01-02"),
|
||||
},
|
||||
"totalConsents": totalConsents,
|
||||
"consentByCategory": categoryStats,
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Helper Functions
|
||||
// ========================================
|
||||
|
||||
// anonymizeIP anonymisiert eine IP-Adresse (DSGVO-konform)
|
||||
func anonymizeIP(ip string) string {
|
||||
// IPv4: Letztes Oktett auf 0
|
||||
parts := strings.Split(ip, ".")
|
||||
if len(parts) == 4 {
|
||||
parts[3] = "0"
|
||||
anonymized := strings.Join(parts, ".")
|
||||
hash := sha256.Sum256([]byte(anonymized))
|
||||
return hex.EncodeToString(hash[:])[:16]
|
||||
}
|
||||
|
||||
// IPv6: Hash
|
||||
hash := sha256.Sum256([]byte(ip))
|
||||
return hex.EncodeToString(hash[:])[:16]
|
||||
}
|
||||
|
||||
// logBannerConsentAudit schreibt einen Audit-Log-Eintrag
|
||||
func (h *Handler) logBannerConsentAudit(ctx context.Context, consentID, action string, req interface{}, ipHash string) {
|
||||
details, _ := json.Marshal(req)
|
||||
|
||||
h.db.Pool.Exec(ctx, `
|
||||
INSERT INTO banner_consent_audit_log (
|
||||
id, consent_id, action, details, ip_hash, created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
`, uuid.New().String(), consentID, action, string(details), ipHash)
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/services/jitsi"
|
||||
"github.com/breakpilot/consent-service/internal/services/matrix"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CommunicationHandlers handles Matrix and Jitsi API endpoints
|
||||
type CommunicationHandlers struct {
|
||||
matrixService *matrix.MatrixService
|
||||
jitsiService *jitsi.JitsiService
|
||||
}
|
||||
|
||||
// NewCommunicationHandlers creates new communication handlers
|
||||
func NewCommunicationHandlers(matrixSvc *matrix.MatrixService, jitsiSvc *jitsi.JitsiService) *CommunicationHandlers {
|
||||
return &CommunicationHandlers{
|
||||
matrixService: matrixSvc,
|
||||
jitsiService: jitsiSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Health & Status Endpoints
|
||||
// ========================================
|
||||
|
||||
// GetCommunicationStatus returns status of Matrix and Jitsi services
|
||||
func (h *CommunicationHandlers) GetCommunicationStatus(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
status := gin.H{
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Check Matrix
|
||||
if h.matrixService != nil {
|
||||
matrixErr := h.matrixService.HealthCheck(ctx)
|
||||
status["matrix"] = gin.H{
|
||||
"enabled": true,
|
||||
"healthy": matrixErr == nil,
|
||||
"server_name": h.matrixService.GetServerName(),
|
||||
"error": errToString(matrixErr),
|
||||
}
|
||||
} else {
|
||||
status["matrix"] = gin.H{
|
||||
"enabled": false,
|
||||
"healthy": false,
|
||||
}
|
||||
}
|
||||
|
||||
// Check Jitsi
|
||||
if h.jitsiService != nil {
|
||||
jitsiErr := h.jitsiService.HealthCheck(ctx)
|
||||
serverInfo := h.jitsiService.GetServerInfo()
|
||||
status["jitsi"] = gin.H{
|
||||
"enabled": true,
|
||||
"healthy": jitsiErr == nil,
|
||||
"base_url": serverInfo["base_url"],
|
||||
"auth_enabled": serverInfo["auth_enabled"],
|
||||
"error": errToString(jitsiErr),
|
||||
}
|
||||
} else {
|
||||
status["jitsi"] = gin.H{
|
||||
"enabled": false,
|
||||
"healthy": false,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Matrix Room Endpoints
|
||||
// ========================================
|
||||
|
||||
// CreateRoomRequest for creating Matrix rooms
|
||||
type CreateRoomRequest struct {
|
||||
Type string `json:"type" binding:"required"` // "class_info", "student_dm", "parent_rep"
|
||||
ClassName string `json:"class_name"`
|
||||
SchoolName string `json:"school_name"`
|
||||
StudentName string `json:"student_name,omitempty"`
|
||||
TeacherIDs []string `json:"teacher_ids"`
|
||||
ParentIDs []string `json:"parent_ids,omitempty"`
|
||||
ParentRepIDs []string `json:"parent_rep_ids,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRoom creates a new Matrix room based on type
|
||||
func (h *CommunicationHandlers) CreateRoom(c *gin.Context) {
|
||||
if h.matrixService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateRoomRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var resp *matrix.CreateRoomResponse
|
||||
var err error
|
||||
|
||||
switch req.Type {
|
||||
case "class_info":
|
||||
resp, err = h.matrixService.CreateClassInfoRoom(ctx, req.ClassName, req.SchoolName, req.TeacherIDs)
|
||||
case "student_dm":
|
||||
resp, err = h.matrixService.CreateStudentDMRoom(ctx, req.StudentName, req.ClassName, req.TeacherIDs, req.ParentIDs)
|
||||
case "parent_rep":
|
||||
resp, err = h.matrixService.CreateParentRepRoom(ctx, req.ClassName, req.TeacherIDs, req.ParentRepIDs)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid room type. Use: class_info, student_dm, parent_rep"})
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"room_id": resp.RoomID,
|
||||
"type": req.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// InviteUserRequest for inviting users to rooms
|
||||
type InviteUserRequest struct {
|
||||
RoomID string `json:"room_id" binding:"required"`
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
}
|
||||
|
||||
// InviteUser invites a user to a Matrix room
|
||||
func (h *CommunicationHandlers) InviteUser(c *gin.Context) {
|
||||
if h.matrixService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req InviteUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
if err := h.matrixService.InviteUser(ctx, req.RoomID, req.UserID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// SendMessageRequest for sending messages
|
||||
type SendMessageRequest struct {
|
||||
RoomID string `json:"room_id" binding:"required"`
|
||||
Message string `json:"message" binding:"required"`
|
||||
HTML string `json:"html,omitempty"`
|
||||
}
|
||||
|
||||
// SendMessage sends a message to a Matrix room
|
||||
func (h *CommunicationHandlers) SendMessage(c *gin.Context) {
|
||||
if h.matrixService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var err error
|
||||
|
||||
if req.HTML != "" {
|
||||
err = h.matrixService.SendHTMLMessage(ctx, req.RoomID, req.Message, req.HTML)
|
||||
} else {
|
||||
err = h.matrixService.SendMessage(ctx, req.RoomID, req.Message)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// SendNotificationRequest for sending school notifications
|
||||
type SendNotificationRequest struct {
|
||||
RoomID string `json:"room_id" binding:"required"`
|
||||
Type string `json:"type" binding:"required"` // "absence", "grade", "announcement"
|
||||
StudentName string `json:"student_name,omitempty"`
|
||||
Date string `json:"date,omitempty"`
|
||||
Lesson int `json:"lesson,omitempty"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
GradeType string `json:"grade_type,omitempty"`
|
||||
Grade float64 `json:"grade,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
TeacherName string `json:"teacher_name,omitempty"`
|
||||
}
|
||||
|
||||
// SendNotification sends a typed notification (absence, grade, announcement)
|
||||
func (h *CommunicationHandlers) SendNotification(c *gin.Context) {
|
||||
if h.matrixService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req SendNotificationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var err error
|
||||
|
||||
switch req.Type {
|
||||
case "absence":
|
||||
err = h.matrixService.SendAbsenceNotification(ctx, req.RoomID, req.StudentName, req.Date, req.Lesson)
|
||||
case "grade":
|
||||
err = h.matrixService.SendGradeNotification(ctx, req.RoomID, req.StudentName, req.Subject, req.GradeType, req.Grade)
|
||||
case "announcement":
|
||||
err = h.matrixService.SendClassAnnouncement(ctx, req.RoomID, req.Title, req.Content, req.TeacherName)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification type. Use: absence, grade, announcement"})
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// RegisterUserRequest for user registration
|
||||
type RegisterUserRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// RegisterMatrixUser registers a new Matrix user
|
||||
func (h *CommunicationHandlers) RegisterMatrixUser(c *gin.Context) {
|
||||
if h.matrixService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Matrix service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req RegisterUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
resp, err := h.matrixService.RegisterUser(ctx, req.Username, req.DisplayName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"user_id": resp.UserID,
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Jitsi Video Conference Endpoints
|
||||
// ========================================
|
||||
|
||||
// CreateMeetingRequest for creating Jitsi meetings
|
||||
type CreateMeetingRequest struct {
|
||||
Type string `json:"type" binding:"required"` // "quick", "training", "parent_teacher", "class"
|
||||
Title string `json:"title,omitempty"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Duration int `json:"duration,omitempty"` // minutes
|
||||
ClassName string `json:"class_name,omitempty"`
|
||||
ParentName string `json:"parent_name,omitempty"`
|
||||
StudentName string `json:"student_name,omitempty"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
StartTime time.Time `json:"start_time,omitempty"`
|
||||
}
|
||||
|
||||
// CreateMeeting creates a new Jitsi meeting
|
||||
func (h *CommunicationHandlers) CreateMeeting(c *gin.Context) {
|
||||
if h.jitsiService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Jitsi service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateMeetingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var link *jitsi.MeetingLink
|
||||
var err error
|
||||
|
||||
switch req.Type {
|
||||
case "quick":
|
||||
link, err = h.jitsiService.CreateQuickMeeting(ctx, req.DisplayName)
|
||||
case "training":
|
||||
link, err = h.jitsiService.CreateTrainingSession(ctx, req.Title, req.DisplayName, req.Email, req.Duration)
|
||||
case "parent_teacher":
|
||||
link, err = h.jitsiService.CreateParentTeacherMeeting(ctx, req.DisplayName, req.ParentName, req.StudentName, req.StartTime)
|
||||
case "class":
|
||||
link, err = h.jitsiService.CreateClassMeeting(ctx, req.ClassName, req.DisplayName, req.Subject)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid meeting type. Use: quick, training, parent_teacher, class"})
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"room_name": link.RoomName,
|
||||
"url": link.URL,
|
||||
"join_url": link.JoinURL,
|
||||
"moderator_url": link.ModeratorURL,
|
||||
"password": link.Password,
|
||||
"expires_at": link.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
// GetEmbedURLRequest for embedding Jitsi
|
||||
type GetEmbedURLRequest struct {
|
||||
RoomName string `json:"room_name" binding:"required"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AudioMuted bool `json:"audio_muted"`
|
||||
VideoMuted bool `json:"video_muted"`
|
||||
}
|
||||
|
||||
// GetEmbedURL returns an embeddable Jitsi URL
|
||||
func (h *CommunicationHandlers) GetEmbedURL(c *gin.Context) {
|
||||
if h.jitsiService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Jitsi service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
var req GetEmbedURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
config := &jitsi.MeetingConfig{
|
||||
StartWithAudioMuted: req.AudioMuted,
|
||||
StartWithVideoMuted: req.VideoMuted,
|
||||
DisableDeepLinking: true,
|
||||
}
|
||||
|
||||
embedURL := h.jitsiService.BuildEmbedURL(req.RoomName, req.DisplayName, config)
|
||||
iframeCode := h.jitsiService.BuildIFrameCode(req.RoomName, 800, 600)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"embed_url": embedURL,
|
||||
"iframe_code": iframeCode,
|
||||
})
|
||||
}
|
||||
|
||||
// GetJitsiInfo returns Jitsi server information
|
||||
func (h *CommunicationHandlers) GetJitsiInfo(c *gin.Context) {
|
||||
if h.jitsiService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Jitsi service not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
info := h.jitsiService.GetServerInfo()
|
||||
c.JSON(http.StatusOK, info)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Admin Statistics Endpoints (for Admin Panel)
|
||||
// ========================================
|
||||
|
||||
// CommunicationStats holds communication service statistics
|
||||
type CommunicationStats struct {
|
||||
Matrix MatrixStats `json:"matrix"`
|
||||
Jitsi JitsiStats `json:"jitsi"`
|
||||
}
|
||||
|
||||
// MatrixStats holds Matrix-specific statistics
|
||||
type MatrixStats struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Healthy bool `json:"healthy"`
|
||||
ServerName string `json:"server_name"`
|
||||
// TODO: Add real stats from Matrix Synapse Admin API
|
||||
TotalUsers int `json:"total_users"`
|
||||
TotalRooms int `json:"total_rooms"`
|
||||
ActiveToday int `json:"active_today"`
|
||||
MessagesToday int `json:"messages_today"`
|
||||
}
|
||||
|
||||
// JitsiStats holds Jitsi-specific statistics
|
||||
type JitsiStats struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Healthy bool `json:"healthy"`
|
||||
BaseURL string `json:"base_url"`
|
||||
AuthEnabled bool `json:"auth_enabled"`
|
||||
// TODO: Add real stats from Jitsi SRTP API or Jicofo
|
||||
ActiveMeetings int `json:"active_meetings"`
|
||||
TotalParticipants int `json:"total_participants"`
|
||||
MeetingsToday int `json:"meetings_today"`
|
||||
AvgDurationMin int `json:"avg_duration_min"`
|
||||
}
|
||||
|
||||
// GetAdminStats returns admin statistics for Matrix and Jitsi
|
||||
func (h *CommunicationHandlers) GetAdminStats(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
stats := CommunicationStats{}
|
||||
|
||||
// Matrix Stats
|
||||
if h.matrixService != nil {
|
||||
matrixErr := h.matrixService.HealthCheck(ctx)
|
||||
stats.Matrix = MatrixStats{
|
||||
Enabled: true,
|
||||
Healthy: matrixErr == nil,
|
||||
ServerName: h.matrixService.GetServerName(),
|
||||
// Placeholder stats - in production these would come from Synapse Admin API
|
||||
TotalUsers: 0,
|
||||
TotalRooms: 0,
|
||||
ActiveToday: 0,
|
||||
MessagesToday: 0,
|
||||
}
|
||||
} else {
|
||||
stats.Matrix = MatrixStats{Enabled: false}
|
||||
}
|
||||
|
||||
// Jitsi Stats
|
||||
if h.jitsiService != nil {
|
||||
jitsiErr := h.jitsiService.HealthCheck(ctx)
|
||||
serverInfo := h.jitsiService.GetServerInfo()
|
||||
stats.Jitsi = JitsiStats{
|
||||
Enabled: true,
|
||||
Healthy: jitsiErr == nil,
|
||||
BaseURL: serverInfo["base_url"],
|
||||
AuthEnabled: serverInfo["auth_enabled"] == "true",
|
||||
// Placeholder stats - in production these would come from Jicofo/JVB stats
|
||||
ActiveMeetings: 0,
|
||||
TotalParticipants: 0,
|
||||
MeetingsToday: 0,
|
||||
AvgDurationMin: 0,
|
||||
}
|
||||
} else {
|
||||
stats.Jitsi = JitsiStats{Enabled: false}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Helper Functions
|
||||
// ========================================
|
||||
|
||||
func errToString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all communication routes
|
||||
func (h *CommunicationHandlers) RegisterRoutes(router *gin.RouterGroup, jwtSecret string, authMiddleware gin.HandlerFunc) {
|
||||
comm := router.Group("/communication")
|
||||
{
|
||||
// Public health check
|
||||
comm.GET("/status", h.GetCommunicationStatus)
|
||||
|
||||
// Protected routes
|
||||
protected := comm.Group("")
|
||||
protected.Use(authMiddleware)
|
||||
{
|
||||
// Matrix
|
||||
protected.POST("/rooms", h.CreateRoom)
|
||||
protected.POST("/rooms/invite", h.InviteUser)
|
||||
protected.POST("/messages", h.SendMessage)
|
||||
protected.POST("/notifications", h.SendNotification)
|
||||
|
||||
// Jitsi
|
||||
protected.POST("/meetings", h.CreateMeeting)
|
||||
protected.POST("/meetings/embed", h.GetEmbedURL)
|
||||
protected.GET("/jitsi/info", h.GetJitsiInfo)
|
||||
}
|
||||
|
||||
// Admin routes (for Matrix user registration and stats)
|
||||
admin := comm.Group("/admin")
|
||||
admin.Use(authMiddleware)
|
||||
// TODO: Add AdminOnly middleware
|
||||
{
|
||||
admin.POST("/matrix/users", h.RegisterMatrixUser)
|
||||
admin.GET("/stats", h.GetAdminStats)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestGetCommunicationStatus_NoServices tests status with no services configured
|
||||
func TestGetCommunicationStatus_NoServices_ReturnsDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create handler with no services
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/api/v1/communication/status", handler.GetCommunicationStatus)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/communication/status", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
// Check matrix is disabled
|
||||
matrix, ok := response["matrix"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected matrix in response")
|
||||
}
|
||||
if matrix["enabled"] != false {
|
||||
t.Error("Expected matrix.enabled to be false")
|
||||
}
|
||||
|
||||
// Check jitsi is disabled
|
||||
jitsi, ok := response["jitsi"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected jitsi in response")
|
||||
}
|
||||
if jitsi["enabled"] != false {
|
||||
t.Error("Expected jitsi.enabled to be false")
|
||||
}
|
||||
|
||||
// Check timestamp exists
|
||||
if _, ok := response["timestamp"]; !ok {
|
||||
t.Error("Expected timestamp in response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateRoom_NoMatrixService tests room creation without Matrix
|
||||
func TestCreateRoom_NoMatrixService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/rooms", handler.CreateRoom)
|
||||
|
||||
body := `{"type": "class_info", "class_name": "5b"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "Matrix service not configured" {
|
||||
t.Errorf("Unexpected error message: %s", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateRoom_InvalidBody tests room creation with invalid body
|
||||
func TestCreateRoom_InvalidBody_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/rooms", handler.CreateRoom)
|
||||
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms", bytes.NewBufferString("{invalid"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Service unavailable check happens first, so we get 503
|
||||
// This is expected behavior - service check before body validation
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInviteUser_NoMatrixService tests invite without Matrix
|
||||
func TestInviteUser_NoMatrixService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/rooms/invite", handler.InviteUser)
|
||||
|
||||
body := `{"room_id": "!abc:server", "user_id": "@user:server"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms/invite", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendMessage_NoMatrixService tests message sending without Matrix
|
||||
func TestSendMessage_NoMatrixService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/messages", handler.SendMessage)
|
||||
|
||||
body := `{"room_id": "!abc:server", "message": "Hello"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/messages", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendNotification_NoMatrixService tests notification without Matrix
|
||||
func TestSendNotification_NoMatrixService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/notifications", handler.SendNotification)
|
||||
|
||||
body := `{"room_id": "!abc:server", "type": "absence", "student_name": "Max"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/notifications", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateMeeting_NoJitsiService tests meeting creation without Jitsi
|
||||
func TestCreateMeeting_NoJitsiService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/meetings", handler.CreateMeeting)
|
||||
|
||||
body := `{"type": "quick", "display_name": "Teacher"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/meetings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "Jitsi service not configured" {
|
||||
t.Errorf("Unexpected error message: %s", response["error"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetEmbedURL_NoJitsiService tests embed URL without Jitsi
|
||||
func TestGetEmbedURL_NoJitsiService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/meetings/embed", handler.GetEmbedURL)
|
||||
|
||||
body := `{"room_name": "test-room", "display_name": "User"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/meetings/embed", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetJitsiInfo_NoJitsiService tests Jitsi info without service
|
||||
func TestGetJitsiInfo_NoJitsiService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/api/v1/communication/jitsi/info", handler.GetJitsiInfo)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/communication/jitsi/info", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterMatrixUser_NoMatrixService tests user registration without Matrix
|
||||
func TestRegisterMatrixUser_NoMatrixService_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/admin/matrix/users", handler.RegisterMatrixUser)
|
||||
|
||||
body := `{"username": "testuser", "display_name": "Test User"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/admin/matrix/users", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAdminStats_NoServices tests admin stats without services
|
||||
func TestGetAdminStats_NoServices_ReturnsDisabledStats(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/api/v1/communication/admin/stats", handler.GetAdminStats)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/communication/admin/stats", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response CommunicationStats
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response.Matrix.Enabled {
|
||||
t.Error("Expected matrix.enabled to be false")
|
||||
}
|
||||
|
||||
if response.Jitsi.Enabled {
|
||||
t.Error("Expected jitsi.enabled to be false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrToString tests the helper function
|
||||
func TestErrToString_NilError_ReturnsEmpty(t *testing.T) {
|
||||
result := errToString(nil)
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrToString_WithError_ReturnsMessage tests error string conversion
|
||||
func TestErrToString_WithError_ReturnsMessage(t *testing.T) {
|
||||
err := &testError{"test error message"}
|
||||
result := errToString(err)
|
||||
if result != "test error message" {
|
||||
t.Errorf("Expected 'test error message', got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// testError is a simple error implementation for testing
|
||||
type testError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *testError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// TestCreateRoomRequest_Types tests different room types validation
|
||||
func TestCreateRoom_InvalidType_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Since we don't have Matrix service, we get 503 first
|
||||
// This test documents expected behavior when Matrix IS available
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/rooms", handler.CreateRoom)
|
||||
|
||||
body := `{"type": "invalid_type", "class_name": "5b"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/rooms", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Without Matrix service, we get 503 before type validation
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateMeeting_InvalidType tests invalid meeting type
|
||||
func TestCreateMeeting_InvalidType_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/meetings", handler.CreateMeeting)
|
||||
|
||||
body := `{"type": "invalid", "display_name": "User"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/meetings", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Without Jitsi service, we get 503 before type validation
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendNotification_InvalidType tests invalid notification type
|
||||
func TestSendNotification_InvalidType_Returns503(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/communication/notifications", handler.SendNotification)
|
||||
|
||||
body := `{"room_id": "!abc:server", "type": "invalid", "student_name": "Max"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/communication/notifications", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Without Matrix service, we get 503 before type validation
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCommunicationHandlers tests constructor
|
||||
func TestNewCommunicationHandlers_WithNilServices_CreatesHandler(t *testing.T) {
|
||||
handler := NewCommunicationHandlers(nil, nil)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created")
|
||||
}
|
||||
|
||||
if handler.matrixService != nil {
|
||||
t.Error("Expected matrixService to be nil")
|
||||
}
|
||||
|
||||
if handler.jitsiService != nil {
|
||||
t.Error("Expected jitsiService to be nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/middleware"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// DeadlineHandler handles deadline-related requests
|
||||
type DeadlineHandler struct {
|
||||
deadlineService *services.DeadlineService
|
||||
}
|
||||
|
||||
// NewDeadlineHandler creates a new deadline handler
|
||||
func NewDeadlineHandler(deadlineService *services.DeadlineService) *DeadlineHandler {
|
||||
return &DeadlineHandler{
|
||||
deadlineService: deadlineService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPendingDeadlines returns all pending consent deadlines for the current user
|
||||
func (h *DeadlineHandler) GetPendingDeadlines(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
deadlines, err := h.deadlineService.GetPendingDeadlines(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch deadlines"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"deadlines": deadlines,
|
||||
"count": len(deadlines),
|
||||
})
|
||||
}
|
||||
|
||||
// GetSuspensionStatus returns the current suspension status for a user
|
||||
func (h *DeadlineHandler) GetSuspensionStatus(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
suspended, err := h.deadlineService.IsUserSuspended(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check suspension status"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"suspended": suspended,
|
||||
}
|
||||
|
||||
if suspended {
|
||||
suspension, err := h.deadlineService.GetAccountSuspension(c.Request.Context(), userID)
|
||||
if err == nil && suspension != nil {
|
||||
response["reason"] = suspension.Reason
|
||||
response["suspended_at"] = suspension.SuspendedAt
|
||||
response["details"] = suspension.Details
|
||||
}
|
||||
|
||||
deadlines, err := h.deadlineService.GetPendingDeadlines(c.Request.Context(), userID)
|
||||
if err == nil {
|
||||
response["pending_deadlines"] = deadlines
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// TriggerDeadlineProcessing manually triggers deadline processing (admin only)
|
||||
func (h *DeadlineHandler) TriggerDeadlineProcessing(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin access required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deadlineService.ProcessDailyDeadlines(c.Request.Context()); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process deadlines"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Deadline processing completed"})
|
||||
}
|
||||
@@ -0,0 +1,948 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/middleware"
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// DSRHandler handles Data Subject Request HTTP endpoints
|
||||
type DSRHandler struct {
|
||||
dsrService *services.DSRService
|
||||
}
|
||||
|
||||
// NewDSRHandler creates a new DSR handler
|
||||
func NewDSRHandler(dsrService *services.DSRService) *DSRHandler {
|
||||
return &DSRHandler{
|
||||
dsrService: dsrService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// USER ENDPOINTS
|
||||
// ========================================
|
||||
|
||||
// CreateDSR creates a new data subject request (user-facing)
|
||||
func (h *DSRHandler) CreateDSR(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.CreateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user email if not provided
|
||||
if req.RequesterEmail == "" {
|
||||
var email string
|
||||
ctx := context.Background()
|
||||
h.dsrService.GetPool().QueryRow(ctx, "SELECT email FROM users WHERE id = $1", userID).Scan(&email)
|
||||
req.RequesterEmail = email
|
||||
}
|
||||
|
||||
// Set source as API
|
||||
req.Source = "api"
|
||||
|
||||
dsr, err := h.dsrService.CreateRequest(c.Request.Context(), req, &userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Ihre Anfrage wurde erfolgreich eingereicht",
|
||||
"request_number": dsr.RequestNumber,
|
||||
"dsr": dsr,
|
||||
})
|
||||
}
|
||||
|
||||
// GetMyDSRs returns DSRs for the current user
|
||||
func (h *DSRHandler) GetMyDSRs(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrs, err := h.dsrService.ListByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch requests"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"requests": dsrs})
|
||||
}
|
||||
|
||||
// GetMyDSR returns a specific DSR for the current user
|
||||
func (h *DSRHandler) GetMyDSR(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
dsr, err := h.dsrService.GetByID(c.Request.Context(), dsrID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Request not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if dsr.UserID == nil || *dsr.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dsr)
|
||||
}
|
||||
|
||||
// CancelMyDSR cancels a user's own DSR
|
||||
func (h *DSRHandler) CancelMyDSR(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.CancelRequest(c.Request.Context(), dsrID, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde storniert"})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// ADMIN ENDPOINTS
|
||||
// ========================================
|
||||
|
||||
// AdminListDSR returns all DSRs with filters (admin only)
|
||||
func (h *DSRHandler) AdminListDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse pagination
|
||||
limit := 20
|
||||
offset := 0
|
||||
if l := c.Query("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
if o := c.Query("offset"); o != "" {
|
||||
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Parse filters
|
||||
filters := models.DSRListFilters{}
|
||||
if status := c.Query("status"); status != "" {
|
||||
filters.Status = &status
|
||||
}
|
||||
if reqType := c.Query("request_type"); reqType != "" {
|
||||
filters.RequestType = &reqType
|
||||
}
|
||||
if assignedTo := c.Query("assigned_to"); assignedTo != "" {
|
||||
filters.AssignedTo = &assignedTo
|
||||
}
|
||||
if priority := c.Query("priority"); priority != "" {
|
||||
filters.Priority = &priority
|
||||
}
|
||||
if c.Query("overdue_only") == "true" {
|
||||
filters.OverdueOnly = true
|
||||
}
|
||||
if search := c.Query("search"); search != "" {
|
||||
filters.Search = &search
|
||||
}
|
||||
if from := c.Query("from_date"); from != "" {
|
||||
if t, err := time.Parse("2006-01-02", from); err == nil {
|
||||
filters.FromDate = &t
|
||||
}
|
||||
}
|
||||
if to := c.Query("to_date"); to != "" {
|
||||
if t, err := time.Parse("2006-01-02", to); err == nil {
|
||||
filters.ToDate = &t
|
||||
}
|
||||
}
|
||||
|
||||
dsrs, total, err := h.dsrService.List(c.Request.Context(), filters, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch requests"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"requests": dsrs,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminGetDSR returns a specific DSR (admin only)
|
||||
func (h *DSRHandler) AdminGetDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
dsr, err := h.dsrService.GetByID(c.Request.Context(), dsrID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Request not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dsr)
|
||||
}
|
||||
|
||||
// AdminCreateDSR creates a DSR manually (admin only)
|
||||
func (h *DSRHandler) AdminCreateDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.CreateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Set source as admin_panel
|
||||
if req.Source == "" {
|
||||
req.Source = "admin_panel"
|
||||
}
|
||||
|
||||
dsr, err := h.dsrService.CreateRequest(c.Request.Context(), req, &userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Anfrage wurde erstellt",
|
||||
"request_number": dsr.RequestNumber,
|
||||
"dsr": dsr,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminUpdateDSR updates a DSR (admin only)
|
||||
func (h *DSRHandler) AdminUpdateDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.UpdateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Update status if provided
|
||||
if req.Status != nil {
|
||||
err = h.dsrService.UpdateStatus(ctx, dsrID, models.DSRStatus(*req.Status), "", &userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update processing notes
|
||||
if req.ProcessingNotes != nil {
|
||||
h.dsrService.GetPool().Exec(ctx, `
|
||||
UPDATE data_subject_requests SET processing_notes = $1, updated_at = NOW() WHERE id = $2
|
||||
`, *req.ProcessingNotes, dsrID)
|
||||
}
|
||||
|
||||
// Update priority
|
||||
if req.Priority != nil {
|
||||
h.dsrService.GetPool().Exec(ctx, `
|
||||
UPDATE data_subject_requests SET priority = $1, updated_at = NOW() WHERE id = $2
|
||||
`, *req.Priority, dsrID)
|
||||
}
|
||||
|
||||
// Get updated DSR
|
||||
dsr, _ := h.dsrService.GetByID(ctx, dsrID)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Anfrage wurde aktualisiert",
|
||||
"dsr": dsr,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminGetDSRStats returns dashboard statistics
|
||||
func (h *DSRHandler) AdminGetDSRStats(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dsrService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch statistics"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// AdminVerifyIdentity verifies the identity of a requester
|
||||
func (h *DSRHandler) AdminVerifyIdentity(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.VerifyDSRIdentityRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.VerifyIdentity(c.Request.Context(), dsrID, req.Method, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Identität wurde verifiziert"})
|
||||
}
|
||||
|
||||
// AdminAssignDSR assigns a DSR to a user
|
||||
func (h *DSRHandler) AdminAssignDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
AssigneeID string `json:"assignee_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
assigneeID, err := uuid.Parse(req.AssigneeID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid assignee ID"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.AssignRequest(c.Request.Context(), dsrID, assigneeID, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde zugewiesen"})
|
||||
}
|
||||
|
||||
// AdminExtendDSRDeadline extends the deadline for a DSR
|
||||
func (h *DSRHandler) AdminExtendDSRDeadline(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.ExtendDSRDeadlineRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.ExtendDeadline(c.Request.Context(), dsrID, req.Reason, req.Days, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Frist wurde verlängert"})
|
||||
}
|
||||
|
||||
// AdminCompleteDSR marks a DSR as completed
|
||||
func (h *DSRHandler) AdminCompleteDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.CompleteDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.CompleteRequest(c.Request.Context(), dsrID, req.ResultSummary, req.ResultData, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde abgeschlossen"})
|
||||
}
|
||||
|
||||
// AdminRejectDSR rejects a DSR
|
||||
func (h *DSRHandler) AdminRejectDSR(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.RejectDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.RejectRequest(c.Request.Context(), dsrID, req.Reason, req.LegalBasis, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Anfrage wurde abgelehnt"})
|
||||
}
|
||||
|
||||
// AdminGetDSRHistory returns the status history for a DSR
|
||||
func (h *DSRHandler) AdminGetDSRHistory(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
history, err := h.dsrService.GetStatusHistory(c.Request.Context(), dsrID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch history"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"history": history})
|
||||
}
|
||||
|
||||
// AdminGetDSRCommunications returns communications for a DSR
|
||||
func (h *DSRHandler) AdminGetDSRCommunications(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
comms, err := h.dsrService.GetCommunications(c.Request.Context(), dsrID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch communications"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"communications": comms})
|
||||
}
|
||||
|
||||
// AdminSendDSRCommunication sends a communication for a DSR
|
||||
func (h *DSRHandler) AdminSendDSRCommunication(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req models.SendDSRCommunicationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.SendCommunication(c.Request.Context(), dsrID, req, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Kommunikation wurde gesendet"})
|
||||
}
|
||||
|
||||
// AdminUpdateDSRStatus updates the status of a DSR
|
||||
func (h *DSRHandler) AdminUpdateDSRStatus(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
Status string `json:"status" binding:"required"`
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.UpdateStatus(c.Request.Context(), dsrID, models.DSRStatus(req.Status), req.Comment, &userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Status wurde aktualisiert"})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// EXCEPTION CHECKS (Art. 17)
|
||||
// ========================================
|
||||
|
||||
// AdminGetExceptionChecks returns exception checks for an erasure DSR
|
||||
func (h *DSRHandler) AdminGetExceptionChecks(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
checks, err := h.dsrService.GetExceptionChecks(c.Request.Context(), dsrID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch exception checks"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"exception_checks": checks})
|
||||
}
|
||||
|
||||
// AdminInitExceptionChecks initializes exception checks for an erasure DSR
|
||||
func (h *DSRHandler) AdminInitExceptionChecks(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
dsrID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.InitErasureExceptionChecks(c.Request.Context(), dsrID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to initialize exception checks"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Ausnahmeprüfungen wurden initialisiert"})
|
||||
}
|
||||
|
||||
// AdminUpdateExceptionCheck updates a single exception check
|
||||
func (h *DSRHandler) AdminUpdateExceptionCheck(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
checkID, err := uuid.Parse(c.Param("checkId"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid check ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
Applies bool `json:"applies"`
|
||||
Notes *string `json:"notes"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.dsrService.UpdateExceptionCheck(c.Request.Context(), checkID, req.Applies, req.Notes, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update exception check"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Ausnahmeprüfung wurde aktualisiert"})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// TEMPLATE ENDPOINTS
|
||||
// ========================================
|
||||
|
||||
// AdminGetDSRTemplates returns all DSR templates
|
||||
func (h *DSRHandler) AdminGetDSRTemplates(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
rows, err := h.dsrService.GetPool().Query(ctx, `
|
||||
SELECT id, template_type, name, description, request_types, is_active, sort_order, created_at, updated_at
|
||||
FROM dsr_templates ORDER BY sort_order, name
|
||||
`)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch templates"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var templates []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var templateType, name string
|
||||
var description *string
|
||||
var requestTypes []byte
|
||||
var isActive bool
|
||||
var sortOrder int
|
||||
var createdAt, updatedAt time.Time
|
||||
|
||||
err := rows.Scan(&id, &templateType, &name, &description, &requestTypes, &isActive, &sortOrder, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
templates = append(templates, map[string]interface{}{
|
||||
"id": id,
|
||||
"template_type": templateType,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"request_types": string(requestTypes),
|
||||
"is_active": isActive,
|
||||
"sort_order": sortOrder,
|
||||
"created_at": createdAt,
|
||||
"updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"templates": templates})
|
||||
}
|
||||
|
||||
// AdminGetDSRTemplateVersions returns versions for a template
|
||||
func (h *DSRHandler) AdminGetDSRTemplateVersions(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
templateID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid template ID"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
rows, err := h.dsrService.GetPool().Query(ctx, `
|
||||
SELECT id, template_id, version, language, subject, body_html, body_text,
|
||||
status, published_at, created_by, approved_by, approved_at, created_at, updated_at
|
||||
FROM dsr_template_versions WHERE template_id = $1 ORDER BY created_at DESC
|
||||
`, templateID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch versions"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var versions []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, tempID uuid.UUID
|
||||
var version, language, subject, bodyHTML, bodyText, status string
|
||||
var publishedAt, approvedAt *time.Time
|
||||
var createdBy, approvedBy *uuid.UUID
|
||||
var createdAt, updatedAt time.Time
|
||||
|
||||
err := rows.Scan(&id, &tempID, &version, &language, &subject, &bodyHTML, &bodyText,
|
||||
&status, &publishedAt, &createdBy, &approvedBy, &approvedAt, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
versions = append(versions, map[string]interface{}{
|
||||
"id": id,
|
||||
"template_id": tempID,
|
||||
"version": version,
|
||||
"language": language,
|
||||
"subject": subject,
|
||||
"body_html": bodyHTML,
|
||||
"body_text": bodyText,
|
||||
"status": status,
|
||||
"published_at": publishedAt,
|
||||
"created_by": createdBy,
|
||||
"approved_by": approvedBy,
|
||||
"approved_at": approvedAt,
|
||||
"created_at": createdAt,
|
||||
"updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"versions": versions})
|
||||
}
|
||||
|
||||
// AdminCreateDSRTemplateVersion creates a new template version
|
||||
func (h *DSRHandler) AdminCreateDSRTemplateVersion(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
templateID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid template ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
Version string `json:"version" binding:"required"`
|
||||
Language string `json:"language"`
|
||||
Subject string `json:"subject" binding:"required"`
|
||||
BodyHTML string `json:"body_html" binding:"required"`
|
||||
BodyText string `json:"body_text"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Language == "" {
|
||||
req.Language = "de"
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var versionID uuid.UUID
|
||||
err = h.dsrService.GetPool().QueryRow(ctx, `
|
||||
INSERT INTO dsr_template_versions (template_id, version, language, subject, body_html, body_text, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id
|
||||
`, templateID, req.Version, req.Language, req.Subject, req.BodyHTML, req.BodyText, userID).Scan(&versionID)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create version"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Version wurde erstellt",
|
||||
"id": versionID,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminPublishDSRTemplateVersion publishes a template version
|
||||
func (h *DSRHandler) AdminPublishDSRTemplateVersion(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
versionID, err := uuid.Parse(c.Param("versionId"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
ctx := c.Request.Context()
|
||||
_, err = h.dsrService.GetPool().Exec(ctx, `
|
||||
UPDATE dsr_template_versions
|
||||
SET status = 'published', published_at = NOW(), approved_by = $1, approved_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $2
|
||||
`, userID, versionID)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to publish version"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Version wurde veröffentlicht"})
|
||||
}
|
||||
|
||||
// AdminGetPublishedDSRTemplates returns all published templates for selection
|
||||
func (h *DSRHandler) AdminGetPublishedDSRTemplates(c *gin.Context) {
|
||||
if !middleware.IsAdmin(c) && !middleware.IsDSB(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin or DSB access required"})
|
||||
return
|
||||
}
|
||||
|
||||
requestType := c.Query("request_type")
|
||||
language := c.DefaultQuery("language", "de")
|
||||
|
||||
ctx := c.Request.Context()
|
||||
query := `
|
||||
SELECT t.id, t.template_type, t.name, t.description,
|
||||
v.id as version_id, v.version, v.subject, v.body_html, v.body_text
|
||||
FROM dsr_templates t
|
||||
JOIN dsr_template_versions v ON t.id = v.template_id
|
||||
WHERE t.is_active = TRUE AND v.status = 'published' AND v.language = $1
|
||||
`
|
||||
args := []interface{}{language}
|
||||
|
||||
if requestType != "" {
|
||||
query += ` AND t.request_types @> $2::jsonb`
|
||||
args = append(args, `["`+requestType+`"]`)
|
||||
}
|
||||
|
||||
query += " ORDER BY t.sort_order, t.name"
|
||||
|
||||
rows, err := h.dsrService.GetPool().Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch templates"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var templates []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var templateID, versionID uuid.UUID
|
||||
var templateType, name, version, subject, bodyHTML, bodyText string
|
||||
var description *string
|
||||
|
||||
err := rows.Scan(&templateID, &templateType, &name, &description, &versionID, &version, &subject, &bodyHTML, &bodyText)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
templates = append(templates, map[string]interface{}{
|
||||
"template_id": templateID,
|
||||
"template_type": templateType,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"version_id": versionID,
|
||||
"version": version,
|
||||
"subject": subject,
|
||||
"body_html": bodyHTML,
|
||||
"body_text": bodyText,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"templates": templates})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// DEADLINE PROCESSING
|
||||
// ========================================
|
||||
|
||||
// ProcessDeadlines triggers deadline checking (called by scheduler)
|
||||
func (h *DSRHandler) ProcessDeadlines(c *gin.Context) {
|
||||
err := h.dsrService.ProcessDeadlines(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process deadlines"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Deadline processing completed"})
|
||||
}
|
||||
@@ -0,0 +1,448 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// TestCreateDSR_InvalidBody tests create DSR with invalid body
|
||||
func TestCreateDSR_InvalidBody_Returns400(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
// Mock handler that mimics the actual behavior for invalid body
|
||||
router.POST("/api/v1/dsr", func(c *gin.Context) {
|
||||
var req models.CreateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body", "details": err.Error()})
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// Invalid JSON
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dsr", bytes.NewBufferString("{invalid json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateDSR_MissingType tests create DSR with missing type
|
||||
func TestCreateDSR_MissingType_Returns400(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/dsr", func(c *gin.Context) {
|
||||
var req models.CreateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if req.RequestType == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "request_type is required"})
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
body := `{"requester_email": "test@example.com"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dsr", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateDSR_InvalidType tests create DSR with invalid type
|
||||
func TestCreateDSR_InvalidType_Returns400(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/dsr", func(c *gin.Context) {
|
||||
var req models.CreateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if !models.IsValidDSRRequestType(req.RequestType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request_type"})
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
body := `{"request_type": "invalid_type", "requester_email": "test@example.com"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dsr", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminListDSR_Unauthorized_Returns401 tests admin list without auth
|
||||
func TestAdminListDSR_Unauthorized_Returns401(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
// Simplified auth check
|
||||
router.GET("/api/v1/admin/dsr", func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"requests": []interface{}{}})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminListDSR_ValidRequest tests admin list with valid auth
|
||||
func TestAdminListDSR_ValidRequest_Returns200(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.GET("/api/v1/admin/dsr", func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"requests": []interface{}{},
|
||||
"total": 0,
|
||||
"limit": 20,
|
||||
"offset": 0,
|
||||
})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
||||
if _, ok := response["requests"]; !ok {
|
||||
t.Error("Response should contain 'requests' field")
|
||||
}
|
||||
if _, ok := response["total"]; !ok {
|
||||
t.Error("Response should contain 'total' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminGetDSRStats_ValidRequest tests admin stats endpoint
|
||||
func TestAdminGetDSRStats_ValidRequest_Returns200(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.GET("/api/v1/admin/dsr/stats", func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total_requests": 0,
|
||||
"pending_requests": 0,
|
||||
"overdue_requests": 0,
|
||||
"completed_this_month": 0,
|
||||
"average_processing_days": 0,
|
||||
"by_type": map[string]int{},
|
||||
"by_status": map[string]int{},
|
||||
})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr/stats", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
||||
expectedFields := []string{"total_requests", "pending_requests", "overdue_requests", "by_type", "by_status"}
|
||||
for _, field := range expectedFields {
|
||||
if _, ok := response[field]; !ok {
|
||||
t.Errorf("Response should contain '%s' field", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminUpdateDSR_InvalidStatus_Returns400 tests admin update with invalid status
|
||||
func TestAdminUpdateDSR_InvalidStatus_Returns400(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.PUT("/api/v1/admin/dsr/:id", func(c *gin.Context) {
|
||||
var req models.UpdateDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if req.Status != nil && !models.IsValidDSRStatus(*req.Status) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Updated"})
|
||||
})
|
||||
|
||||
body := `{"status": "invalid_status"}`
|
||||
req, _ := http.NewRequest("PUT", "/api/v1/admin/dsr/123", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminVerifyIdentity_ValidRequest_Returns200 tests identity verification
|
||||
func TestAdminVerifyIdentity_ValidRequest_Returns200(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/admin/dsr/:id/verify-identity", func(c *gin.Context) {
|
||||
var req models.VerifyDSRIdentityRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if req.Method == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "method is required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Identität verifiziert"})
|
||||
})
|
||||
|
||||
body := `{"method": "id_card"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/verify-identity", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminExtendDeadline_MissingReason_Returns400 tests extend deadline without reason
|
||||
func TestAdminExtendDeadline_MissingReason_Returns400(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/admin/dsr/:id/extend", func(c *gin.Context) {
|
||||
var req models.ExtendDSRDeadlineRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if req.Reason == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "reason is required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Deadline extended"})
|
||||
})
|
||||
|
||||
body := `{"days": 30}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/extend", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminCompleteDSR_ValidRequest_Returns200 tests complete DSR
|
||||
func TestAdminCompleteDSR_ValidRequest_Returns200(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/admin/dsr/:id/complete", func(c *gin.Context) {
|
||||
var req models.CompleteDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Anfrage erfolgreich abgeschlossen"})
|
||||
})
|
||||
|
||||
body := `{"result_summary": "Alle Daten wurden bereitgestellt"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/complete", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminRejectDSR_MissingLegalBasis_Returns400 tests reject DSR without legal basis
|
||||
func TestAdminRejectDSR_MissingLegalBasis_Returns400(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/admin/dsr/:id/reject", func(c *gin.Context) {
|
||||
var req models.RejectDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if req.LegalBasis == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "legal_basis is required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Rejected"})
|
||||
})
|
||||
|
||||
body := `{"reason": "Some reason"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/reject", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminRejectDSR_ValidRequest_Returns200 tests reject DSR with valid data
|
||||
func TestAdminRejectDSR_ValidRequest_Returns200(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.POST("/api/v1/admin/dsr/:id/reject", func(c *gin.Context) {
|
||||
var req models.RejectDSRRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
if req.LegalBasis == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "legal_basis is required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Anfrage abgelehnt"})
|
||||
})
|
||||
|
||||
body := `{"reason": "Daten benötigt für Rechtsstreit", "legal_basis": "Art. 17(3)e"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/dsr/123/reject", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetDSRTemplates_Returns200 tests templates endpoint
|
||||
func TestGetDSRTemplates_Returns200(t *testing.T) {
|
||||
router := gin.New()
|
||||
|
||||
router.GET("/api/v1/admin/dsr-templates", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"templates": []map[string]interface{}{
|
||||
{
|
||||
"id": "uuid-1",
|
||||
"template_type": "dsr_receipt_access",
|
||||
"name": "Eingangsbestätigung (Art. 15)",
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/admin/dsr-templates", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
||||
if _, ok := response["templates"]; !ok {
|
||||
t.Error("Response should contain 'templates' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestTypeValidation tests all valid request types
|
||||
func TestRequestTypeValidation(t *testing.T) {
|
||||
validTypes := []string{"access", "rectification", "erasure", "restriction", "portability"}
|
||||
|
||||
for _, reqType := range validTypes {
|
||||
if !models.IsValidDSRRequestType(reqType) {
|
||||
t.Errorf("Expected %s to be a valid request type", reqType)
|
||||
}
|
||||
}
|
||||
|
||||
invalidTypes := []string{"invalid", "delete", "copy", ""}
|
||||
for _, reqType := range invalidTypes {
|
||||
if models.IsValidDSRRequestType(reqType) {
|
||||
t.Errorf("Expected %s to be an invalid request type", reqType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestStatusValidation tests all valid statuses
|
||||
func TestStatusValidation(t *testing.T) {
|
||||
validStatuses := []string{"intake", "identity_verification", "processing", "completed", "rejected", "cancelled"}
|
||||
|
||||
for _, status := range validStatuses {
|
||||
if !models.IsValidDSRStatus(status) {
|
||||
t.Errorf("Expected %s to be a valid status", status)
|
||||
}
|
||||
}
|
||||
|
||||
invalidStatuses := []string{"invalid", "pending", "done", ""}
|
||||
for _, status := range invalidStatuses {
|
||||
if models.IsValidDSRStatus(status) {
|
||||
t.Errorf("Expected %s to be an invalid status", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,528 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// EmailTemplateHandler handles email template operations
|
||||
type EmailTemplateHandler struct {
|
||||
service *services.EmailTemplateService
|
||||
}
|
||||
|
||||
// NewEmailTemplateHandler creates a new email template handler
|
||||
func NewEmailTemplateHandler(service *services.EmailTemplateService) *EmailTemplateHandler {
|
||||
return &EmailTemplateHandler{service: service}
|
||||
}
|
||||
|
||||
// GetAllTemplateTypes returns all available email template types with their variables
|
||||
// GET /api/v1/admin/email-templates/types
|
||||
func (h *EmailTemplateHandler) GetAllTemplateTypes(c *gin.Context) {
|
||||
types := h.service.GetAllTemplateTypes()
|
||||
c.JSON(http.StatusOK, gin.H{"types": types})
|
||||
}
|
||||
|
||||
// GetAllTemplates returns all email templates with their latest published versions
|
||||
// GET /api/v1/admin/email-templates
|
||||
func (h *EmailTemplateHandler) GetAllTemplates(c *gin.Context) {
|
||||
templates, err := h.service.GetAllTemplates(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"templates": templates})
|
||||
}
|
||||
|
||||
// GetTemplate returns a single template by ID
|
||||
// GET /api/v1/admin/email-templates/:id
|
||||
func (h *EmailTemplateHandler) GetTemplate(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid template ID"})
|
||||
return
|
||||
}
|
||||
|
||||
template, err := h.service.GetTemplateByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "template not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, template)
|
||||
}
|
||||
|
||||
// CreateTemplate creates a new email template type
|
||||
// POST /api/v1/admin/email-templates
|
||||
func (h *EmailTemplateHandler) CreateTemplate(c *gin.Context) {
|
||||
var req models.CreateEmailTemplateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
template, err := h.service.CreateEmailTemplate(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, template)
|
||||
}
|
||||
|
||||
// GetTemplateVersions returns all versions for a template
|
||||
// GET /api/v1/admin/email-templates/:id/versions
|
||||
func (h *EmailTemplateHandler) GetTemplateVersions(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid template ID"})
|
||||
return
|
||||
}
|
||||
|
||||
versions, err := h.service.GetVersionsByTemplateID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"versions": versions})
|
||||
}
|
||||
|
||||
// GetVersion returns a single version by ID
|
||||
// GET /api/v1/admin/email-template-versions/:id
|
||||
func (h *EmailTemplateHandler) GetVersion(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
version, err := h.service.GetVersionByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, version)
|
||||
}
|
||||
|
||||
// CreateVersion creates a new version of an email template
|
||||
// POST /api/v1/admin/email-template-versions
|
||||
func (h *EmailTemplateHandler) CreateVersion(c *gin.Context) {
|
||||
var req models.CreateEmailTemplateVersionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID from context
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
version, err := h.service.CreateTemplateVersion(c.Request.Context(), &req, uid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, version)
|
||||
}
|
||||
|
||||
// UpdateVersion updates a version
|
||||
// PUT /api/v1/admin/email-template-versions/:id
|
||||
func (h *EmailTemplateHandler) UpdateVersion(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.UpdateEmailTemplateVersionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.UpdateVersion(c.Request.Context(), id, &req); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "version updated"})
|
||||
}
|
||||
|
||||
// SubmitForReview submits a version for review
|
||||
// POST /api/v1/admin/email-template-versions/:id/submit
|
||||
func (h *EmailTemplateHandler) SubmitForReview(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Comment *string `json:"comment"`
|
||||
}
|
||||
c.ShouldBindJSON(&req)
|
||||
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
if err := h.service.SubmitForReview(c.Request.Context(), id, uid, req.Comment); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "version submitted for review"})
|
||||
}
|
||||
|
||||
// ApproveVersion approves a version (DSB only)
|
||||
// POST /api/v1/admin/email-template-versions/:id/approve
|
||||
func (h *EmailTemplateHandler) ApproveVersion(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check role
|
||||
role, exists := c.Get("user_role")
|
||||
if !exists || (role != "data_protection_officer" && role != "admin" && role != "super_admin") {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Comment *string `json:"comment"`
|
||||
ScheduledPublishAt *string `json:"scheduled_publish_at"`
|
||||
}
|
||||
c.ShouldBindJSON(&req)
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
var scheduledAt *time.Time
|
||||
if req.ScheduledPublishAt != nil {
|
||||
t, err := time.Parse(time.RFC3339, *req.ScheduledPublishAt)
|
||||
if err == nil {
|
||||
scheduledAt = &t
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.service.ApproveVersion(c.Request.Context(), id, uid, req.Comment, scheduledAt); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "version approved"})
|
||||
}
|
||||
|
||||
// RejectVersion rejects a version
|
||||
// POST /api/v1/admin/email-template-versions/:id/reject
|
||||
func (h *EmailTemplateHandler) RejectVersion(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
role, exists := c.Get("user_role")
|
||||
if !exists || (role != "data_protection_officer" && role != "admin" && role != "super_admin") {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Comment string `json:"comment" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "comment is required"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
if err := h.service.RejectVersion(c.Request.Context(), id, uid, req.Comment); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "version rejected"})
|
||||
}
|
||||
|
||||
// PublishVersion publishes an approved version
|
||||
// POST /api/v1/admin/email-template-versions/:id/publish
|
||||
func (h *EmailTemplateHandler) PublishVersion(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
role, exists := c.Get("user_role")
|
||||
if !exists || (role != "data_protection_officer" && role != "admin" && role != "super_admin") {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
if err := h.service.PublishVersion(c.Request.Context(), id, uid); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "version published"})
|
||||
}
|
||||
|
||||
// GetApprovals returns approval history for a version
|
||||
// GET /api/v1/admin/email-template-versions/:id/approvals
|
||||
func (h *EmailTemplateHandler) GetApprovals(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
approvals, err := h.service.GetApprovals(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"approvals": approvals})
|
||||
}
|
||||
|
||||
// PreviewVersion renders a preview of an email template version
|
||||
// POST /api/v1/admin/email-template-versions/:id/preview
|
||||
func (h *EmailTemplateHandler) PreviewVersion(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Variables map[string]string `json:"variables"`
|
||||
}
|
||||
c.ShouldBindJSON(&req)
|
||||
|
||||
version, err := h.service.GetVersionByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Use default test values if not provided
|
||||
if req.Variables == nil {
|
||||
req.Variables = map[string]string{
|
||||
"user_name": "Max Mustermann",
|
||||
"user_email": "max@example.com",
|
||||
"login_url": "https://breakpilot.app/login",
|
||||
"support_email": "support@breakpilot.app",
|
||||
"verification_url": "https://breakpilot.app/verify?token=abc123",
|
||||
"verification_code": "123456",
|
||||
"expires_in": "24 Stunden",
|
||||
"reset_url": "https://breakpilot.app/reset?token=xyz789",
|
||||
"reset_code": "RESET123",
|
||||
"ip_address": "192.168.1.1",
|
||||
"device_info": "Chrome auf Windows 11",
|
||||
"changed_at": time.Now().Format("02.01.2006 15:04"),
|
||||
"enabled_at": time.Now().Format("02.01.2006 15:04"),
|
||||
"disabled_at": time.Now().Format("02.01.2006 15:04"),
|
||||
"support_url": "https://breakpilot.app/support",
|
||||
"security_url": "https://breakpilot.app/account/security",
|
||||
"login_time": time.Now().Format("02.01.2006 15:04"),
|
||||
"location": "Berlin, Deutschland",
|
||||
"activity_type": "Mehrere fehlgeschlagene Login-Versuche",
|
||||
"activity_time": time.Now().Format("02.01.2006 15:04"),
|
||||
"locked_at": time.Now().Format("02.01.2006 15:04"),
|
||||
"reason": "Zu viele fehlgeschlagene Login-Versuche",
|
||||
"unlock_time": time.Now().Add(30 * time.Minute).Format("02.01.2006 15:04"),
|
||||
"unlocked_at": time.Now().Format("02.01.2006 15:04"),
|
||||
"requested_at": time.Now().Format("02.01.2006"),
|
||||
"deletion_date": time.Now().AddDate(0, 0, 30).Format("02.01.2006"),
|
||||
"cancel_url": "https://breakpilot.app/cancel-deletion?token=cancel123",
|
||||
"data_info": "Benutzerdaten, Zustimmungshistorie, Audit-Logs",
|
||||
"deleted_at": time.Now().Format("02.01.2006"),
|
||||
"feedback_url": "https://breakpilot.app/feedback",
|
||||
"download_url": "https://breakpilot.app/export/download?token=export123",
|
||||
"file_size": "2.3 MB",
|
||||
"old_email": "alt@example.com",
|
||||
"new_email": "neu@example.com",
|
||||
"document_name": "Datenschutzerklärung",
|
||||
"document_type": "privacy",
|
||||
"version": "2.0.0",
|
||||
"consent_url": "https://breakpilot.app/consent",
|
||||
"deadline": time.Now().AddDate(0, 0, 14).Format("02.01.2006"),
|
||||
"days_left": "7",
|
||||
"hours_left": "24 Stunden",
|
||||
"consequences": "Ohne Ihre Zustimmung wird Ihr Konto suspendiert.",
|
||||
"suspended_at": time.Now().Format("02.01.2006 15:04"),
|
||||
"documents": "- Datenschutzerklärung v2.0.0\n- AGB v1.5.0",
|
||||
}
|
||||
}
|
||||
|
||||
preview, err := h.service.RenderTemplate(version, req.Variables)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, preview)
|
||||
}
|
||||
|
||||
// SendTestEmail sends a test email
|
||||
// POST /api/v1/admin/email-template-versions/:id/send-test
|
||||
func (h *EmailTemplateHandler) SendTestEmail(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid version ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.SendTestEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.VersionID = idStr
|
||||
|
||||
version, err := h.service.GetVersionByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "version not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get template to find type
|
||||
template, err := h.service.GetTemplateByID(c.Request.Context(), version.TemplateID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "template not found"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
// Send test email
|
||||
if err := h.service.SendEmail(c.Request.Context(), template.Type, version.Language, req.Recipient, req.Variables, &uid); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "test email sent"})
|
||||
}
|
||||
|
||||
// GetSettings returns global email settings
|
||||
// GET /api/v1/admin/email-templates/settings
|
||||
func (h *EmailTemplateHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.service.GetSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
// Return default settings if none exist
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"company_name": "BreakPilot",
|
||||
"sender_name": "BreakPilot",
|
||||
"sender_email": "noreply@breakpilot.app",
|
||||
"primary_color": "#2563eb",
|
||||
"secondary_color": "#64748b",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, settings)
|
||||
}
|
||||
|
||||
// UpdateSettings updates global email settings
|
||||
// PUT /api/v1/admin/email-templates/settings
|
||||
func (h *EmailTemplateHandler) UpdateSettings(c *gin.Context) {
|
||||
var req models.UpdateEmailTemplateSettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := c.Get("user_id")
|
||||
uid, _ := uuid.Parse(userID.(string))
|
||||
|
||||
if err := h.service.UpdateSettings(c.Request.Context(), &req, uid); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "settings updated"})
|
||||
}
|
||||
|
||||
// GetEmailStats returns email statistics
|
||||
// GET /api/v1/admin/email-templates/stats
|
||||
func (h *EmailTemplateHandler) GetEmailStats(c *gin.Context) {
|
||||
stats, err := h.service.GetEmailStats(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetSendLogs returns email send logs
|
||||
// GET /api/v1/admin/email-templates/logs
|
||||
func (h *EmailTemplateHandler) GetSendLogs(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
logs, total, err := h.service.GetSendLogs(c.Request.Context(), limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"logs": logs, "total": total})
|
||||
}
|
||||
|
||||
// GetDefaultContent returns default template content for a type
|
||||
// GET /api/v1/admin/email-templates/default/:type
|
||||
func (h *EmailTemplateHandler) GetDefaultContent(c *gin.Context) {
|
||||
templateType := c.Param("type")
|
||||
language := c.DefaultQuery("language", "de")
|
||||
|
||||
subject, bodyHTML, bodyText := h.service.GetDefaultTemplateContent(templateType, language)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"subject": subject,
|
||||
"body_html": bodyHTML,
|
||||
"body_text": bodyText,
|
||||
})
|
||||
}
|
||||
|
||||
// InitializeTemplates initializes default email templates
|
||||
// POST /api/v1/admin/email-templates/initialize
|
||||
func (h *EmailTemplateHandler) InitializeTemplates(c *gin.Context) {
|
||||
role, exists := c.Get("user_role")
|
||||
if !exists || (role != "admin" && role != "super_admin") {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient permissions"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.InitDefaultTemplates(c.Request.Context()); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "default templates initialized"})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,805 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// setupTestRouter creates a test router with handlers
|
||||
// Note: For full integration tests, use a test database
|
||||
func setupTestRouter() *gin.Engine {
|
||||
router := gin.New()
|
||||
return router
|
||||
}
|
||||
|
||||
// TestHealthEndpoint tests the health check endpoint
|
||||
func TestHealthEndpoint(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// Add health endpoint
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"service": "consent-service",
|
||||
"version": "1.0.0",
|
||||
})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
||||
if response["status"] != "healthy" {
|
||||
t.Errorf("Expected status 'healthy', got %v", response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnauthorizedAccess tests that protected endpoints require auth
|
||||
func TestUnauthorizedAccess(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// Add a protected endpoint
|
||||
router.GET("/api/v1/consent/my", func(c *gin.Context) {
|
||||
auth := c.GetHeader("Authorization")
|
||||
if auth == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization required"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"consents": []interface{}{}})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authorization string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"no auth header", "", http.StatusUnauthorized},
|
||||
{"empty bearer", "Bearer ", http.StatusOK}, // Would be invalid in real middleware
|
||||
{"valid format", "Bearer test-token", http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/api/v1/consent/my", nil)
|
||||
if tt.authorization != "" {
|
||||
req.Header.Set("Authorization", tt.authorization)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateConsentRequest tests consent creation request validation
|
||||
func TestCreateConsentRequest(t *testing.T) {
|
||||
type ConsentRequest struct {
|
||||
DocumentType string `json:"document_type"`
|
||||
VersionID string `json:"version_id"`
|
||||
Consented bool `json:"consented"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request ConsentRequest
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid consent",
|
||||
request: ConsentRequest{
|
||||
DocumentType: "terms",
|
||||
VersionID: "123e4567-e89b-12d3-a456-426614174000",
|
||||
Consented: true,
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "missing document type",
|
||||
request: ConsentRequest{
|
||||
VersionID: "123e4567-e89b-12d3-a456-426614174000",
|
||||
Consented: true,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing version ID",
|
||||
request: ConsentRequest{
|
||||
DocumentType: "terms",
|
||||
Consented: true,
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.request.DocumentType != "" && tt.request.VersionID != ""
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentTypeValidation tests valid document types
|
||||
func TestDocumentTypeValidation(t *testing.T) {
|
||||
validTypes := map[string]bool{
|
||||
"terms": true,
|
||||
"privacy": true,
|
||||
"cookies": true,
|
||||
"community_guidelines": true,
|
||||
"imprint": true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
docType string
|
||||
expected bool
|
||||
}{
|
||||
{"terms", true},
|
||||
{"privacy", true},
|
||||
{"cookies", true},
|
||||
{"community_guidelines", true},
|
||||
{"imprint", true},
|
||||
{"invalid", false},
|
||||
{"", false},
|
||||
{"Terms", false}, // case sensitive
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.docType, func(t *testing.T) {
|
||||
_, isValid := validTypes[tt.docType]
|
||||
if isValid != tt.expected {
|
||||
t.Errorf("Expected %s valid=%v, got %v", tt.docType, tt.expected, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVersionStatusTransitions tests valid status transitions
|
||||
func TestVersionStatusTransitions(t *testing.T) {
|
||||
validTransitions := map[string][]string{
|
||||
"draft": {"review"},
|
||||
"review": {"approved", "rejected"},
|
||||
"approved": {"scheduled", "published"},
|
||||
"scheduled": {"published"},
|
||||
"published": {"archived"},
|
||||
"rejected": {"draft"},
|
||||
"archived": {}, // terminal state
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
fromStatus string
|
||||
toStatus string
|
||||
expected bool
|
||||
}{
|
||||
{"draft", "review", true},
|
||||
{"draft", "published", false},
|
||||
{"review", "approved", true},
|
||||
{"review", "rejected", true},
|
||||
{"review", "published", false},
|
||||
{"approved", "published", true},
|
||||
{"approved", "scheduled", true},
|
||||
{"published", "archived", true},
|
||||
{"published", "draft", false},
|
||||
{"archived", "draft", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.fromStatus+"->"+tt.toStatus, func(t *testing.T) {
|
||||
allowed := false
|
||||
if transitions, ok := validTransitions[tt.fromStatus]; ok {
|
||||
for _, t := range transitions {
|
||||
if t == tt.toStatus {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.expected {
|
||||
t.Errorf("Transition %s->%s: expected %v, got %v",
|
||||
tt.fromStatus, tt.toStatus, tt.expected, allowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRolePermissions tests role-based access control
|
||||
func TestRolePermissions(t *testing.T) {
|
||||
permissions := map[string]map[string]bool{
|
||||
"user": {
|
||||
"view_documents": true,
|
||||
"give_consent": true,
|
||||
"view_own_data": true,
|
||||
"request_deletion": true,
|
||||
"create_document": false,
|
||||
"publish_version": false,
|
||||
"approve_version": false,
|
||||
},
|
||||
"admin": {
|
||||
"view_documents": true,
|
||||
"give_consent": true,
|
||||
"view_own_data": true,
|
||||
"create_document": true,
|
||||
"edit_version": true,
|
||||
"publish_version": true,
|
||||
"approve_version": false, // Only DSB
|
||||
},
|
||||
"data_protection_officer": {
|
||||
"view_documents": true,
|
||||
"create_document": true,
|
||||
"edit_version": true,
|
||||
"approve_version": true,
|
||||
"publish_version": true,
|
||||
"view_audit_log": true,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
role string
|
||||
action string
|
||||
shouldHave bool
|
||||
}{
|
||||
{"user", "view_documents", true},
|
||||
{"user", "create_document", false},
|
||||
{"admin", "create_document", true},
|
||||
{"admin", "approve_version", false},
|
||||
{"data_protection_officer", "approve_version", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.role+":"+tt.action, func(t *testing.T) {
|
||||
rolePerms, ok := permissions[tt.role]
|
||||
if !ok {
|
||||
t.Fatalf("Unknown role: %s", tt.role)
|
||||
}
|
||||
|
||||
hasPermission := rolePerms[tt.action]
|
||||
if hasPermission != tt.shouldHave {
|
||||
t.Errorf("Role %s action %s: expected %v, got %v",
|
||||
tt.role, tt.action, tt.shouldHave, hasPermission)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJSONResponseFormat tests that responses have correct format
|
||||
func TestJSONResponseFormat(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
router.GET("/api/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"id": "123",
|
||||
"name": "Test",
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json; charset=utf-8" {
|
||||
t.Errorf("Expected Content-Type 'application/json; charset=utf-8', got %s", contentType)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
if err != nil {
|
||||
t.Fatalf("Response should be valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResponseFormat tests error response format
|
||||
func TestErrorResponseFormat(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
router.GET("/api/error", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Bad Request",
|
||||
"message": "Invalid input",
|
||||
})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/error", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
||||
if response["error"] == nil {
|
||||
t.Error("Error response should contain 'error' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCookieCategoryValidation tests cookie category validation
|
||||
func TestCookieCategoryValidation(t *testing.T) {
|
||||
mandatoryCategories := []string{"necessary"}
|
||||
optionalCategories := []string{"functional", "analytics", "marketing"}
|
||||
|
||||
// Necessary should always be consented
|
||||
for _, cat := range mandatoryCategories {
|
||||
t.Run("mandatory_"+cat, func(t *testing.T) {
|
||||
// Business rule: mandatory categories cannot be declined
|
||||
isMandatory := true
|
||||
if !isMandatory {
|
||||
t.Errorf("Category %s should be mandatory", cat)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Optional categories can be toggled
|
||||
for _, cat := range optionalCategories {
|
||||
t.Run("optional_"+cat, func(t *testing.T) {
|
||||
isMandatory := false
|
||||
if isMandatory {
|
||||
t.Errorf("Category %s should not be mandatory", cat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaginationParams tests pagination parameter handling
|
||||
func TestPaginationParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
perPage int
|
||||
expPage int
|
||||
expLimit int
|
||||
}{
|
||||
{"defaults", 0, 0, 1, 50},
|
||||
{"page 1", 1, 10, 1, 10},
|
||||
{"page 5", 5, 20, 5, 20},
|
||||
{"negative page", -1, 10, 1, 10}, // should default
|
||||
{"too large per_page", 1, 500, 1, 100}, // should cap
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
perPage := tt.perPage
|
||||
|
||||
// Apply defaults and limits
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if perPage < 1 {
|
||||
perPage = 50
|
||||
}
|
||||
if perPage > 100 {
|
||||
perPage = 100
|
||||
}
|
||||
|
||||
if page != tt.expPage {
|
||||
t.Errorf("Expected page %d, got %d", tt.expPage, page)
|
||||
}
|
||||
if perPage != tt.expLimit {
|
||||
t.Errorf("Expected perPage %d, got %d", tt.expLimit, perPage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIPAddressExtraction tests IP address extraction from requests
|
||||
func TestIPAddressExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
xForwarded string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{"direct connection", "", "192.168.1.1:1234", "192.168.1.1"},
|
||||
{"behind proxy", "10.0.0.1", "192.168.1.1:1234", "10.0.0.1"},
|
||||
{"multiple proxies", "10.0.0.1, 10.0.0.2", "192.168.1.1:1234", "10.0.0.1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
var extractedIP string
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
if xf := c.GetHeader("X-Forwarded-For"); xf != "" {
|
||||
// Take first IP from list
|
||||
for i, ch := range xf {
|
||||
if ch == ',' {
|
||||
extractedIP = xf[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if extractedIP == "" {
|
||||
extractedIP = xf
|
||||
}
|
||||
} else {
|
||||
// Extract IP from RemoteAddr
|
||||
addr := c.Request.RemoteAddr
|
||||
for i := len(addr) - 1; i >= 0; i-- {
|
||||
if addr[i] == ':' {
|
||||
extractedIP = addr[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ip": extractedIP})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwarded != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwarded)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if extractedIP != tt.expected {
|
||||
t.Errorf("Expected IP %s, got %s", tt.expected, extractedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestBodySizeLimit tests that large requests are rejected
|
||||
func TestRequestBodySizeLimit(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// Simulate a body size limit check
|
||||
maxBodySize := int64(1024 * 1024) // 1MB
|
||||
|
||||
router.POST("/api/upload", func(c *gin.Context) {
|
||||
if c.Request.ContentLength > maxBodySize {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "Request body too large",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contentLength int64
|
||||
expectedStatus int
|
||||
}{
|
||||
{"small body", 1000, http.StatusOK},
|
||||
{"medium body", 500000, http.StatusOK},
|
||||
{"exactly at limit", maxBodySize, http.StatusOK},
|
||||
{"over limit", maxBodySize + 1, http.StatusRequestEntityTooLarge},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 0))
|
||||
req, _ := http.NewRequest("POST", "/api/upload", body)
|
||||
req.ContentLength = tt.contentLength
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// EXTENDED HANDLER TESTS
|
||||
// ========================================
|
||||
|
||||
// TestAuthHandlers tests authentication endpoints
|
||||
func TestAuthHandlers(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// Register endpoint
|
||||
router.POST("/api/v1/auth/register", func(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, gin.H{"message": "User registered"})
|
||||
})
|
||||
|
||||
// Login endpoint
|
||||
router.POST("/api/v1/auth/login", func(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"access_token": "token123"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
method string
|
||||
body interface{}
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "register - valid",
|
||||
endpoint: "/api/v1/auth/register",
|
||||
method: "POST",
|
||||
body: map[string]string{"email": "test@example.com", "password": "password123"},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "login - valid",
|
||||
endpoint: "/api/v1/auth/login",
|
||||
method: "POST",
|
||||
body: map[string]string{"email": "test@example.com", "password": "password123"},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jsonBody, _ := json.Marshal(tt.body)
|
||||
req, _ := http.NewRequest(tt.method, tt.endpoint, bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentHandlers tests document endpoints
|
||||
func TestDocumentHandlers(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// GET documents
|
||||
router.GET("/api/v1/documents", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"documents": []interface{}{}})
|
||||
})
|
||||
|
||||
// GET document by type
|
||||
router.GET("/api/v1/documents/:type", func(c *gin.Context) {
|
||||
docType := c.Param("type")
|
||||
if docType == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid type"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"id": "123", "type": docType})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"get all documents", "/api/v1/documents", http.StatusOK},
|
||||
{"get terms", "/api/v1/documents/terms", http.StatusOK},
|
||||
{"get privacy", "/api/v1/documents/privacy", http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", tt.endpoint, nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentHandlers tests consent endpoints
|
||||
func TestConsentHandlers(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// Create consent
|
||||
router.POST("/api/v1/consent", func(c *gin.Context) {
|
||||
var req struct {
|
||||
VersionID string `json:"version_id"`
|
||||
Consented bool `json:"consented"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, gin.H{"message": "Consent saved"})
|
||||
})
|
||||
|
||||
// Check consent
|
||||
router.GET("/api/v1/consent/check/:type", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"has_consent": true, "needs_update": false})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
method string
|
||||
body interface{}
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "create consent",
|
||||
endpoint: "/api/v1/consent",
|
||||
method: "POST",
|
||||
body: map[string]interface{}{"version_id": "123", "consented": true},
|
||||
expectedStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "check consent",
|
||||
endpoint: "/api/v1/consent/check/terms",
|
||||
method: "GET",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req *http.Request
|
||||
if tt.body != nil {
|
||||
jsonBody, _ := json.Marshal(tt.body)
|
||||
req, _ = http.NewRequest(tt.method, tt.endpoint, bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req, _ = http.NewRequest(tt.method, tt.endpoint, nil)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminHandlers tests admin endpoints
|
||||
func TestAdminHandlers(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
// Create document (admin only)
|
||||
router.POST("/api/v1/admin/documents", func(c *gin.Context) {
|
||||
auth := c.GetHeader("Authorization")
|
||||
if auth != "Bearer admin-token" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin only"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusCreated, gin.H{"message": "Document created"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"admin token", "Bearer admin-token", http.StatusCreated},
|
||||
{"user token", "Bearer user-token", http.StatusForbidden},
|
||||
{"no token", "", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := map[string]string{"type": "terms", "name": "Test"}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/documents", bytes.NewBuffer(jsonBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.token != "" {
|
||||
req.Header.Set("Authorization", tt.token)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORSHeaders tests CORS headers
|
||||
func TestCORSHeaders(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Next()
|
||||
})
|
||||
|
||||
router.GET("/api/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "test"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||
t.Error("CORS headers not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiting tests rate limiting logic
|
||||
func TestRateLimiting(t *testing.T) {
|
||||
requests := 0
|
||||
limit := 5
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
requests++
|
||||
if requests > limit {
|
||||
// Would return 429 Too Many Requests
|
||||
if requests <= limit {
|
||||
t.Error("Rate limit not enforced")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmailTemplateHandlers tests email template endpoints
|
||||
func TestEmailTemplateHandlers(t *testing.T) {
|
||||
router := setupTestRouter()
|
||||
|
||||
router.GET("/api/v1/admin/email-templates", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"templates": []interface{}{}})
|
||||
})
|
||||
|
||||
router.POST("/api/v1/admin/email-templates/test", func(c *gin.Context) {
|
||||
var req struct {
|
||||
Recipient string `json:"recipient"`
|
||||
VersionID string `json:"version_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Test email sent"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/api/v1/admin/email-templates", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/middleware"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// NotificationHandler handles notification-related requests
|
||||
type NotificationHandler struct {
|
||||
notificationService *services.NotificationService
|
||||
}
|
||||
|
||||
// NewNotificationHandler creates a new notification handler
|
||||
func NewNotificationHandler(notificationService *services.NotificationService) *NotificationHandler {
|
||||
return &NotificationHandler{
|
||||
notificationService: notificationService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetNotifications returns notifications for the current user
|
||||
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse query parameters
|
||||
limit := 20
|
||||
offset := 0
|
||||
unreadOnly := false
|
||||
|
||||
if l := c.Query("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
if o := c.Query("offset"); o != "" {
|
||||
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
if u := c.Query("unread_only"); u == "true" {
|
||||
unreadOnly = true
|
||||
}
|
||||
|
||||
notifications, total, err := h.notificationService.GetUserNotifications(c.Request.Context(), userID, limit, offset, unreadOnly)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch notifications"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"notifications": notifications,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUnreadCount returns the count of unread notifications
|
||||
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get unread count"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"unread_count": count})
|
||||
}
|
||||
|
||||
// MarkAsRead marks a notification as read
|
||||
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
notificationID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.notificationService.MarkAsRead(c.Request.Context(), userID, notificationID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Notification not found or already read"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Notification marked as read"})
|
||||
}
|
||||
|
||||
// MarkAllAsRead marks all notifications as read
|
||||
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to mark notifications as read"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "All notifications marked as read"})
|
||||
}
|
||||
|
||||
// DeleteNotification deletes a notification
|
||||
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
notificationID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.notificationService.DeleteNotification(c.Request.Context(), userID, notificationID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Notification not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Notification deleted"})
|
||||
}
|
||||
|
||||
// GetPreferences returns notification preferences for the user
|
||||
func (h *NotificationHandler) GetPreferences(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
prefs, err := h.notificationService.GetPreferences(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get preferences"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, prefs)
|
||||
}
|
||||
|
||||
// UpdatePreferences updates notification preferences for the user
|
||||
func (h *NotificationHandler) UpdatePreferences(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
EmailEnabled *bool `json:"email_enabled"`
|
||||
PushEnabled *bool `json:"push_enabled"`
|
||||
InAppEnabled *bool `json:"in_app_enabled"`
|
||||
ReminderFrequency *string `json:"reminder_frequency"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get current preferences
|
||||
prefs, _ := h.notificationService.GetPreferences(c.Request.Context(), userID)
|
||||
|
||||
// Update only provided fields
|
||||
if req.EmailEnabled != nil {
|
||||
prefs.EmailEnabled = *req.EmailEnabled
|
||||
}
|
||||
if req.PushEnabled != nil {
|
||||
prefs.PushEnabled = *req.PushEnabled
|
||||
}
|
||||
if req.InAppEnabled != nil {
|
||||
prefs.InAppEnabled = *req.InAppEnabled
|
||||
}
|
||||
if req.ReminderFrequency != nil {
|
||||
prefs.ReminderFrequency = *req.ReminderFrequency
|
||||
}
|
||||
|
||||
if err := h.notificationService.UpdatePreferences(c.Request.Context(), userID, prefs); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update preferences"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Preferences updated", "preferences": prefs})
|
||||
}
|
||||
@@ -0,0 +1,743 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/middleware"
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth 2.0 endpoints
|
||||
type OAuthHandler struct {
|
||||
oauthService *services.OAuthService
|
||||
totpService *services.TOTPService
|
||||
authService *services.AuthService
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuthHandler
|
||||
func NewOAuthHandler(oauthService *services.OAuthService, totpService *services.TOTPService, authService *services.AuthService) *OAuthHandler {
|
||||
return &OAuthHandler{
|
||||
oauthService: oauthService,
|
||||
totpService: totpService,
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// OAuth 2.0 Authorization Code Flow
|
||||
// ========================================
|
||||
|
||||
// Authorize handles the OAuth 2.0 authorization request
|
||||
// GET /oauth/authorize
|
||||
func (h *OAuthHandler) Authorize(c *gin.Context) {
|
||||
responseType := c.Query("response_type")
|
||||
clientID := c.Query("client_id")
|
||||
redirectURI := c.Query("redirect_uri")
|
||||
scope := c.Query("scope")
|
||||
state := c.Query("state")
|
||||
codeChallenge := c.Query("code_challenge")
|
||||
codeChallengeMethod := c.Query("code_challenge_method")
|
||||
|
||||
// Validate response_type
|
||||
if responseType != "code" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "unsupported_response_type",
|
||||
"error_description": "Only 'code' response_type is supported",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate client
|
||||
ctx := context.Background()
|
||||
client, err := h.oauthService.ValidateClient(ctx, clientID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_client",
|
||||
"error_description": "Unknown or invalid client_id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate redirect_uri
|
||||
if err := h.oauthService.ValidateRedirectURI(client, redirectURI); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Invalid redirect_uri",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate scopes
|
||||
scopes, err := h.oauthService.ValidateScopes(client, scope)
|
||||
if err != nil {
|
||||
redirectWithError(c, redirectURI, "invalid_scope", "One or more requested scopes are invalid", state)
|
||||
return
|
||||
}
|
||||
|
||||
// For public clients, PKCE is required
|
||||
if client.IsPublic && codeChallenge == "" {
|
||||
redirectWithError(c, redirectURI, "invalid_request", "PKCE code_challenge is required for public clients", state)
|
||||
return
|
||||
}
|
||||
|
||||
// Get authenticated user
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
// User not authenticated - redirect to login
|
||||
// Store authorization request in session and redirect to login
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "login_required",
|
||||
"error_description": "User must be authenticated to authorize",
|
||||
"login_url": "/auth/login",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate authorization code
|
||||
code, err := h.oauthService.GenerateAuthorizationCode(
|
||||
ctx, client, userID, redirectURI, scopes, codeChallenge, codeChallengeMethod,
|
||||
)
|
||||
if err != nil {
|
||||
redirectWithError(c, redirectURI, "server_error", "Failed to generate authorization code", state)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect with code
|
||||
redirectURL := redirectURI + "?code=" + code
|
||||
if state != "" {
|
||||
redirectURL += "&state=" + state
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
}
|
||||
|
||||
// Token handles the OAuth 2.0 token request
|
||||
// POST /oauth/token
|
||||
func (h *OAuthHandler) Token(c *gin.Context) {
|
||||
grantType := c.PostForm("grant_type")
|
||||
|
||||
switch grantType {
|
||||
case "authorization_code":
|
||||
h.tokenAuthorizationCode(c)
|
||||
case "refresh_token":
|
||||
h.tokenRefreshToken(c)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "unsupported_grant_type",
|
||||
"error_description": "Only 'authorization_code' and 'refresh_token' grant types are supported",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// tokenAuthorizationCode handles the authorization_code grant
|
||||
func (h *OAuthHandler) tokenAuthorizationCode(c *gin.Context) {
|
||||
code := c.PostForm("code")
|
||||
clientID := c.PostForm("client_id")
|
||||
redirectURI := c.PostForm("redirect_uri")
|
||||
codeVerifier := c.PostForm("code_verifier")
|
||||
|
||||
if code == "" || clientID == "" || redirectURI == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing required parameters: code, client_id, redirect_uri",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate client
|
||||
ctx := context.Background()
|
||||
client, err := h.oauthService.ValidateClient(ctx, clientID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_client",
|
||||
"error_description": "Unknown or invalid client_id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For confidential clients, validate client_secret
|
||||
if !client.IsPublic {
|
||||
clientSecret := c.PostForm("client_secret")
|
||||
if err := h.oauthService.ValidateClientSecret(client, clientSecret); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_client",
|
||||
"error_description": "Invalid client credentials",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Exchange authorization code for tokens
|
||||
tokenResponse, err := h.oauthService.ExchangeAuthorizationCode(ctx, code, clientID, redirectURI, codeVerifier)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrCodeExpired:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Authorization code has expired",
|
||||
})
|
||||
case services.ErrCodeUsed:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Authorization code has already been used",
|
||||
})
|
||||
case services.ErrPKCEVerifyFailed:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "PKCE verification failed",
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Invalid authorization code",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, tokenResponse)
|
||||
}
|
||||
|
||||
// tokenRefreshToken handles the refresh_token grant
|
||||
func (h *OAuthHandler) tokenRefreshToken(c *gin.Context) {
|
||||
refreshToken := c.PostForm("refresh_token")
|
||||
clientID := c.PostForm("client_id")
|
||||
scope := c.PostForm("scope")
|
||||
|
||||
if refreshToken == "" || clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing required parameters: refresh_token, client_id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Refresh access token
|
||||
tokenResponse, err := h.oauthService.RefreshAccessToken(ctx, refreshToken, clientID, scope)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrInvalidScope:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_scope",
|
||||
"error_description": "Requested scope exceeds original grant",
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Invalid or expired refresh token",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, tokenResponse)
|
||||
}
|
||||
|
||||
// Revoke handles token revocation
|
||||
// POST /oauth/revoke
|
||||
func (h *OAuthHandler) Revoke(c *gin.Context) {
|
||||
token := c.PostForm("token")
|
||||
tokenTypeHint := c.PostForm("token_type_hint")
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing token parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = h.oauthService.RevokeToken(ctx, token, tokenTypeHint)
|
||||
|
||||
// RFC 7009: Always return 200 OK
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
// Introspect handles token introspection (for resource servers)
|
||||
// POST /oauth/introspect
|
||||
func (h *OAuthHandler) Introspect(c *gin.Context) {
|
||||
token := c.PostForm("token")
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing token parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
claims, err := h.oauthService.ValidateAccessToken(ctx, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"active": true,
|
||||
"sub": (*claims)["sub"],
|
||||
"client_id": (*claims)["client_id"],
|
||||
"scope": (*claims)["scope"],
|
||||
"exp": (*claims)["exp"],
|
||||
"iat": (*claims)["iat"],
|
||||
"iss": (*claims)["iss"],
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// 2FA (TOTP) Endpoints
|
||||
// ========================================
|
||||
|
||||
// Setup2FA initiates 2FA setup
|
||||
// POST /auth/2fa/setup
|
||||
func (h *OAuthHandler) Setup2FA(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user email
|
||||
ctx := context.Background()
|
||||
user, err := h.authService.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Setup 2FA
|
||||
response, err := h.totpService.Setup2FA(ctx, userID, user.Email)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrTOTPAlreadyEnabled:
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "2FA is already enabled for this account"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to setup 2FA"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Verify2FASetup verifies the 2FA setup with a code
|
||||
// POST /auth/2fa/verify-setup
|
||||
func (h *OAuthHandler) Verify2FASetup(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = h.totpService.Verify2FASetup(ctx, userID, req.Code)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrTOTPAlreadyEnabled:
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "2FA is already enabled"})
|
||||
case services.ErrTOTPInvalidCode:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to verify 2FA setup"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "2FA enabled successfully"})
|
||||
}
|
||||
|
||||
// Verify2FAChallenge verifies a 2FA challenge during login
|
||||
// POST /auth/2fa/verify
|
||||
func (h *OAuthHandler) Verify2FAChallenge(c *gin.Context) {
|
||||
var req models.Verify2FAChallengeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var userID *uuid.UUID
|
||||
var err error
|
||||
|
||||
if req.RecoveryCode != "" {
|
||||
// Verify with recovery code
|
||||
userID, err = h.totpService.VerifyChallengeWithRecoveryCode(ctx, req.ChallengeID, req.RecoveryCode)
|
||||
} else {
|
||||
// Verify with TOTP code
|
||||
userID, err = h.totpService.VerifyChallenge(ctx, req.ChallengeID, req.Code)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrTOTPChallengeExpired:
|
||||
c.JSON(http.StatusGone, gin.H{"error": "2FA challenge has expired"})
|
||||
case services.ErrTOTPInvalidCode:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
|
||||
case services.ErrRecoveryCodeInvalid:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid recovery code"})
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "2FA verification failed"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get user and generate tokens
|
||||
user, err := h.authService.GetUserByID(ctx, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate access token
|
||||
accessToken, err := h.authService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate refresh token
|
||||
refreshToken, refreshTokenHash, err := h.authService.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate refresh token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Store session
|
||||
ipAddress := middleware.GetClientIP(c)
|
||||
userAgent := middleware.GetUserAgent(c)
|
||||
|
||||
// We need direct DB access for this, or we need to add a method to AuthService
|
||||
// For now, we'll return the tokens and let the caller handle session storage
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"user": map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
"role": user.Role,
|
||||
},
|
||||
"_session_hash": refreshTokenHash,
|
||||
"_ip": ipAddress,
|
||||
"_user_agent": userAgent,
|
||||
})
|
||||
}
|
||||
|
||||
// Disable2FA disables 2FA for the current user
|
||||
// POST /auth/2fa/disable
|
||||
func (h *OAuthHandler) Disable2FA(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = h.totpService.Disable2FA(ctx, userID, req.Code)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrTOTPNotEnabled:
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "2FA is not enabled"})
|
||||
case services.ErrTOTPInvalidCode:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to disable 2FA"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "2FA disabled successfully"})
|
||||
}
|
||||
|
||||
// Get2FAStatus returns the 2FA status for the current user
|
||||
// GET /auth/2fa/status
|
||||
func (h *OAuthHandler) Get2FAStatus(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
status, err := h.totpService.GetStatus(ctx, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get 2FA status"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// RegenerateRecoveryCodes generates new recovery codes
|
||||
// POST /auth/2fa/recovery-codes
|
||||
func (h *OAuthHandler) RegenerateRecoveryCodes(c *gin.Context) {
|
||||
userID, err := middleware.GetUserID(c)
|
||||
if err != nil || userID == uuid.Nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
codes, err := h.totpService.RegenerateRecoveryCodes(ctx, userID, req.Code)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrTOTPNotEnabled:
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "2FA is not enabled"})
|
||||
case services.ErrTOTPInvalidCode:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid 2FA code"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to regenerate recovery codes"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"recovery_codes": codes})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Enhanced Login with 2FA
|
||||
// ========================================
|
||||
|
||||
// LoginWith2FA handles login with optional 2FA
|
||||
// POST /auth/login
|
||||
func (h *OAuthHandler) LoginWith2FA(c *gin.Context) {
|
||||
var req models.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ipAddress := middleware.GetClientIP(c)
|
||||
userAgent := middleware.GetUserAgent(c)
|
||||
|
||||
// Attempt login
|
||||
response, err := h.authService.Login(ctx, &req, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrInvalidCredentials:
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid email or password"})
|
||||
case services.ErrAccountLocked:
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Account is temporarily locked"})
|
||||
case services.ErrAccountSuspended:
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Account is suspended"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Login failed"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check if 2FA is enabled
|
||||
twoFactorEnabled, _ := h.totpService.IsTwoFactorEnabled(ctx, response.User.ID)
|
||||
|
||||
if twoFactorEnabled {
|
||||
// Create 2FA challenge
|
||||
challengeID, err := h.totpService.CreateChallenge(ctx, response.User.ID, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create 2FA challenge"})
|
||||
return
|
||||
}
|
||||
|
||||
// Return 2FA required response
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"requires_2fa": true,
|
||||
"challenge_id": challengeID,
|
||||
"message": "2FA verification required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// No 2FA required, return tokens
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"requires_2fa": false,
|
||||
"access_token": response.AccessToken,
|
||||
"refresh_token": response.RefreshToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": response.ExpiresIn,
|
||||
"user": map[string]interface{}{
|
||||
"id": response.User.ID,
|
||||
"email": response.User.Email,
|
||||
"name": response.User.Name,
|
||||
"role": response.User.Role,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Registration with mandatory 2FA setup
|
||||
// ========================================
|
||||
|
||||
// RegisterWith2FA handles registration with mandatory 2FA setup
|
||||
// POST /auth/register
|
||||
func (h *OAuthHandler) RegisterWith2FA(c *gin.Context) {
|
||||
var req models.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Validate password strength
|
||||
if len(req.Password) < 8 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Password must be at least 8 characters"})
|
||||
return
|
||||
}
|
||||
|
||||
// Register user
|
||||
user, verificationToken, err := h.authService.Register(ctx, &req)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case services.ErrUserExists:
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "A user with this email already exists"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Registration failed"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Setup 2FA immediately
|
||||
twoFAResponse, err := h.totpService.Setup2FA(ctx, user.ID, user.Email)
|
||||
if err != nil {
|
||||
// Non-fatal - user can set up 2FA later, but log it
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Registration successful. Please verify your email.",
|
||||
"user_id": user.ID,
|
||||
"verification_token": verificationToken, // In production, this would be sent via email
|
||||
"two_factor_setup": nil,
|
||||
"two_factor_error": "Failed to initialize 2FA. Please set it up in your account settings.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "Registration successful. Please verify your email and complete 2FA setup.",
|
||||
"user_id": user.ID,
|
||||
"verification_token": verificationToken, // In production, this would be sent via email
|
||||
"two_factor_setup": map[string]interface{}{
|
||||
"secret": twoFAResponse.Secret,
|
||||
"qr_code": twoFAResponse.QRCodeDataURL,
|
||||
"recovery_codes": twoFAResponse.RecoveryCodes,
|
||||
"setup_required": true,
|
||||
"setup_endpoint": "/auth/2fa/verify-setup",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// OAuth Client Management (Admin)
|
||||
// ========================================
|
||||
|
||||
// AdminCreateClient creates a new OAuth client
|
||||
// POST /admin/oauth/clients
|
||||
func (h *OAuthHandler) AdminCreateClient(c *gin.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
RedirectURIs []string `json:"redirect_uris" binding:"required"`
|
||||
Scopes []string `json:"scopes"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
IsPublic bool `json:"is_public"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := middleware.GetUserID(c)
|
||||
|
||||
// Default scopes
|
||||
if len(req.Scopes) == 0 {
|
||||
req.Scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
// Default grant types
|
||||
if len(req.GrantTypes) == 0 {
|
||||
req.GrantTypes = []string{"authorization_code", "refresh_token"}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
client, clientSecret, err := h.oauthService.CreateClient(
|
||||
ctx, req.Name, req.Description, req.RedirectURIs, req.Scopes, req.GrantTypes, req.IsPublic, &userID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create client"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"client_id": client.ClientID,
|
||||
"name": client.Name,
|
||||
"redirect_uris": client.RedirectURIs,
|
||||
"scopes": client.Scopes,
|
||||
"grant_types": client.GrantTypes,
|
||||
"is_public": client.IsPublic,
|
||||
}
|
||||
|
||||
// Only show client_secret once for confidential clients
|
||||
if !client.IsPublic && clientSecret != "" {
|
||||
response["client_secret"] = clientSecret
|
||||
response["client_secret_warning"] = "Store this secret securely. It will not be shown again."
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// AdminGetClients lists all OAuth clients
|
||||
// GET /admin/oauth/clients
|
||||
func (h *OAuthHandler) AdminGetClients(c *gin.Context) {
|
||||
// This would need a new method in OAuthService
|
||||
// For now, return a placeholder
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"clients": []interface{}{},
|
||||
"message": "Client listing not yet implemented",
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Helper Functions
|
||||
// ========================================
|
||||
|
||||
func redirectWithError(c *gin.Context, redirectURI, errorCode, errorDescription, state string) {
|
||||
separator := "?"
|
||||
if strings.Contains(redirectURI, "?") {
|
||||
separator = "&"
|
||||
}
|
||||
|
||||
redirectURL := redirectURI + separator + "error=" + errorCode + "&error_description=" + errorDescription
|
||||
if state != "" {
|
||||
redirectURL += "&state=" + state
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
}
|
||||
@@ -0,0 +1,933 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SchoolHandlers contains all school-related HTTP handlers
|
||||
type SchoolHandlers struct {
|
||||
schoolService *services.SchoolService
|
||||
attendanceService *services.AttendanceService
|
||||
gradeService *services.GradeService
|
||||
}
|
||||
|
||||
// NewSchoolHandlers creates new school handlers
|
||||
func NewSchoolHandlers(schoolService *services.SchoolService, attendanceService *services.AttendanceService, gradeService *services.GradeService) *SchoolHandlers {
|
||||
return &SchoolHandlers{
|
||||
schoolService: schoolService,
|
||||
attendanceService: attendanceService,
|
||||
gradeService: gradeService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// School Handlers
|
||||
// ========================================
|
||||
|
||||
// CreateSchool creates a new school
|
||||
// POST /api/v1/schools
|
||||
func (h *SchoolHandlers) CreateSchool(c *gin.Context) {
|
||||
var req models.CreateSchoolRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
school, err := h.schoolService.CreateSchool(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, school)
|
||||
}
|
||||
|
||||
// GetSchool retrieves a school by ID
|
||||
// GET /api/v1/schools/:id
|
||||
func (h *SchoolHandlers) GetSchool(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
school, err := h.schoolService.GetSchool(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, school)
|
||||
}
|
||||
|
||||
// ListSchools lists all schools
|
||||
// GET /api/v1/schools
|
||||
func (h *SchoolHandlers) ListSchools(c *gin.Context) {
|
||||
schools, err := h.schoolService.ListSchools(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, schools)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// School Year Handlers
|
||||
// ========================================
|
||||
|
||||
// CreateSchoolYear creates a new school year
|
||||
// POST /api/v1/schools/:id/years
|
||||
func (h *SchoolHandlers) CreateSchoolYear(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
StartDate string `json:"start_date" binding:"required"`
|
||||
EndDate string `json:"end_date" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
startDate, err := time.Parse("2006-01-02", req.StartDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid start date format"})
|
||||
return
|
||||
}
|
||||
|
||||
endDate, err := time.Parse("2006-01-02", req.EndDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid end date format"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYear, err := h.schoolService.CreateSchoolYear(c.Request.Context(), schoolID, req.Name, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, schoolYear)
|
||||
}
|
||||
|
||||
// SetCurrentSchoolYear sets a school year as current
|
||||
// PUT /api/v1/schools/:id/years/:yearId/current
|
||||
func (h *SchoolHandlers) SetCurrentSchoolYear(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
yearIDStr := c.Param("yearId")
|
||||
yearID, err := uuid.Parse(yearIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.schoolService.SetCurrentSchoolYear(c.Request.Context(), schoolID, yearID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "school year set as current"})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Class Handlers
|
||||
// ========================================
|
||||
|
||||
// CreateClass creates a new class
|
||||
// POST /api/v1/schools/:id/classes
|
||||
func (h *SchoolHandlers) CreateClass(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.CreateClassRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
class, err := h.schoolService.CreateClass(c.Request.Context(), schoolID, req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, class)
|
||||
}
|
||||
|
||||
// GetClass retrieves a class by ID
|
||||
// GET /api/v1/classes/:id
|
||||
func (h *SchoolHandlers) GetClass(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
class, err := h.schoolService.GetClass(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, class)
|
||||
}
|
||||
|
||||
// ListClasses lists all classes for a school in a school year
|
||||
// GET /api/v1/schools/:id/classes?school_year_id=...
|
||||
func (h *SchoolHandlers) ListClasses(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearIDStr := c.Query("school_year_id")
|
||||
if schoolYearIDStr == "" {
|
||||
// Get current school year
|
||||
schoolYear, err := h.schoolService.GetCurrentSchoolYear(c.Request.Context(), schoolID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no current school year set"})
|
||||
return
|
||||
}
|
||||
schoolYearIDStr = schoolYear.ID.String()
|
||||
}
|
||||
|
||||
schoolYearID, err := uuid.Parse(schoolYearIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
|
||||
return
|
||||
}
|
||||
|
||||
classes, err := h.schoolService.ListClasses(c.Request.Context(), schoolID, schoolYearID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, classes)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Student Handlers
|
||||
// ========================================
|
||||
|
||||
// CreateStudent creates a new student
|
||||
// POST /api/v1/schools/:id/students
|
||||
func (h *SchoolHandlers) CreateStudent(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.CreateStudentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
student, err := h.schoolService.CreateStudent(c.Request.Context(), schoolID, req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, student)
|
||||
}
|
||||
|
||||
// GetStudent retrieves a student by ID
|
||||
// GET /api/v1/students/:id
|
||||
func (h *SchoolHandlers) GetStudent(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
|
||||
return
|
||||
}
|
||||
|
||||
student, err := h.schoolService.GetStudent(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, student)
|
||||
}
|
||||
|
||||
// ListStudentsByClass lists all students in a class
|
||||
// GET /api/v1/classes/:id/students
|
||||
func (h *SchoolHandlers) ListStudentsByClass(c *gin.Context) {
|
||||
classIDStr := c.Param("id")
|
||||
classID, err := uuid.Parse(classIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
students, err := h.schoolService.ListStudentsByClass(c.Request.Context(), classID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, students)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Subject Handlers
|
||||
// ========================================
|
||||
|
||||
// CreateSubject creates a new subject
|
||||
// POST /api/v1/schools/:id/subjects
|
||||
func (h *SchoolHandlers) CreateSubject(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ShortName string `json:"short_name" binding:"required"`
|
||||
Color *string `json:"color"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
subject, err := h.schoolService.CreateSubject(c.Request.Context(), schoolID, req.Name, req.ShortName, req.Color)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, subject)
|
||||
}
|
||||
|
||||
// ListSubjects lists all subjects for a school
|
||||
// GET /api/v1/schools/:id/subjects
|
||||
func (h *SchoolHandlers) ListSubjects(c *gin.Context) {
|
||||
schoolIDStr := c.Param("id")
|
||||
schoolID, err := uuid.Parse(schoolIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
subjects, err := h.schoolService.ListSubjects(c.Request.Context(), schoolID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, subjects)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Attendance Handlers
|
||||
// ========================================
|
||||
|
||||
// RecordAttendance records attendance for a student
|
||||
// POST /api/v1/attendance
|
||||
func (h *SchoolHandlers) RecordAttendance(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.RecordAttendanceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
record, err := h.attendanceService.RecordAttendance(c.Request.Context(), req, userID.(uuid.UUID))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, record)
|
||||
}
|
||||
|
||||
// RecordBulkAttendance records attendance for multiple students
|
||||
// POST /api/v1/classes/:id/attendance
|
||||
func (h *SchoolHandlers) RecordBulkAttendance(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
classIDStr := c.Param("id")
|
||||
classID, err := uuid.Parse(classIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Date string `json:"date" binding:"required"`
|
||||
SlotID string `json:"slot_id" binding:"required"`
|
||||
Records []struct {
|
||||
StudentID string `json:"student_id"`
|
||||
Status string `json:"status"`
|
||||
Note *string `json:"note"`
|
||||
} `json:"records" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slotID, err := uuid.Parse(req.SlotID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid slot ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to the expected type (without JSON tags)
|
||||
records := make([]struct {
|
||||
StudentID string
|
||||
Status string
|
||||
Note *string
|
||||
}, len(req.Records))
|
||||
for i, r := range req.Records {
|
||||
records[i] = struct {
|
||||
StudentID string
|
||||
Status string
|
||||
Note *string
|
||||
}{
|
||||
StudentID: r.StudentID,
|
||||
Status: r.Status,
|
||||
Note: r.Note,
|
||||
}
|
||||
}
|
||||
|
||||
err = h.attendanceService.RecordBulkAttendance(c.Request.Context(), classID, req.Date, slotID, records, userID.(uuid.UUID))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "attendance recorded"})
|
||||
}
|
||||
|
||||
// GetClassAttendance gets attendance for a class on a specific date
|
||||
// GET /api/v1/classes/:id/attendance?date=...
|
||||
func (h *SchoolHandlers) GetClassAttendance(c *gin.Context) {
|
||||
classIDStr := c.Param("id")
|
||||
classID, err := uuid.Parse(classIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
date := c.Query("date")
|
||||
if date == "" {
|
||||
date = time.Now().Format("2006-01-02")
|
||||
}
|
||||
|
||||
overview, err := h.attendanceService.GetAttendanceByClass(c.Request.Context(), classID, date)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, overview)
|
||||
}
|
||||
|
||||
// GetStudentAttendance gets attendance history for a student
|
||||
// GET /api/v1/students/:id/attendance?start_date=...&end_date=...
|
||||
func (h *SchoolHandlers) GetStudentAttendance(c *gin.Context) {
|
||||
studentIDStr := c.Param("id")
|
||||
studentID, err := uuid.Parse(studentIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
|
||||
return
|
||||
}
|
||||
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
|
||||
var startDate, endDate time.Time
|
||||
if startDateStr == "" {
|
||||
startDate = time.Now().AddDate(0, -1, 0) // Last month
|
||||
} else {
|
||||
startDate, _ = time.Parse("2006-01-02", startDateStr)
|
||||
}
|
||||
|
||||
if endDateStr == "" {
|
||||
endDate = time.Now()
|
||||
} else {
|
||||
endDate, _ = time.Parse("2006-01-02", endDateStr)
|
||||
}
|
||||
|
||||
records, err := h.attendanceService.GetStudentAttendance(c.Request.Context(), studentID, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, records)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Absence Report Handlers
|
||||
// ========================================
|
||||
|
||||
// ReportAbsence allows parents to report absence
|
||||
// POST /api/v1/absence/report
|
||||
func (h *SchoolHandlers) ReportAbsence(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.ReportAbsenceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
report, err := h.attendanceService.ReportAbsence(c.Request.Context(), req, userID.(uuid.UUID))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, report)
|
||||
}
|
||||
|
||||
// ConfirmAbsence allows teachers to confirm absence
|
||||
// PUT /api/v1/absence/:id/confirm
|
||||
func (h *SchoolHandlers) ConfirmAbsence(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
reportIDStr := c.Param("id")
|
||||
reportID, err := uuid.Parse(reportIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid report ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Status string `json:"status" binding:"required"` // "excused" or "unexcused"
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.attendanceService.ConfirmAbsence(c.Request.Context(), reportID, userID.(uuid.UUID), req.Status)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "absence confirmed"})
|
||||
}
|
||||
|
||||
// GetPendingAbsenceReports gets pending absence reports for a class
|
||||
// GET /api/v1/classes/:id/absence/pending
|
||||
func (h *SchoolHandlers) GetPendingAbsenceReports(c *gin.Context) {
|
||||
classIDStr := c.Param("id")
|
||||
classID, err := uuid.Parse(classIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
reports, err := h.attendanceService.GetPendingAbsenceReports(c.Request.Context(), classID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, reports)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Grade Handlers
|
||||
// ========================================
|
||||
|
||||
// CreateGrade creates a new grade
|
||||
// POST /api/v1/grades
|
||||
func (h *SchoolHandlers) CreateGrade(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.CreateGradeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get teacher ID from user ID
|
||||
teacher, err := h.schoolService.GetTeacherByUserID(c.Request.Context(), userID.(uuid.UUID))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "user is not a teacher"})
|
||||
return
|
||||
}
|
||||
|
||||
grade, err := h.gradeService.CreateGrade(c.Request.Context(), req, teacher.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, grade)
|
||||
}
|
||||
|
||||
// GetStudentGrades gets all grades for a student
|
||||
// GET /api/v1/students/:id/grades?school_year_id=...
|
||||
func (h *SchoolHandlers) GetStudentGrades(c *gin.Context) {
|
||||
studentIDStr := c.Param("id")
|
||||
studentID, err := uuid.Parse(studentIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearIDStr := c.Query("school_year_id")
|
||||
if schoolYearIDStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "school_year_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearID, err := uuid.Parse(schoolYearIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
|
||||
return
|
||||
}
|
||||
|
||||
grades, err := h.gradeService.GetStudentGrades(c.Request.Context(), studentID, schoolYearID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, grades)
|
||||
}
|
||||
|
||||
// GetClassGrades gets grades for all students in a class for a subject (Notenspiegel)
|
||||
// GET /api/v1/classes/:id/grades/:subjectId?school_year_id=...&semester=...
|
||||
func (h *SchoolHandlers) GetClassGrades(c *gin.Context) {
|
||||
classIDStr := c.Param("id")
|
||||
classID, err := uuid.Parse(classIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
subjectIDStr := c.Param("subjectId")
|
||||
subjectID, err := uuid.Parse(subjectIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid subject ID"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearIDStr := c.Query("school_year_id")
|
||||
if schoolYearIDStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "school_year_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearID, err := uuid.Parse(schoolYearIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
|
||||
return
|
||||
}
|
||||
|
||||
semesterStr := c.DefaultQuery("semester", "1")
|
||||
var semester int
|
||||
if semesterStr == "1" {
|
||||
semester = 1
|
||||
} else {
|
||||
semester = 2
|
||||
}
|
||||
|
||||
overviews, err := h.gradeService.GetClassGradesBySubject(c.Request.Context(), classID, subjectID, schoolYearID, semester)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, overviews)
|
||||
}
|
||||
|
||||
// GetGradeStatistics gets grade statistics for a class/subject
|
||||
// GET /api/v1/classes/:id/grades/:subjectId/stats?school_year_id=...&semester=...
|
||||
func (h *SchoolHandlers) GetGradeStatistics(c *gin.Context) {
|
||||
classIDStr := c.Param("id")
|
||||
classID, err := uuid.Parse(classIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
subjectIDStr := c.Param("subjectId")
|
||||
subjectID, err := uuid.Parse(subjectIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid subject ID"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearIDStr := c.Query("school_year_id")
|
||||
if schoolYearIDStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "school_year_id is required"})
|
||||
return
|
||||
}
|
||||
|
||||
schoolYearID, err := uuid.Parse(schoolYearIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school year ID"})
|
||||
return
|
||||
}
|
||||
|
||||
semesterStr := c.DefaultQuery("semester", "1")
|
||||
var semester int
|
||||
if semesterStr == "1" {
|
||||
semester = 1
|
||||
} else {
|
||||
semester = 2
|
||||
}
|
||||
|
||||
stats, err := h.gradeService.GetSubjectGradeStatistics(c.Request.Context(), classID, subjectID, schoolYearID, semester)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Parent Onboarding Handlers
|
||||
// ========================================
|
||||
|
||||
// GenerateOnboardingToken generates a QR code token for parent onboarding
|
||||
// POST /api/v1/onboarding/tokens
|
||||
func (h *SchoolHandlers) GenerateOnboardingToken(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
SchoolID string `json:"school_id" binding:"required"`
|
||||
ClassID string `json:"class_id" binding:"required"`
|
||||
StudentID string `json:"student_id" binding:"required"`
|
||||
Role string `json:"role"` // "parent" or "parent_representative"
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
schoolID, err := uuid.Parse(req.SchoolID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid school ID"})
|
||||
return
|
||||
}
|
||||
|
||||
classID, err := uuid.Parse(req.ClassID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid class ID"})
|
||||
return
|
||||
}
|
||||
|
||||
studentID, err := uuid.Parse(req.StudentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid student ID"})
|
||||
return
|
||||
}
|
||||
|
||||
role := req.Role
|
||||
if role == "" {
|
||||
role = "parent"
|
||||
}
|
||||
|
||||
token, err := h.schoolService.GenerateParentOnboardingToken(c.Request.Context(), schoolID, classID, studentID, userID.(uuid.UUID), role)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate QR code URL
|
||||
qrURL := "/onboard-parent?token=" + token.Token
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"token": token.Token,
|
||||
"qr_url": qrURL,
|
||||
"expires_at": token.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateOnboardingToken validates an onboarding token
|
||||
// GET /api/v1/onboarding/validate?token=...
|
||||
func (h *SchoolHandlers) ValidateOnboardingToken(c *gin.Context) {
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
onboardingToken, err := h.schoolService.ValidateOnboardingToken(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "invalid or expired token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get student and school info
|
||||
student, err := h.schoolService.GetStudent(c.Request.Context(), onboardingToken.StudentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
class, err := h.schoolService.GetClass(c.Request.Context(), onboardingToken.ClassID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
school, err := h.schoolService.GetSchool(c.Request.Context(), onboardingToken.SchoolID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"valid": true,
|
||||
"role": onboardingToken.Role,
|
||||
"student_name": student.FirstName + " " + student.LastName,
|
||||
"class_name": class.Name,
|
||||
"school_name": school.Name,
|
||||
"expires_at": onboardingToken.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
// RedeemOnboardingToken redeems a token and creates parent account
|
||||
// POST /api/v1/onboarding/redeem
|
||||
func (h *SchoolHandlers) RedeemOnboardingToken(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "user not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.schoolService.RedeemOnboardingToken(c.Request.Context(), req.Token, userID.(uuid.UUID))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "token redeemed successfully"})
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Register Routes
|
||||
// ========================================
|
||||
|
||||
// RegisterRoutes registers all school-related routes
|
||||
func (h *SchoolHandlers) RegisterRoutes(r *gin.RouterGroup, authMiddleware gin.HandlerFunc) {
|
||||
// Public routes (for onboarding)
|
||||
r.GET("/onboarding/validate", h.ValidateOnboardingToken)
|
||||
|
||||
// Protected routes
|
||||
protected := r.Group("")
|
||||
protected.Use(authMiddleware)
|
||||
|
||||
// Schools
|
||||
protected.POST("/schools", h.CreateSchool)
|
||||
protected.GET("/schools", h.ListSchools)
|
||||
protected.GET("/schools/:id", h.GetSchool)
|
||||
protected.POST("/schools/:id/years", h.CreateSchoolYear)
|
||||
protected.PUT("/schools/:id/years/:yearId/current", h.SetCurrentSchoolYear)
|
||||
protected.POST("/schools/:id/classes", h.CreateClass)
|
||||
protected.GET("/schools/:id/classes", h.ListClasses)
|
||||
protected.POST("/schools/:id/students", h.CreateStudent)
|
||||
protected.POST("/schools/:id/subjects", h.CreateSubject)
|
||||
protected.GET("/schools/:id/subjects", h.ListSubjects)
|
||||
|
||||
// Classes
|
||||
protected.GET("/classes/:id", h.GetClass)
|
||||
protected.GET("/classes/:id/students", h.ListStudentsByClass)
|
||||
protected.GET("/classes/:id/attendance", h.GetClassAttendance)
|
||||
protected.POST("/classes/:id/attendance", h.RecordBulkAttendance)
|
||||
protected.GET("/classes/:id/absence/pending", h.GetPendingAbsenceReports)
|
||||
protected.GET("/classes/:id/grades/:subjectId", h.GetClassGrades)
|
||||
protected.GET("/classes/:id/grades/:subjectId/stats", h.GetGradeStatistics)
|
||||
|
||||
// Students
|
||||
protected.GET("/students/:id", h.GetStudent)
|
||||
protected.GET("/students/:id/attendance", h.GetStudentAttendance)
|
||||
protected.GET("/students/:id/grades", h.GetStudentGrades)
|
||||
|
||||
// Attendance & Absence
|
||||
protected.POST("/attendance", h.RecordAttendance)
|
||||
protected.POST("/absence/report", h.ReportAbsence)
|
||||
protected.PUT("/absence/:id/confirm", h.ConfirmAbsence)
|
||||
|
||||
// Grades
|
||||
protected.POST("/grades", h.CreateGrade)
|
||||
|
||||
// Onboarding
|
||||
protected.POST("/onboarding/tokens", h.GenerateOnboardingToken)
|
||||
protected.POST("/onboarding/redeem", h.RedeemOnboardingToken)
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// InputGateConfig holds configuration for input validation.
|
||||
type InputGateConfig struct {
|
||||
// Maximum request body size (default: 10MB)
|
||||
MaxBodySize int64
|
||||
|
||||
// Maximum file upload size (default: 50MB)
|
||||
MaxFileSize int64
|
||||
|
||||
// Allowed content types
|
||||
AllowedContentTypes map[string]bool
|
||||
|
||||
// Allowed file types for uploads
|
||||
AllowedFileTypes map[string]bool
|
||||
|
||||
// Blocked file extensions
|
||||
BlockedExtensions map[string]bool
|
||||
|
||||
// Paths that allow larger uploads
|
||||
LargeUploadPaths []string
|
||||
|
||||
// Paths excluded from validation
|
||||
ExcludedPaths []string
|
||||
|
||||
// Enable strict content type checking
|
||||
StrictContentType bool
|
||||
}
|
||||
|
||||
// DefaultInputGateConfig returns sensible default configuration.
|
||||
func DefaultInputGateConfig() InputGateConfig {
|
||||
maxSize := int64(10 * 1024 * 1024) // 10MB
|
||||
if envSize := os.Getenv("MAX_REQUEST_BODY_SIZE"); envSize != "" {
|
||||
if size, err := strconv.ParseInt(envSize, 10, 64); err == nil {
|
||||
maxSize = size
|
||||
}
|
||||
}
|
||||
|
||||
return InputGateConfig{
|
||||
MaxBodySize: maxSize,
|
||||
MaxFileSize: 50 * 1024 * 1024, // 50MB
|
||||
AllowedContentTypes: map[string]bool{
|
||||
"application/json": true,
|
||||
"application/x-www-form-urlencoded": true,
|
||||
"multipart/form-data": true,
|
||||
"text/plain": true,
|
||||
},
|
||||
AllowedFileTypes: map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"application/pdf": true,
|
||||
"text/csv": true,
|
||||
"application/msword": true,
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true,
|
||||
"application/vnd.ms-excel": true,
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": true,
|
||||
},
|
||||
BlockedExtensions: map[string]bool{
|
||||
".exe": true, ".bat": true, ".cmd": true, ".com": true, ".msi": true,
|
||||
".dll": true, ".scr": true, ".pif": true, ".vbs": true, ".js": true,
|
||||
".jar": true, ".sh": true, ".ps1": true, ".app": true,
|
||||
},
|
||||
LargeUploadPaths: []string{
|
||||
"/api/v1/files/upload",
|
||||
"/api/v1/documents/upload",
|
||||
"/api/v1/attachments",
|
||||
},
|
||||
ExcludedPaths: []string{
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/health",
|
||||
},
|
||||
StrictContentType: true,
|
||||
}
|
||||
}
|
||||
|
||||
// isExcludedPath checks if path is excluded from validation.
|
||||
func (c *InputGateConfig) isExcludedPath(path string) bool {
|
||||
for _, excluded := range c.ExcludedPaths {
|
||||
if path == excluded {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLargeUploadPath checks if path allows larger uploads.
|
||||
func (c *InputGateConfig) isLargeUploadPath(path string) bool {
|
||||
for _, uploadPath := range c.LargeUploadPaths {
|
||||
if strings.HasPrefix(path, uploadPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getMaxSize returns the maximum allowed body size for the path.
|
||||
func (c *InputGateConfig) getMaxSize(path string) int64 {
|
||||
if c.isLargeUploadPath(path) {
|
||||
return c.MaxFileSize
|
||||
}
|
||||
return c.MaxBodySize
|
||||
}
|
||||
|
||||
// validateContentType validates the content type.
|
||||
func (c *InputGateConfig) validateContentType(contentType string) (bool, string) {
|
||||
if contentType == "" {
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// Extract base content type (remove charset, boundary, etc.)
|
||||
baseType := strings.Split(contentType, ";")[0]
|
||||
baseType = strings.TrimSpace(strings.ToLower(baseType))
|
||||
|
||||
if !c.AllowedContentTypes[baseType] {
|
||||
return false, "Content-Type '" + baseType + "' is not allowed"
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// hasBlockedExtension checks if filename has a blocked extension.
|
||||
func (c *InputGateConfig) hasBlockedExtension(filename string) bool {
|
||||
if filename == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerFilename := strings.ToLower(filename)
|
||||
for ext := range c.BlockedExtensions {
|
||||
if strings.HasSuffix(lowerFilename, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// InputGate returns a middleware that validates incoming request bodies.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// r.Use(middleware.InputGate())
|
||||
//
|
||||
// // Or with custom config:
|
||||
// config := middleware.DefaultInputGateConfig()
|
||||
// config.MaxBodySize = 5 * 1024 * 1024 // 5MB
|
||||
// r.Use(middleware.InputGateWithConfig(config))
|
||||
func InputGate() gin.HandlerFunc {
|
||||
return InputGateWithConfig(DefaultInputGateConfig())
|
||||
}
|
||||
|
||||
// InputGateWithConfig returns an input gate middleware with custom configuration.
|
||||
func InputGateWithConfig(config InputGateConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip excluded paths
|
||||
if config.isExcludedPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Skip validation for GET, HEAD, OPTIONS requests
|
||||
method := c.Request.Method
|
||||
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate content type for requests with body
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if config.StrictContentType {
|
||||
valid, errMsg := config.validateContentType(contentType)
|
||||
if !valid {
|
||||
c.AbortWithStatusJSON(http.StatusUnsupportedMediaType, gin.H{
|
||||
"error": "unsupported_media_type",
|
||||
"message": errMsg,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check Content-Length header
|
||||
contentLength := c.GetHeader("Content-Length")
|
||||
if contentLength != "" {
|
||||
length, err := strconv.ParseInt(contentLength, 10, 64)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_content_length",
|
||||
"message": "Invalid Content-Length header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
maxSize := config.getMaxSize(c.Request.URL.Path)
|
||||
if length > maxSize {
|
||||
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "payload_too_large",
|
||||
"message": "Request body exceeds maximum size",
|
||||
"max_size": maxSize,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set max multipart memory for file uploads
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, config.MaxFileSize)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateFileUpload validates a file upload.
|
||||
// Use this in upload handlers for detailed validation.
|
||||
func ValidateFileUpload(filename, contentType string, size int64, config *InputGateConfig) (bool, string) {
|
||||
if config == nil {
|
||||
defaultConfig := DefaultInputGateConfig()
|
||||
config = &defaultConfig
|
||||
}
|
||||
|
||||
// Check size
|
||||
if size > config.MaxFileSize {
|
||||
return false, "File size exceeds maximum allowed"
|
||||
}
|
||||
|
||||
// Check extension
|
||||
if config.hasBlockedExtension(filename) {
|
||||
return false, "File extension is not allowed"
|
||||
}
|
||||
|
||||
// Check content type
|
||||
if contentType != "" && !config.AllowedFileTypes[contentType] {
|
||||
return false, "File type '" + contentType + "' is not allowed"
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
@@ -0,0 +1,421 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInputGate_AllowsGETRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for GET request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsHEADRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.HEAD("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodHead, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for HEAD request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsOPTIONSRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.OPTIONS("/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for OPTIONS request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsValidJSONRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`{"key": "value"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "16")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for valid JSON, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_RejectsInvalidContentType(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.StrictContentType = true
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`data`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
req.Header.Set("Content-Length", "4")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnsupportedMediaType {
|
||||
t.Errorf("Expected status 415 for invalid content type, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsEmptyContentType(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`data`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
// No Content-Type header
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for empty content type, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_RejectsOversizedRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 100 // 100 bytes
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Create a body larger than 100 bytes
|
||||
largeBody := strings.Repeat("x", 200)
|
||||
body := bytes.NewBufferString(largeBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "200")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("Expected status 413 for oversized request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_AllowsLargeUploadPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 100 // 100 bytes
|
||||
config.MaxFileSize = 1000 // 1000 bytes
|
||||
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/api/v1/files/upload", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Create a body larger than MaxBodySize but smaller than MaxFileSize
|
||||
largeBody := strings.Repeat("x", 500)
|
||||
body := bytes.NewBufferString(largeBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/files/upload", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "500")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for large upload path, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_ExcludedPaths(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 10 // Very small
|
||||
config.ExcludedPaths = []string{"/health"}
|
||||
router.Use(InputGateWithConfig(config))
|
||||
|
||||
router.POST("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
|
||||
})
|
||||
|
||||
// Send oversized body to excluded path
|
||||
largeBody := strings.Repeat("x", 100)
|
||||
body := bytes.NewBufferString(largeBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/health", body)
|
||||
req.Header.Set("Content-Length", "100")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Should pass because path is excluded
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for excluded path, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_RejectsInvalidContentLength(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`data`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-Length", "invalid")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for invalid content length, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_BlockedExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
contentType string
|
||||
blocked bool
|
||||
}{
|
||||
{"malware.exe", "application/octet-stream", true},
|
||||
{"script.bat", "application/octet-stream", true},
|
||||
{"hack.cmd", "application/octet-stream", true},
|
||||
{"shell.sh", "application/octet-stream", true},
|
||||
{"powershell.ps1", "application/octet-stream", true},
|
||||
{"document.pdf", "application/pdf", false},
|
||||
{"image.jpg", "image/jpeg", false},
|
||||
{"data.csv", "text/csv", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
valid, errMsg := ValidateFileUpload(tt.filename, tt.contentType, 100, nil)
|
||||
if tt.blocked && valid {
|
||||
t.Errorf("File %s should be blocked", tt.filename)
|
||||
}
|
||||
if !tt.blocked && !valid {
|
||||
t.Errorf("File %s should not be blocked, error: %s", tt.filename, errMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_OversizedFile(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxFileSize = 1000 // 1KB
|
||||
|
||||
valid, errMsg := ValidateFileUpload("test.pdf", "application/pdf", 2000, &config)
|
||||
|
||||
if valid {
|
||||
t.Error("Should reject oversized file")
|
||||
}
|
||||
if !strings.Contains(errMsg, "size") {
|
||||
t.Errorf("Error message should mention size, got: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_ValidFile(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
valid, errMsg := ValidateFileUpload("document.pdf", "application/pdf", 1000, &config)
|
||||
|
||||
if !valid {
|
||||
t.Errorf("Should accept valid file, got error: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_InvalidContentType(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
valid, errMsg := ValidateFileUpload("file.xyz", "application/x-unknown", 100, &config)
|
||||
|
||||
if valid {
|
||||
t.Error("Should reject unknown file type")
|
||||
}
|
||||
if !strings.Contains(errMsg, "not allowed") {
|
||||
t.Errorf("Error message should mention not allowed, got: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileUpload_NilConfig(t *testing.T) {
|
||||
// Should use default config when nil is passed
|
||||
valid, _ := ValidateFileUpload("document.pdf", "application/pdf", 1000, nil)
|
||||
|
||||
if !valid {
|
||||
t.Error("Should accept valid file with nil config (uses defaults)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasBlockedExtension(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
tests := []struct {
|
||||
filename string
|
||||
blocked bool
|
||||
}{
|
||||
{"test.exe", true},
|
||||
{"TEST.EXE", true}, // Case insensitive
|
||||
{"script.BAT", true},
|
||||
{"app.APP", true},
|
||||
{"document.pdf", false},
|
||||
{"image.png", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.hasBlockedExtension(tt.filename)
|
||||
if result != tt.blocked {
|
||||
t.Errorf("File %s: expected blocked=%v, got %v", tt.filename, tt.blocked, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateContentType(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
|
||||
tests := []struct {
|
||||
contentType string
|
||||
valid bool
|
||||
}{
|
||||
{"application/json", true},
|
||||
{"application/json; charset=utf-8", true},
|
||||
{"APPLICATION/JSON", true}, // Case insensitive
|
||||
{"multipart/form-data; boundary=----WebKitFormBoundary", true},
|
||||
{"text/plain", true},
|
||||
{"application/xml", false},
|
||||
{"text/html", false},
|
||||
{"", true}, // Empty is allowed
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
valid, _ := config.validateContentType(tt.contentType)
|
||||
if valid != tt.valid {
|
||||
t.Errorf("Content-Type %q: expected valid=%v, got %v", tt.contentType, tt.valid, valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLargeUploadPath(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
config.LargeUploadPaths = []string{"/api/v1/files/upload", "/api/v1/documents"}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
isLarge bool
|
||||
}{
|
||||
{"/api/v1/files/upload", true},
|
||||
{"/api/v1/files/upload/batch", true}, // Prefix match
|
||||
{"/api/v1/documents", true},
|
||||
{"/api/v1/documents/1/attachments", true},
|
||||
{"/api/v1/users", false},
|
||||
{"/health", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.isLargeUploadPath(tt.path)
|
||||
if result != tt.isLarge {
|
||||
t.Errorf("Path %s: expected isLarge=%v, got %v", tt.path, tt.isLarge, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxSize(t *testing.T) {
|
||||
config := DefaultInputGateConfig()
|
||||
config.MaxBodySize = 100
|
||||
config.MaxFileSize = 1000
|
||||
config.LargeUploadPaths = []string{"/api/v1/files/upload"}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected int64
|
||||
}{
|
||||
{"/api/test", 100},
|
||||
{"/api/v1/files/upload", 1000},
|
||||
{"/health", 100},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.getMaxSize(tt.path)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Path %s: expected maxSize=%d, got %d", tt.path, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputGate_DefaultMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(InputGate())
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`{"key": "value"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,379 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserClaims represents the JWT claims for a user
|
||||
type UserClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CORS returns a CORS middleware
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// Allow localhost for development
|
||||
allowedOrigins := []string{
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8080",
|
||||
"https://breakpilot.app",
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, o := range allowedOrigins {
|
||||
if origin == o {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestLogger logs each request
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
|
||||
// Log only in development or for errors
|
||||
if status >= 400 {
|
||||
gin.DefaultWriter.Write([]byte(
|
||||
method + " " + path + " " +
|
||||
string(rune(status)) + " " +
|
||||
latency.String() + "\n",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiter implements a simple in-memory rate limiter
|
||||
// Configurable via RATE_LIMIT_PER_MINUTE env var (default: 500)
|
||||
func RateLimiter() gin.HandlerFunc {
|
||||
type client struct {
|
||||
count int
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
clients = make(map[string]*client)
|
||||
)
|
||||
|
||||
// Clean up old entries periodically
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, c := range clients {
|
||||
if time.Since(c.lastSeen) > time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
// Skip rate limiting for Docker internal network (172.x.x.x) and localhost
|
||||
// This prevents issues when multiple services share the same internal IP
|
||||
if strings.HasPrefix(ip, "172.") || ip == "127.0.0.1" || ip == "::1" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := clients[ip]; !exists {
|
||||
clients[ip] = &client{}
|
||||
}
|
||||
|
||||
cli := clients[ip]
|
||||
|
||||
// Reset count if more than a minute has passed
|
||||
if time.Since(cli.lastSeen) > time.Minute {
|
||||
cli.count = 0
|
||||
}
|
||||
|
||||
cli.count++
|
||||
cli.lastSeen = time.Now()
|
||||
|
||||
// Allow 500 requests per minute (increased for admin panels with many API calls)
|
||||
if cli.count > 500 {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMiddleware validates JWT tokens
|
||||
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "missing_authorization",
|
||||
"message": "Authorization header is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token from "Bearer <token>"
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_authorization",
|
||||
"message": "Authorization header must be in format: Bearer <token>",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
|
||||
// Parse and validate token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_token",
|
||||
"message": "Invalid or expired token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
// Set user info in context
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("role", claims.Role)
|
||||
c.Next()
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "invalid_claims",
|
||||
"message": "Invalid token claims",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AdminOnly ensures only admin users can access the route
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User role not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := role.(string)
|
||||
if !ok || (roleStr != "admin" && roleStr != "super_admin" && roleStr != "data_protection_officer") {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "forbidden",
|
||||
"message": "Admin access required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// DSBOnly ensures only Data Protection Officers can access the route
|
||||
// Used for critical operations like publishing legal documents (four-eyes principle)
|
||||
func DSBOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "User role not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
roleStr, ok := role.(string)
|
||||
if !ok || (roleStr != "data_protection_officer" && roleStr != "super_admin") {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "forbidden",
|
||||
"message": "Only Data Protection Officers can perform this action",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user has admin role
|
||||
func IsAdmin(c *gin.Context) bool {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
roleStr, ok := role.(string)
|
||||
return ok && (roleStr == "admin" || roleStr == "super_admin" || roleStr == "data_protection_officer")
|
||||
}
|
||||
|
||||
// IsDSB checks if the user has DSB role
|
||||
func IsDSB(c *gin.Context) bool {
|
||||
role, exists := c.Get("role")
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
roleStr, ok := role.(string)
|
||||
return ok && (roleStr == "data_protection_officer" || roleStr == "super_admin")
|
||||
}
|
||||
|
||||
// GetUserID extracts the user ID from the context
|
||||
func GetUserID(c *gin.Context) (uuid.UUID, error) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// GetClientIP returns the client's IP address
|
||||
func GetClientIP(c *gin.Context) string {
|
||||
// Check X-Forwarded-For header first (for proxied requests)
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// GetUserAgent returns the client's User-Agent
|
||||
func GetUserAgent(c *gin.Context) string {
|
||||
return c.GetHeader("User-Agent")
|
||||
}
|
||||
|
||||
// SuspensionCheckMiddleware checks if a user is suspended and restricts access
|
||||
// Suspended users can only access consent-related endpoints
|
||||
func SuspensionCheckMiddleware(pool interface{ QueryRow(ctx interface{}, sql string, args ...interface{}) interface{ Scan(dest ...interface{}) error } }) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userIDStr, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr.(string))
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check user account status
|
||||
var accountStatus string
|
||||
err = pool.QueryRow(c.Request.Context(), `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if accountStatus == "suspended" {
|
||||
// Check if current path is allowed for suspended users
|
||||
path := c.Request.URL.Path
|
||||
allowedPaths := []string{
|
||||
"/api/v1/consent",
|
||||
"/api/v1/documents",
|
||||
"/api/v1/notifications",
|
||||
"/api/v1/profile",
|
||||
"/api/v1/privacy/my-data",
|
||||
"/api/v1/auth/logout",
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, p := range allowedPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "account_suspended",
|
||||
"message": "Your account is suspended due to pending consent requirements",
|
||||
"redirect": "/consent/pending",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set suspended flag in context for handlers to use
|
||||
c.Set("account_suspended", true)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsSuspended checks if the current user's account is suspended
|
||||
func IsSuspended(c *gin.Context) bool {
|
||||
suspended, exists := c.Get("account_suspended")
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return suspended.(bool)
|
||||
}
|
||||
@@ -0,0 +1,546 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// Helper to create a valid JWT token for testing
|
||||
func createTestToken(secret string, userID, email, role string, exp time.Time) string {
|
||||
claims := UserClaims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Role: role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(exp),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte(secret))
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestCORS tests the CORS middleware
|
||||
func TestCORS(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
method string
|
||||
expectedStatus int
|
||||
expectAllowedOrigin bool
|
||||
}{
|
||||
{"localhost:3000", "http://localhost:3000", "GET", http.StatusOK, true},
|
||||
{"localhost:8000", "http://localhost:8000", "GET", http.StatusOK, true},
|
||||
{"production", "https://breakpilot.app", "GET", http.StatusOK, true},
|
||||
{"unknown origin", "https://unknown.com", "GET", http.StatusOK, false},
|
||||
{"preflight", "http://localhost:3000", "OPTIONS", http.StatusNoContent, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(tt.method, "/test", nil)
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
allowedOrigin := w.Header().Get("Access-Control-Allow-Origin")
|
||||
if tt.expectAllowedOrigin && allowedOrigin != tt.origin {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", tt.origin, allowedOrigin)
|
||||
}
|
||||
if !tt.expectAllowedOrigin && allowedOrigin != "" {
|
||||
t.Errorf("Expected no Access-Control-Allow-Origin header, got %s", allowedOrigin)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORSHeaders tests that CORS headers are set correctly
|
||||
func TestCORSHeaders(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
expectedHeaders := map[string]string{
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Origin, Content-Type, Authorization, X-Requested-With",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
|
||||
for header, expected := range expectedHeaders {
|
||||
actual := w.Header().Get(header)
|
||||
if actual != expected {
|
||||
t.Errorf("Expected %s to be %s, got %s", header, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_ValidToken tests authentication with valid token
|
||||
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
||||
secret := "test-secret-key"
|
||||
userID := uuid.New().String()
|
||||
email := "test@example.com"
|
||||
role := "user"
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware(secret))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
uid, _ := c.Get("user_id")
|
||||
em, _ := c.Get("email")
|
||||
r, _ := c.Get("role")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": uid,
|
||||
"email": em,
|
||||
"role": r,
|
||||
})
|
||||
})
|
||||
|
||||
token := createTestToken(secret, userID, email, role, time.Now().Add(time.Hour))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_MissingHeader tests authentication without header
|
||||
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware("test-secret"))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_InvalidFormat tests authentication with invalid header format
|
||||
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware("test-secret"))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"no Bearer prefix", "some-token"},
|
||||
{"Basic auth", "Basic dXNlcjpwYXNz"},
|
||||
{"empty Bearer", "Bearer "},
|
||||
{"multiple spaces", "Bearer token"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", tt.header)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_ExpiredToken tests authentication with expired token
|
||||
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware(secret))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
// Create expired token
|
||||
token := createTestToken(secret, "user-123", "test@example.com", "user", time.Now().Add(-time.Hour))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_WrongSecret tests authentication with wrong secret
|
||||
func TestAuthMiddleware_WrongSecret(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware("correct-secret"))
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
// Create token with different secret
|
||||
token := createTestToken("wrong-secret", "user-123", "test@example.com", "user", time.Now().Add(time.Hour))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminOnly tests the AdminOnly middleware
|
||||
func TestAdminOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"admin allowed", "admin", http.StatusOK},
|
||||
{"super_admin allowed", "super_admin", http.StatusOK},
|
||||
{"dpo allowed", "data_protection_officer", http.StatusOK},
|
||||
{"user forbidden", "user", http.StatusForbidden},
|
||||
{"empty role forbidden", "", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("role", tt.role)
|
||||
c.Next()
|
||||
})
|
||||
router.Use(AdminOnly())
|
||||
router.GET("/admin", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/admin", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminOnly_NoRole tests AdminOnly when role is not set
|
||||
func TestAdminOnly_NoRole(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(AdminOnly())
|
||||
router.GET("/admin", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/admin", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSBOnly tests the DSBOnly middleware
|
||||
func TestDSBOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"dpo allowed", "data_protection_officer", http.StatusOK},
|
||||
{"super_admin allowed", "super_admin", http.StatusOK},
|
||||
{"admin forbidden", "admin", http.StatusForbidden},
|
||||
{"user forbidden", "user", http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("role", tt.role)
|
||||
c.Next()
|
||||
})
|
||||
router.Use(DSBOnly())
|
||||
router.GET("/dsb", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/dsb", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAdmin tests the IsAdmin helper function
|
||||
func TestIsAdmin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expected bool
|
||||
}{
|
||||
{"admin", "admin", true},
|
||||
{"super_admin", "super_admin", true},
|
||||
{"dpo", "data_protection_officer", true},
|
||||
{"user", "user", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.role != "" {
|
||||
c.Set("role", tt.role)
|
||||
}
|
||||
|
||||
result := IsAdmin(c)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsAdmin to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsDSB tests the IsDSB helper function
|
||||
func TestIsDSB(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
role string
|
||||
expected bool
|
||||
}{
|
||||
{"dpo", "data_protection_officer", true},
|
||||
{"super_admin", "super_admin", true},
|
||||
{"admin", "admin", false},
|
||||
{"user", "user", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Set("role", tt.role)
|
||||
|
||||
result := IsDSB(c)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected IsDSB to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserID tests the GetUserID helper function
|
||||
func TestGetUserID(t *testing.T) {
|
||||
validUUID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
setUserID bool
|
||||
expectError bool
|
||||
expectedID uuid.UUID
|
||||
}{
|
||||
{"valid UUID", validUUID.String(), true, false, validUUID},
|
||||
{"invalid UUID", "not-a-uuid", true, true, uuid.Nil},
|
||||
{"missing user_id", "", false, false, uuid.Nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setUserID {
|
||||
c.Set("user_id", tt.userID)
|
||||
}
|
||||
|
||||
result, err := GetUserID(c)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
|
||||
if !tt.expectError && result != tt.expectedID {
|
||||
t.Errorf("Expected %v, got %v", tt.expectedID, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetClientIP tests the GetClientIP helper function
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
xff string
|
||||
xri string
|
||||
clientIP string
|
||||
expectedIP string
|
||||
}{
|
||||
{"X-Forwarded-For", "10.0.0.1", "", "192.168.1.1", "10.0.0.1"},
|
||||
{"X-Forwarded-For multiple", "10.0.0.1, 10.0.0.2", "", "192.168.1.1", "10.0.0.1"},
|
||||
{"X-Real-IP", "", "10.0.0.1", "192.168.1.1", "10.0.0.1"},
|
||||
{"direct", "", "", "192.168.1.1", "192.168.1.1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
if tt.xff != "" {
|
||||
c.Request.Header.Set("X-Forwarded-For", tt.xff)
|
||||
}
|
||||
if tt.xri != "" {
|
||||
c.Request.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
c.Request.RemoteAddr = tt.clientIP + ":12345"
|
||||
|
||||
result := GetClientIP(c)
|
||||
|
||||
// Note: gin.ClientIP() might return different values
|
||||
// depending on trusted proxies config
|
||||
if result != tt.expectedIP && result != tt.clientIP {
|
||||
t.Logf("Note: GetClientIP returned %s (expected %s or %s)", result, tt.expectedIP, tt.clientIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserAgent tests the GetUserAgent helper function
|
||||
func TestGetUserAgent(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
|
||||
expectedUA := "Mozilla/5.0 (Test)"
|
||||
c.Request.Header.Set("User-Agent", expectedUA)
|
||||
|
||||
result := GetUserAgent(c)
|
||||
if result != expectedUA {
|
||||
t.Errorf("Expected %s, got %s", expectedUA, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsSuspended tests the IsSuspended helper function
|
||||
func TestIsSuspended(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
suspended interface{}
|
||||
setSuspended bool
|
||||
expected bool
|
||||
}{
|
||||
{"suspended true", true, true, true},
|
||||
{"suspended false", false, true, false},
|
||||
{"not set", nil, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
if tt.setSuspended {
|
||||
c.Set("account_suspended", tt.suspended)
|
||||
}
|
||||
|
||||
result := IsSuspended(c)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCORS benchmarks the CORS middleware
|
||||
func BenchmarkCORS(b *testing.B) {
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAuthMiddleware benchmarks the auth middleware
|
||||
func BenchmarkAuthMiddleware(b *testing.B) {
|
||||
secret := "test-secret-key"
|
||||
token := createTestToken(secret, uuid.New().String(), "test@example.com", "user", time.Now().Add(time.Hour))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(AuthMiddleware(secret))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PIIPattern defines a pattern for identifying PII.
|
||||
type PIIPattern struct {
|
||||
Name string
|
||||
Pattern *regexp.Regexp
|
||||
Replacement string
|
||||
}
|
||||
|
||||
// PIIRedactor redacts personally identifiable information from strings.
|
||||
type PIIRedactor struct {
|
||||
patterns []*PIIPattern
|
||||
}
|
||||
|
||||
// Pre-compiled patterns for common PII types
|
||||
var (
|
||||
emailPattern = regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b`)
|
||||
ipv4Pattern = regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b`)
|
||||
ipv6Pattern = regexp.MustCompile(`\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b`)
|
||||
phonePattern = regexp.MustCompile(`(?:\+49|0049)[\s.-]?\d{2,4}[\s.-]?\d{3,8}|\b0\d{2,4}[\s.-]?\d{3,8}\b`)
|
||||
ibanPattern = regexp.MustCompile(`(?i)\b[A-Z]{2}\d{2}[\s]?(?:\d{4}[\s]?){3,5}\d{1,4}\b`)
|
||||
uuidPattern = regexp.MustCompile(`(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b`)
|
||||
namePattern = regexp.MustCompile(`\b(?:Herr|Frau|Hr\.|Fr\.)\s+[A-ZÄÖÜ][a-zäöüß]+(?:\s+[A-ZÄÖÜ][a-zäöüß]+)?\b`)
|
||||
)
|
||||
|
||||
// DefaultPIIPatterns returns the default set of PII patterns.
|
||||
func DefaultPIIPatterns() []*PIIPattern {
|
||||
return []*PIIPattern{
|
||||
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
|
||||
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
|
||||
}
|
||||
}
|
||||
|
||||
// AllPIIPatterns returns all available PII patterns.
|
||||
func AllPIIPatterns() []*PIIPattern {
|
||||
return []*PIIPattern{
|
||||
{Name: "email", Pattern: emailPattern, Replacement: "[EMAIL_REDACTED]"},
|
||||
{Name: "ip_v4", Pattern: ipv4Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "ip_v6", Pattern: ipv6Pattern, Replacement: "[IP_REDACTED]"},
|
||||
{Name: "phone", Pattern: phonePattern, Replacement: "[PHONE_REDACTED]"},
|
||||
{Name: "iban", Pattern: ibanPattern, Replacement: "[IBAN_REDACTED]"},
|
||||
{Name: "uuid", Pattern: uuidPattern, Replacement: "[UUID_REDACTED]"},
|
||||
{Name: "name", Pattern: namePattern, Replacement: "[NAME_REDACTED]"},
|
||||
}
|
||||
}
|
||||
|
||||
// NewPIIRedactor creates a new PII redactor with the given patterns.
|
||||
func NewPIIRedactor(patterns []*PIIPattern) *PIIRedactor {
|
||||
if patterns == nil {
|
||||
patterns = DefaultPIIPatterns()
|
||||
}
|
||||
return &PIIRedactor{patterns: patterns}
|
||||
}
|
||||
|
||||
// NewDefaultPIIRedactor creates a PII redactor with default patterns.
|
||||
func NewDefaultPIIRedactor() *PIIRedactor {
|
||||
return NewPIIRedactor(DefaultPIIPatterns())
|
||||
}
|
||||
|
||||
// Redact removes PII from the given text.
|
||||
func (r *PIIRedactor) Redact(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
result := text
|
||||
for _, pattern := range r.patterns {
|
||||
result = pattern.Pattern.ReplaceAllString(result, pattern.Replacement)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContainsPII checks if the text contains any PII.
|
||||
func (r *PIIRedactor) ContainsPII(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range r.patterns {
|
||||
if pattern.Pattern.MatchString(text) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PIIFinding represents a found PII instance.
|
||||
type PIIFinding struct {
|
||||
Type string
|
||||
Match string
|
||||
Start int
|
||||
End int
|
||||
}
|
||||
|
||||
// FindPII finds all PII in the text.
|
||||
func (r *PIIRedactor) FindPII(text string) []PIIFinding {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var findings []PIIFinding
|
||||
for _, pattern := range r.patterns {
|
||||
matches := pattern.Pattern.FindAllStringIndex(text, -1)
|
||||
for _, match := range matches {
|
||||
findings = append(findings, PIIFinding{
|
||||
Type: pattern.Name,
|
||||
Match: text[match[0]:match[1]],
|
||||
Start: match[0],
|
||||
End: match[1],
|
||||
})
|
||||
}
|
||||
}
|
||||
return findings
|
||||
}
|
||||
|
||||
// Default module-level redactor
|
||||
var defaultRedactor = NewDefaultPIIRedactor()
|
||||
|
||||
// RedactPII is a convenience function that uses the default redactor.
|
||||
func RedactPII(text string) string {
|
||||
return defaultRedactor.Redact(text)
|
||||
}
|
||||
|
||||
// ContainsPIIDefault checks if text contains PII using default patterns.
|
||||
func ContainsPIIDefault(text string) bool {
|
||||
return defaultRedactor.ContainsPII(text)
|
||||
}
|
||||
|
||||
// RedactMap redacts PII from all string values in a map.
|
||||
func RedactMap(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
for key, value := range data {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
result[key] = RedactPII(v)
|
||||
case map[string]interface{}:
|
||||
result[key] = RedactMap(v)
|
||||
case []interface{}:
|
||||
result[key] = redactSlice(v)
|
||||
default:
|
||||
result[key] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func redactSlice(data []interface{}) []interface{} {
|
||||
result := make([]interface{}, len(data))
|
||||
for i, value := range data {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
result[i] = RedactPII(v)
|
||||
case map[string]interface{}:
|
||||
result[i] = RedactMap(v)
|
||||
case []interface{}:
|
||||
result[i] = redactSlice(v)
|
||||
default:
|
||||
result[i] = v
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SafeLogString creates a safe-to-log version of sensitive data.
|
||||
// Use this for logging user-related information.
|
||||
func SafeLogString(format string, args ...interface{}) string {
|
||||
// Convert args to strings and redact
|
||||
safeArgs := make([]interface{}, len(args))
|
||||
for i, arg := range args {
|
||||
switch v := arg.(type) {
|
||||
case string:
|
||||
safeArgs[i] = RedactPII(v)
|
||||
case error:
|
||||
safeArgs[i] = RedactPII(v.Error())
|
||||
default:
|
||||
safeArgs[i] = arg
|
||||
}
|
||||
}
|
||||
|
||||
// Note: We can't use fmt.Sprintf here due to the variadic nature
|
||||
// Instead, we redact the result
|
||||
result := format
|
||||
for _, arg := range safeArgs {
|
||||
if s, ok := arg.(string); ok {
|
||||
result = strings.Replace(result, "%s", s, 1)
|
||||
result = strings.Replace(result, "%v", s, 1)
|
||||
}
|
||||
}
|
||||
return RedactPII(result)
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPIIRedactor_RedactsEmail(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "User test@example.com logged in"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("Email should have been redacted")
|
||||
}
|
||||
if result != "User [EMAIL_REDACTED] logged in" {
|
||||
t.Errorf("Unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_RedactsIPv4(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "Request from 192.168.1.100"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("IP should have been redacted")
|
||||
}
|
||||
if result != "Request from [IP_REDACTED]" {
|
||||
t.Errorf("Unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_RedactsGermanPhone(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"+49 30 12345678", "[PHONE_REDACTED]"},
|
||||
{"0049 30 12345678", "[PHONE_REDACTED]"},
|
||||
{"030 12345678", "[PHONE_REDACTED]"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := redactor.Redact(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("For input %q: expected %q, got %q", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_RedactsMultiplePII(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "User test@example.com from 10.0.0.1"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result != "User [EMAIL_REDACTED] from [IP_REDACTED]" {
|
||||
t.Errorf("Unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_PreservesNonPIIText(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "User logged in successfully"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result != text {
|
||||
t.Errorf("Text should be unchanged: got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIIRedactor_EmptyString(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
result := redactor.Redact("")
|
||||
if result != "" {
|
||||
t.Error("Empty string should remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsPII(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"test@example.com", true},
|
||||
{"192.168.1.1", true},
|
||||
{"+49 30 12345678", true},
|
||||
{"Hello World", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := redactor.ContainsPII(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("For input %q: expected %v, got %v", tt.input, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPII(t *testing.T) {
|
||||
redactor := NewDefaultPIIRedactor()
|
||||
|
||||
text := "Email: test@example.com, IP: 10.0.0.1"
|
||||
findings := redactor.FindPII(text)
|
||||
|
||||
if len(findings) != 2 {
|
||||
t.Errorf("Expected 2 findings, got %d", len(findings))
|
||||
}
|
||||
|
||||
hasEmail := false
|
||||
hasIP := false
|
||||
for _, f := range findings {
|
||||
if f.Type == "email" {
|
||||
hasEmail = true
|
||||
}
|
||||
if f.Type == "ip_v4" {
|
||||
hasIP = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasEmail {
|
||||
t.Error("Should have found email")
|
||||
}
|
||||
if !hasIP {
|
||||
t.Error("Should have found IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactPII_GlobalFunction(t *testing.T) {
|
||||
text := "User test@example.com logged in"
|
||||
result := RedactPII(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("Email should have been redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsPIIDefault(t *testing.T) {
|
||||
if !ContainsPIIDefault("test@example.com") {
|
||||
t.Error("Should detect email as PII")
|
||||
}
|
||||
if ContainsPIIDefault("Hello World") {
|
||||
t.Error("Should not detect non-PII text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactMap(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"message": "Hello World",
|
||||
"nested": map[string]interface{}{
|
||||
"ip": "192.168.1.1",
|
||||
},
|
||||
}
|
||||
|
||||
result := RedactMap(data)
|
||||
|
||||
if result["email"] != "[EMAIL_REDACTED]" {
|
||||
t.Errorf("Email should be redacted: %v", result["email"])
|
||||
}
|
||||
if result["message"] != "Hello World" {
|
||||
t.Errorf("Non-PII should be unchanged: %v", result["message"])
|
||||
}
|
||||
|
||||
nested := result["nested"].(map[string]interface{})
|
||||
if nested["ip"] != "[IP_REDACTED]" {
|
||||
t.Errorf("Nested IP should be redacted: %v", nested["ip"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllPIIPatterns(t *testing.T) {
|
||||
patterns := AllPIIPatterns()
|
||||
|
||||
if len(patterns) == 0 {
|
||||
t.Error("Should have PII patterns")
|
||||
}
|
||||
|
||||
// Check that we have the expected patterns
|
||||
expectedNames := []string{"email", "ip_v4", "ip_v6", "phone", "iban", "uuid", "name"}
|
||||
nameMap := make(map[string]bool)
|
||||
for _, p := range patterns {
|
||||
nameMap[p.Name] = true
|
||||
}
|
||||
|
||||
for _, name := range expectedNames {
|
||||
if !nameMap[name] {
|
||||
t.Errorf("Missing expected pattern: %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPIIPatterns(t *testing.T) {
|
||||
patterns := DefaultPIIPatterns()
|
||||
|
||||
if len(patterns) != 4 {
|
||||
t.Errorf("Expected 4 default patterns, got %d", len(patterns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIBANRedaction(t *testing.T) {
|
||||
redactor := NewPIIRedactor(AllPIIPatterns())
|
||||
|
||||
text := "IBAN: DE89 3704 0044 0532 0130 00"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("IBAN should have been redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDRedaction(t *testing.T) {
|
||||
redactor := NewPIIRedactor(AllPIIPatterns())
|
||||
|
||||
text := "User ID: a0000000-0000-0000-0000-000000000001"
|
||||
result := redactor.Redact(text)
|
||||
|
||||
if result == text {
|
||||
t.Error("UUID should have been redacted")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// RequestIDHeader is the primary header for request IDs
|
||||
RequestIDHeader = "X-Request-ID"
|
||||
// CorrelationIDHeader is an alternative header for distributed tracing
|
||||
CorrelationIDHeader = "X-Correlation-ID"
|
||||
// RequestIDKey is the context key for storing the request ID
|
||||
RequestIDKey = "request_id"
|
||||
)
|
||||
|
||||
// RequestID returns a middleware that generates and propagates request IDs.
|
||||
//
|
||||
// For each incoming request:
|
||||
// 1. Check for existing X-Request-ID or X-Correlation-ID header
|
||||
// 2. If not present, generate a new UUID
|
||||
// 3. Store in Gin context for use by handlers and logging
|
||||
// 4. Add to response headers
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// r.Use(middleware.RequestID())
|
||||
//
|
||||
// func handler(c *gin.Context) {
|
||||
// requestID := middleware.GetRequestID(c)
|
||||
// log.Printf("[%s] Processing request", requestID)
|
||||
// }
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Try to get existing request ID from headers
|
||||
requestID := c.GetHeader(RequestIDHeader)
|
||||
if requestID == "" {
|
||||
requestID = c.GetHeader(CorrelationIDHeader)
|
||||
}
|
||||
|
||||
// Generate new ID if not provided
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Store in context for handlers and logging
|
||||
c.Set(RequestIDKey, requestID)
|
||||
|
||||
// Add to response headers
|
||||
c.Header(RequestIDHeader, requestID)
|
||||
c.Header(CorrelationIDHeader, requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the Gin context.
|
||||
// Returns empty string if no request ID is set.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// requestID := middleware.GetRequestID(c)
|
||||
func GetRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get(RequestIDKey); exists {
|
||||
if idStr, ok := id.(string); ok {
|
||||
return idStr
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RequestIDFromContext is an alias for GetRequestID for API compatibility.
|
||||
func RequestIDFromContext(c *gin.Context) string {
|
||||
return GetRequestID(c)
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestRequestID_GeneratesNewID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID == "" {
|
||||
t.Error("Expected request ID to be set")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check response header
|
||||
requestID := w.Header().Get(RequestIDHeader)
|
||||
if requestID == "" {
|
||||
t.Error("Expected X-Request-ID header in response")
|
||||
}
|
||||
|
||||
// Check correlation ID header
|
||||
correlationID := w.Header().Get(CorrelationIDHeader)
|
||||
if correlationID == "" {
|
||||
t.Error("Expected X-Correlation-ID header in response")
|
||||
}
|
||||
|
||||
if requestID != correlationID {
|
||||
t.Error("X-Request-ID and X-Correlation-ID should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_PropagatesExistingID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
customID := "custom-request-id-12345"
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID != customID {
|
||||
t.Errorf("Expected request ID %s, got %s", customID, requestID)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"request_id": requestID})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(RequestIDHeader, customID)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
responseID := w.Header().Get(RequestIDHeader)
|
||||
if responseID != customID {
|
||||
t.Errorf("Expected response header %s, got %s", customID, responseID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID_PropagatesCorrelationID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
correlationID := "correlation-id-67890"
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID != correlationID {
|
||||
t.Errorf("Expected request ID %s, got %s", correlationID, requestID)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(CorrelationIDHeader, correlationID)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Both headers should be set with the correlation ID
|
||||
if w.Header().Get(RequestIDHeader) != correlationID {
|
||||
t.Error("X-Request-ID should match X-Correlation-ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequestID_ReturnsEmptyWhenNotSet(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// No RequestID middleware
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
requestID := GetRequestID(c)
|
||||
if requestID != "" {
|
||||
t.Errorf("Expected empty request ID, got %s", requestID)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestIDFromContext_IsAliasForGetRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestID())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
id1 := GetRequestID(c)
|
||||
id2 := RequestIDFromContext(c)
|
||||
if id1 != id2 {
|
||||
t.Errorf("GetRequestID and RequestIDFromContext should return same value")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SecurityHeadersConfig holds configuration for security headers.
|
||||
type SecurityHeadersConfig struct {
|
||||
// X-Content-Type-Options
|
||||
ContentTypeOptions string
|
||||
|
||||
// X-Frame-Options
|
||||
FrameOptions string
|
||||
|
||||
// X-XSS-Protection (legacy but useful for older browsers)
|
||||
XSSProtection string
|
||||
|
||||
// Strict-Transport-Security
|
||||
HSTSEnabled bool
|
||||
HSTSMaxAge int
|
||||
HSTSIncludeSubdomains bool
|
||||
HSTSPreload bool
|
||||
|
||||
// Content-Security-Policy
|
||||
CSPEnabled bool
|
||||
CSPPolicy string
|
||||
|
||||
// Referrer-Policy
|
||||
ReferrerPolicy string
|
||||
|
||||
// Permissions-Policy
|
||||
PermissionsPolicy string
|
||||
|
||||
// Cross-Origin headers
|
||||
CrossOriginOpenerPolicy string
|
||||
CrossOriginResourcePolicy string
|
||||
|
||||
// Development mode (relaxes some restrictions)
|
||||
DevelopmentMode bool
|
||||
|
||||
// Excluded paths (e.g., health checks)
|
||||
ExcludedPaths []string
|
||||
}
|
||||
|
||||
// DefaultSecurityHeadersConfig returns sensible default configuration.
|
||||
func DefaultSecurityHeadersConfig() SecurityHeadersConfig {
|
||||
env := os.Getenv("ENVIRONMENT")
|
||||
isDev := env == "" || strings.ToLower(env) == "development" || strings.ToLower(env) == "dev"
|
||||
|
||||
return SecurityHeadersConfig{
|
||||
ContentTypeOptions: "nosniff",
|
||||
FrameOptions: "DENY",
|
||||
XSSProtection: "1; mode=block",
|
||||
HSTSEnabled: true,
|
||||
HSTSMaxAge: 31536000, // 1 year
|
||||
HSTSIncludeSubdomains: true,
|
||||
HSTSPreload: false,
|
||||
CSPEnabled: true,
|
||||
CSPPolicy: getDefaultCSP(isDev),
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
PermissionsPolicy: "geolocation=(), microphone=(), camera=()",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
DevelopmentMode: isDev,
|
||||
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultCSP returns a sensible default CSP for the environment.
|
||||
func getDefaultCSP(isDevelopment bool) string {
|
||||
if isDevelopment {
|
||||
return "default-src 'self' localhost:* ws://localhost:*; " +
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"img-src 'self' data: https: blob:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' localhost:* ws://localhost:* https:; " +
|
||||
"frame-ancestors 'self'"
|
||||
}
|
||||
return "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"img-src 'self' data: https:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' https://breakpilot.app https://*.breakpilot.app; " +
|
||||
"frame-ancestors 'none'"
|
||||
}
|
||||
|
||||
// buildHSTSHeader builds the Strict-Transport-Security header value.
|
||||
func (c *SecurityHeadersConfig) buildHSTSHeader() string {
|
||||
parts := []string{"max-age=" + strconv.Itoa(c.HSTSMaxAge)}
|
||||
if c.HSTSIncludeSubdomains {
|
||||
parts = append(parts, "includeSubDomains")
|
||||
}
|
||||
if c.HSTSPreload {
|
||||
parts = append(parts, "preload")
|
||||
}
|
||||
return strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// isExcludedPath checks if the path should be excluded from security headers.
|
||||
func (c *SecurityHeadersConfig) isExcludedPath(path string) bool {
|
||||
for _, excluded := range c.ExcludedPaths {
|
||||
if path == excluded {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SecurityHeaders returns a middleware that adds security headers to all responses.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// r.Use(middleware.SecurityHeaders())
|
||||
//
|
||||
// // Or with custom config:
|
||||
// config := middleware.DefaultSecurityHeadersConfig()
|
||||
// config.CSPPolicy = "default-src 'self'"
|
||||
// r.Use(middleware.SecurityHeadersWithConfig(config))
|
||||
func SecurityHeaders() gin.HandlerFunc {
|
||||
return SecurityHeadersWithConfig(DefaultSecurityHeadersConfig())
|
||||
}
|
||||
|
||||
// SecurityHeadersWithConfig returns a security headers middleware with custom configuration.
|
||||
func SecurityHeadersWithConfig(config SecurityHeadersConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip for excluded paths
|
||||
if config.isExcludedPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Always add these headers
|
||||
c.Header("X-Content-Type-Options", config.ContentTypeOptions)
|
||||
c.Header("X-Frame-Options", config.FrameOptions)
|
||||
c.Header("X-XSS-Protection", config.XSSProtection)
|
||||
c.Header("Referrer-Policy", config.ReferrerPolicy)
|
||||
|
||||
// HSTS (only in production or if explicitly enabled)
|
||||
if config.HSTSEnabled && !config.DevelopmentMode {
|
||||
c.Header("Strict-Transport-Security", config.buildHSTSHeader())
|
||||
}
|
||||
|
||||
// Content-Security-Policy
|
||||
if config.CSPEnabled && config.CSPPolicy != "" {
|
||||
c.Header("Content-Security-Policy", config.CSPPolicy)
|
||||
}
|
||||
|
||||
// Permissions-Policy
|
||||
if config.PermissionsPolicy != "" {
|
||||
c.Header("Permissions-Policy", config.PermissionsPolicy)
|
||||
}
|
||||
|
||||
// Cross-Origin headers (only in production)
|
||||
if !config.DevelopmentMode {
|
||||
c.Header("Cross-Origin-Opener-Policy", config.CrossOriginOpenerPolicy)
|
||||
c.Header("Cross-Origin-Resource-Policy", config.CrossOriginResourcePolicy)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSecurityHeaders_AddsBasicHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = true // Skip HSTS and cross-origin headers
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check basic security headers
|
||||
tests := []struct {
|
||||
header string
|
||||
expected string
|
||||
}{
|
||||
{"X-Content-Type-Options", "nosniff"},
|
||||
{"X-Frame-Options", "DENY"},
|
||||
{"X-XSS-Protection", "1; mode=block"},
|
||||
{"Referrer-Policy", "strict-origin-when-cross-origin"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
value := w.Header().Get(tt.header)
|
||||
if value != tt.expected {
|
||||
t.Errorf("Header %s: expected %q, got %q", tt.header, tt.expected, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_HSTSNotAddedInDevelopment(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = true
|
||||
config.HSTSEnabled = true
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
hstsHeader := w.Header().Get("Strict-Transport-Security")
|
||||
if hstsHeader != "" {
|
||||
t.Errorf("HSTS should not be set in development mode, got: %s", hstsHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_HSTSAddedInProduction(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = false
|
||||
config.HSTSEnabled = true
|
||||
config.HSTSMaxAge = 31536000
|
||||
config.HSTSIncludeSubdomains = true
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
hstsHeader := w.Header().Get("Strict-Transport-Security")
|
||||
if hstsHeader == "" {
|
||||
t.Error("HSTS should be set in production mode")
|
||||
}
|
||||
|
||||
// Check that it contains max-age
|
||||
if hstsHeader != "max-age=31536000; includeSubDomains" {
|
||||
t.Errorf("Unexpected HSTS value: %s", hstsHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_HSTSWithPreload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = false
|
||||
config.HSTSEnabled = true
|
||||
config.HSTSMaxAge = 31536000
|
||||
config.HSTSIncludeSubdomains = true
|
||||
config.HSTSPreload = true
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
hstsHeader := w.Header().Get("Strict-Transport-Security")
|
||||
expected := "max-age=31536000; includeSubDomains; preload"
|
||||
if hstsHeader != expected {
|
||||
t.Errorf("Expected HSTS %q, got %q", expected, hstsHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CSPHeader(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.CSPEnabled = true
|
||||
config.CSPPolicy = "default-src 'self'"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
cspHeader := w.Header().Get("Content-Security-Policy")
|
||||
if cspHeader != "default-src 'self'" {
|
||||
t.Errorf("Expected CSP %q, got %q", "default-src 'self'", cspHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_NoCSPWhenDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.CSPEnabled = false
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
cspHeader := w.Header().Get("Content-Security-Policy")
|
||||
if cspHeader != "" {
|
||||
t.Errorf("CSP should not be set when disabled, got: %s", cspHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_ExcludedPaths(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.ExcludedPaths = []string{"/health", "/metrics"}
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
|
||||
})
|
||||
|
||||
router.GET("/api", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Test excluded path
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("X-Content-Type-Options") != "" {
|
||||
t.Error("Security headers should not be set for excluded paths")
|
||||
}
|
||||
|
||||
// Test non-excluded path
|
||||
req = httptest.NewRequest(http.MethodGet, "/api", nil)
|
||||
w = httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
|
||||
t.Error("Security headers should be set for non-excluded paths")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CrossOriginInProduction(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = false
|
||||
config.CrossOriginOpenerPolicy = "same-origin"
|
||||
config.CrossOriginResourcePolicy = "same-origin"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
coopHeader := w.Header().Get("Cross-Origin-Opener-Policy")
|
||||
if coopHeader != "same-origin" {
|
||||
t.Errorf("Expected COOP %q, got %q", "same-origin", coopHeader)
|
||||
}
|
||||
|
||||
corpHeader := w.Header().Get("Cross-Origin-Resource-Policy")
|
||||
if corpHeader != "same-origin" {
|
||||
t.Errorf("Expected CORP %q, got %q", "same-origin", corpHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_NoCrossOriginInDevelopment(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.DevelopmentMode = true
|
||||
config.CrossOriginOpenerPolicy = "same-origin"
|
||||
config.CrossOriginResourcePolicy = "same-origin"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Cross-Origin-Opener-Policy") != "" {
|
||||
t.Error("COOP should not be set in development mode")
|
||||
}
|
||||
|
||||
if w.Header().Get("Cross-Origin-Resource-Policy") != "" {
|
||||
t.Error("CORP should not be set in development mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_PermissionsPolicy(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
config := DefaultSecurityHeadersConfig()
|
||||
config.PermissionsPolicy = "geolocation=(), microphone=()"
|
||||
router.Use(SecurityHeadersWithConfig(config))
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
ppHeader := w.Header().Get("Permissions-Policy")
|
||||
if ppHeader != "geolocation=(), microphone=()" {
|
||||
t.Errorf("Expected Permissions-Policy %q, got %q", "geolocation=(), microphone=()", ppHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_DefaultMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// Use the default middleware function
|
||||
router.Use(SecurityHeaders())
|
||||
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// Should at least have the basic headers
|
||||
if w.Header().Get("X-Content-Type-Options") != "nosniff" {
|
||||
t.Error("Default middleware should set X-Content-Type-Options")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHSTSHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config SecurityHeadersConfig
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic HSTS",
|
||||
config: SecurityHeadersConfig{
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSIncludeSubdomains: false,
|
||||
HSTSPreload: false,
|
||||
},
|
||||
expected: "max-age=31536000",
|
||||
},
|
||||
{
|
||||
name: "HSTS with subdomains",
|
||||
config: SecurityHeadersConfig{
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSIncludeSubdomains: true,
|
||||
HSTSPreload: false,
|
||||
},
|
||||
expected: "max-age=31536000; includeSubDomains",
|
||||
},
|
||||
{
|
||||
name: "HSTS with preload",
|
||||
config: SecurityHeadersConfig{
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSIncludeSubdomains: true,
|
||||
HSTSPreload: true,
|
||||
},
|
||||
expected: "max-age=31536000; includeSubDomains; preload",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.config.buildHSTSHeader()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsExcludedPath(t *testing.T) {
|
||||
config := SecurityHeadersConfig{
|
||||
ExcludedPaths: []string{"/health", "/metrics", "/api/v1/health"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
excluded bool
|
||||
}{
|
||||
{"/health", true},
|
||||
{"/metrics", true},
|
||||
{"/api/v1/health", true},
|
||||
{"/api", false},
|
||||
{"/health/check", false},
|
||||
{"/", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := config.isExcludedPath(tt.path)
|
||||
if result != tt.excluded {
|
||||
t.Errorf("Path %s: expected excluded=%v, got %v", tt.path, tt.excluded, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,505 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/database"
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/breakpilot/consent-service/internal/services/matrix"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AttendanceService handles attendance tracking and notifications
|
||||
type AttendanceService struct {
|
||||
db *database.DB
|
||||
matrix *matrix.MatrixService
|
||||
}
|
||||
|
||||
// NewAttendanceService creates a new attendance service
|
||||
func NewAttendanceService(db *database.DB, matrixService *matrix.MatrixService) *AttendanceService {
|
||||
return &AttendanceService{
|
||||
db: db,
|
||||
matrix: matrixService,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Attendance Recording
|
||||
// ========================================
|
||||
|
||||
// RecordAttendance records a student's attendance for a specific lesson
|
||||
func (s *AttendanceService) RecordAttendance(ctx context.Context, req models.RecordAttendanceRequest, recordedByUserID uuid.UUID) (*models.AttendanceRecord, error) {
|
||||
studentID, err := uuid.Parse(req.StudentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid student ID: %w", err)
|
||||
}
|
||||
|
||||
slotID, err := uuid.Parse(req.SlotID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid slot ID: %w", err)
|
||||
}
|
||||
|
||||
date, err := time.Parse("2006-01-02", req.Date)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
record := &models.AttendanceRecord{
|
||||
ID: uuid.New(),
|
||||
StudentID: studentID,
|
||||
Date: date,
|
||||
SlotID: slotID,
|
||||
Status: req.Status,
|
||||
RecordedBy: recordedByUserID,
|
||||
Note: req.Note,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (student_id, date, slot_id)
|
||||
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = EXCLUDED.updated_at
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
record.ID, record.StudentID, record.Date, record.SlotID,
|
||||
record.Status, record.RecordedBy, record.Note, record.CreatedAt, record.UpdatedAt,
|
||||
).Scan(&record.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to record attendance: %w", err)
|
||||
}
|
||||
|
||||
// If student is absent, send notification to parents
|
||||
if record.Status == models.AttendanceAbsent || record.Status == models.AttendancePending {
|
||||
go s.notifyParentsOfAbsence(context.Background(), record)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RecordBulkAttendance records attendance for multiple students at once
|
||||
func (s *AttendanceService) RecordBulkAttendance(ctx context.Context, classID uuid.UUID, date string, slotID uuid.UUID, records []struct {
|
||||
StudentID string
|
||||
Status string
|
||||
Note *string
|
||||
}, recordedByUserID uuid.UUID) error {
|
||||
parsedDate, err := time.Parse("2006-01-02", date)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
for _, rec := range records {
|
||||
studentID, err := uuid.Parse(rec.StudentID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
ON CONFLICT (student_id, date, slot_id)
|
||||
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = NOW()`
|
||||
|
||||
_, err = s.db.Pool.Exec(ctx, query,
|
||||
uuid.New(), studentID, parsedDate, slotID, rec.Status, recordedByUserID, rec.Note,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to record attendance for student %s: %w", rec.StudentID, err)
|
||||
}
|
||||
|
||||
// Notify parents if absent
|
||||
if rec.Status == models.AttendanceAbsent || rec.Status == models.AttendancePending {
|
||||
go s.notifyParentsOfAbsenceByStudentID(context.Background(), studentID, parsedDate, slotID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAttendanceByClass gets attendance records for a class on a specific date
|
||||
func (s *AttendanceService) GetAttendanceByClass(ctx context.Context, classID uuid.UUID, date string) (*models.ClassAttendanceOverview, error) {
|
||||
parsedDate, err := time.Parse("2006-01-02", date)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid date format: %w", err)
|
||||
}
|
||||
|
||||
// Get class info
|
||||
classQuery := `SELECT id, school_id, school_year_id, name, grade, section, room, is_active FROM classes WHERE id = $1`
|
||||
class := &models.Class{}
|
||||
err = s.db.Pool.QueryRow(ctx, classQuery, classID).Scan(
|
||||
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
|
||||
&class.Grade, &class.Section, &class.Room, &class.IsActive,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get class: %w", err)
|
||||
}
|
||||
|
||||
// Get total students
|
||||
var totalStudents int
|
||||
err = s.db.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM students WHERE class_id = $1 AND is_active = true`, classID).Scan(&totalStudents)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to count students: %w", err)
|
||||
}
|
||||
|
||||
// Get attendance records for the date
|
||||
recordsQuery := `
|
||||
SELECT ar.id, ar.student_id, ar.date, ar.slot_id, ar.status, ar.recorded_by, ar.note, ar.created_at, ar.updated_at
|
||||
FROM attendance_records ar
|
||||
JOIN students s ON ar.student_id = s.id
|
||||
WHERE s.class_id = $1 AND ar.date = $2
|
||||
ORDER BY ar.slot_id`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, recordsQuery, classID, parsedDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attendance records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []models.AttendanceRecord
|
||||
presentCount := 0
|
||||
absentCount := 0
|
||||
lateCount := 0
|
||||
|
||||
seenStudents := make(map[uuid.UUID]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var record models.AttendanceRecord
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.StudentID, &record.Date, &record.SlotID,
|
||||
&record.Status, &record.RecordedBy, &record.Note, &record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
|
||||
}
|
||||
records = append(records, record)
|
||||
|
||||
// Count unique students for summary (use first slot's status)
|
||||
if !seenStudents[record.StudentID] {
|
||||
seenStudents[record.StudentID] = true
|
||||
switch record.Status {
|
||||
case models.AttendancePresent:
|
||||
presentCount++
|
||||
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused, models.AttendancePending:
|
||||
absentCount++
|
||||
case models.AttendanceLate, models.AttendanceLateExcused:
|
||||
lateCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &models.ClassAttendanceOverview{
|
||||
Class: *class,
|
||||
Date: parsedDate,
|
||||
TotalStudents: totalStudents,
|
||||
PresentCount: presentCount,
|
||||
AbsentCount: absentCount,
|
||||
LateCount: lateCount,
|
||||
Records: records,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStudentAttendance gets attendance history for a student
|
||||
func (s *AttendanceService) GetStudentAttendance(ctx context.Context, studentID uuid.UUID, startDate, endDate time.Time) ([]models.AttendanceRecord, error) {
|
||||
query := `
|
||||
SELECT id, student_id, timetable_entry_id, date, slot_id, status, recorded_by, note, created_at, updated_at
|
||||
FROM attendance_records
|
||||
WHERE student_id = $1 AND date >= $2 AND date <= $3
|
||||
ORDER BY date DESC, slot_id`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, studentID, startDate, endDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get student attendance: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []models.AttendanceRecord
|
||||
for rows.Next() {
|
||||
var record models.AttendanceRecord
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.StudentID, &record.TimetableEntryID, &record.Date,
|
||||
&record.SlotID, &record.Status, &record.RecordedBy, &record.Note,
|
||||
&record.CreatedAt, &record.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Absence Reports (Parent-initiated)
|
||||
// ========================================
|
||||
|
||||
// ReportAbsence allows parents to report a student's absence
|
||||
func (s *AttendanceService) ReportAbsence(ctx context.Context, req models.ReportAbsenceRequest, reportedByUserID uuid.UUID) (*models.AbsenceReport, error) {
|
||||
studentID, err := uuid.Parse(req.StudentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid student ID: %w", err)
|
||||
}
|
||||
|
||||
startDate, err := time.Parse("2006-01-02", req.StartDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start date format: %w", err)
|
||||
}
|
||||
|
||||
endDate, err := time.Parse("2006-01-02", req.EndDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end date format: %w", err)
|
||||
}
|
||||
|
||||
report := &models.AbsenceReport{
|
||||
ID: uuid.New(),
|
||||
StudentID: studentID,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
Reason: req.Reason,
|
||||
ReasonCategory: req.ReasonCategory,
|
||||
Status: "reported",
|
||||
ReportedBy: reportedByUserID,
|
||||
ReportedAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO absence_reports (id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id`
|
||||
|
||||
err = s.db.Pool.QueryRow(ctx, query,
|
||||
report.ID, report.StudentID, report.StartDate, report.EndDate,
|
||||
report.Reason, report.ReasonCategory, report.Status,
|
||||
report.ReportedBy, report.ReportedAt, report.CreatedAt, report.UpdatedAt,
|
||||
).Scan(&report.ID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create absence report: %w", err)
|
||||
}
|
||||
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// ConfirmAbsence allows teachers to confirm/excuse an absence
|
||||
func (s *AttendanceService) ConfirmAbsence(ctx context.Context, reportID uuid.UUID, confirmedByUserID uuid.UUID, status string) error {
|
||||
query := `
|
||||
UPDATE absence_reports
|
||||
SET status = $1, confirmed_by = $2, confirmed_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $3`
|
||||
|
||||
result, err := s.db.Pool.Exec(ctx, query, status, confirmedByUserID, reportID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to confirm absence: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("absence report not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAbsenceReports gets absence reports for a student
|
||||
func (s *AttendanceService) GetAbsenceReports(ctx context.Context, studentID uuid.UUID) ([]models.AbsenceReport, error) {
|
||||
query := `
|
||||
SELECT id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, confirmed_by, confirmed_at, medical_certificate, certificate_uploaded, matrix_notification_sent, email_notification_sent, created_at, updated_at
|
||||
FROM absence_reports
|
||||
WHERE student_id = $1
|
||||
ORDER BY start_date DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, studentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absence reports: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reports []models.AbsenceReport
|
||||
for rows.Next() {
|
||||
var report models.AbsenceReport
|
||||
err := rows.Scan(
|
||||
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
|
||||
&report.Reason, &report.ReasonCategory, &report.Status,
|
||||
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
|
||||
&report.MedicalCertificate, &report.CertificateUploaded,
|
||||
&report.MatrixNotificationSent, &report.EmailNotificationSent,
|
||||
&report.CreatedAt, &report.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan absence report: %w", err)
|
||||
}
|
||||
reports = append(reports, report)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// GetPendingAbsenceReports gets all unconfirmed absence reports for a class
|
||||
func (s *AttendanceService) GetPendingAbsenceReports(ctx context.Context, classID uuid.UUID) ([]models.AbsenceReport, error) {
|
||||
query := `
|
||||
SELECT ar.id, ar.student_id, ar.start_date, ar.end_date, ar.reason, ar.reason_category, ar.status, ar.reported_by, ar.reported_at, ar.confirmed_by, ar.confirmed_at, ar.medical_certificate, ar.certificate_uploaded, ar.matrix_notification_sent, ar.email_notification_sent, ar.created_at, ar.updated_at
|
||||
FROM absence_reports ar
|
||||
JOIN students s ON ar.student_id = s.id
|
||||
WHERE s.class_id = $1 AND ar.status = 'reported'
|
||||
ORDER BY ar.start_date DESC`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, classID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get pending absence reports: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reports []models.AbsenceReport
|
||||
for rows.Next() {
|
||||
var report models.AbsenceReport
|
||||
err := rows.Scan(
|
||||
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
|
||||
&report.Reason, &report.ReasonCategory, &report.Status,
|
||||
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
|
||||
&report.MedicalCertificate, &report.CertificateUploaded,
|
||||
&report.MatrixNotificationSent, &report.EmailNotificationSent,
|
||||
&report.CreatedAt, &report.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan absence report: %w", err)
|
||||
}
|
||||
reports = append(reports, report)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Attendance Statistics
|
||||
// ========================================
|
||||
|
||||
// GetStudentAttendanceStats gets attendance statistics for a student
|
||||
func (s *AttendanceService) GetStudentAttendanceStats(ctx context.Context, studentID uuid.UUID, schoolYearID uuid.UUID) (map[string]interface{}, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_records,
|
||||
COUNT(CASE WHEN status = 'present' THEN 1 END) as present_count,
|
||||
COUNT(CASE WHEN status IN ('absent', 'excused', 'unexcused', 'pending_confirmation') THEN 1 END) as absent_count,
|
||||
COUNT(CASE WHEN status = 'unexcused' THEN 1 END) as unexcused_count,
|
||||
COUNT(CASE WHEN status IN ('late', 'late_excused') THEN 1 END) as late_count
|
||||
FROM attendance_records ar
|
||||
JOIN timetable_slots ts ON ar.slot_id = ts.id
|
||||
JOIN schools sch ON ts.school_id = sch.id
|
||||
JOIN school_years sy ON sy.school_id = sch.id AND sy.id = $2
|
||||
WHERE ar.student_id = $1 AND ar.date >= sy.start_date AND ar.date <= sy.end_date`
|
||||
|
||||
var totalRecords, presentCount, absentCount, unexcusedCount, lateCount int
|
||||
err := s.db.Pool.QueryRow(ctx, query, studentID, schoolYearID).Scan(
|
||||
&totalRecords, &presentCount, &absentCount, &unexcusedCount, &lateCount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attendance stats: %w", err)
|
||||
}
|
||||
|
||||
var attendanceRate float64
|
||||
if totalRecords > 0 {
|
||||
attendanceRate = float64(presentCount) / float64(totalRecords) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_records": totalRecords,
|
||||
"present_count": presentCount,
|
||||
"absent_count": absentCount,
|
||||
"unexcused_count": unexcusedCount,
|
||||
"late_count": lateCount,
|
||||
"attendance_rate": attendanceRate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ========================================
|
||||
// Parent Notifications
|
||||
// ========================================
|
||||
|
||||
func (s *AttendanceService) notifyParentsOfAbsence(ctx context.Context, record *models.AttendanceRecord) {
|
||||
if s.matrix == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get student info
|
||||
var studentFirstName, studentLastName, matrixDMRoom string
|
||||
err := s.db.Pool.QueryRow(ctx, `
|
||||
SELECT first_name, last_name, matrix_dm_room
|
||||
FROM students
|
||||
WHERE id = $1`, record.StudentID).Scan(&studentFirstName, &studentLastName, &matrixDMRoom)
|
||||
if err != nil || matrixDMRoom == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get slot info
|
||||
var slotNumber int
|
||||
err = s.db.Pool.QueryRow(ctx, `SELECT slot_number FROM timetable_slots WHERE id = $1`, record.SlotID).Scan(&slotNumber)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
studentName := studentFirstName + " " + studentLastName
|
||||
dateStr := record.Date.Format("02.01.2006")
|
||||
|
||||
// Send Matrix notification
|
||||
err = s.matrix.SendAbsenceNotification(ctx, matrixDMRoom, studentName, dateStr, slotNumber)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to send absence notification: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update notification status
|
||||
s.db.Pool.Exec(ctx, `
|
||||
UPDATE attendance_records
|
||||
SET updated_at = NOW()
|
||||
WHERE id = $1`, record.ID)
|
||||
|
||||
// Log the notification
|
||||
s.createAbsenceNotificationLog(ctx, record.ID, studentName, dateStr, slotNumber)
|
||||
}
|
||||
|
||||
func (s *AttendanceService) notifyParentsOfAbsenceByStudentID(ctx context.Context, studentID uuid.UUID, date time.Time, slotID uuid.UUID) {
|
||||
record := &models.AttendanceRecord{
|
||||
StudentID: studentID,
|
||||
Date: date,
|
||||
SlotID: slotID,
|
||||
}
|
||||
s.notifyParentsOfAbsence(ctx, record)
|
||||
}
|
||||
|
||||
func (s *AttendanceService) createAbsenceNotificationLog(ctx context.Context, recordID uuid.UUID, studentName, dateStr string, slotNumber int) {
|
||||
// Get parent IDs for this student
|
||||
query := `
|
||||
SELECT p.id
|
||||
FROM parents p
|
||||
JOIN student_parents sp ON p.id = sp.parent_id
|
||||
JOIN attendance_records ar ON sp.student_id = ar.student_id
|
||||
WHERE ar.id = $1`
|
||||
|
||||
rows, err := s.db.Pool.Query(ctx, query, recordID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
message := fmt.Sprintf("Abwesenheitsmeldung: %s war am %s in der %d. Stunde nicht anwesend.", studentName, dateStr, slotNumber)
|
||||
|
||||
for rows.Next() {
|
||||
var parentID uuid.UUID
|
||||
if err := rows.Scan(&parentID); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Insert notification log
|
||||
s.db.Pool.Exec(ctx, `
|
||||
INSERT INTO absence_notifications (id, attendance_record_id, parent_id, channel, message_content, sent_at, created_at)
|
||||
VALUES ($1, $2, $3, 'matrix', $4, NOW(), NOW())`,
|
||||
uuid.New(), recordID, parentID, message)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestValidateAttendanceRecord tests attendance record validation
|
||||
func TestValidateAttendanceRecord(t *testing.T) {
|
||||
slotID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
record models.AttendanceRecord
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid present record",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid absent record",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendanceAbsent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid late record",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendanceLate,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "missing student ID",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.Nil,
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid status",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now(),
|
||||
Status: "invalid_status",
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "future date",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: slotID,
|
||||
Date: time.Now().AddDate(0, 0, 7),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing slot ID",
|
||||
record: models.AttendanceRecord{
|
||||
StudentID: uuid.New(),
|
||||
SlotID: uuid.Nil,
|
||||
Date: time.Now(),
|
||||
Status: models.AttendancePresent,
|
||||
RecordedBy: uuid.New(),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateAttendanceRecord(tt.record)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateAttendanceRecord validates an attendance record
|
||||
func validateAttendanceRecord(record models.AttendanceRecord) bool {
|
||||
if record.StudentID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if record.SlotID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if record.RecordedBy == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if record.Date.After(time.Now().AddDate(0, 0, 1)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate status
|
||||
validStatuses := map[string]bool{
|
||||
models.AttendancePresent: true,
|
||||
models.AttendanceAbsent: true,
|
||||
models.AttendanceAbsentExcused: true,
|
||||
models.AttendanceAbsentUnexcused: true,
|
||||
models.AttendanceLate: true,
|
||||
models.AttendanceLateExcused: true,
|
||||
models.AttendancePending: true,
|
||||
}
|
||||
|
||||
if !validStatuses[record.Status] {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TestValidateAbsenceReport tests absence report validation
|
||||
func TestValidateAbsenceReport(t *testing.T) {
|
||||
now := time.Now()
|
||||
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
||||
reason := "Krankheit"
|
||||
medicalReason := "Arzttermin"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
report models.AbsenceReport
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid single day absence",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "illness",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "valid multi-day absence",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today.AddDate(0, 0, 3),
|
||||
Reason: &medicalReason,
|
||||
ReasonCategory: "appointment",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "end before start",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today.AddDate(0, 0, 3),
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "illness",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing reason category",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid reason category",
|
||||
report: models.AbsenceReport{
|
||||
StudentID: uuid.New(),
|
||||
ReportedBy: uuid.New(),
|
||||
StartDate: today,
|
||||
EndDate: today,
|
||||
Reason: &reason,
|
||||
ReasonCategory: "invalid_type",
|
||||
Status: "reported",
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := validateAbsenceReport(tt.report)
|
||||
if isValid != tt.expectValid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validateAbsenceReport validates an absence report
|
||||
func validateAbsenceReport(report models.AbsenceReport) bool {
|
||||
if report.StudentID == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if report.ReportedBy == uuid.Nil {
|
||||
return false
|
||||
}
|
||||
if report.EndDate.Before(report.StartDate) {
|
||||
return false
|
||||
}
|
||||
if report.ReasonCategory == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate reason category
|
||||
validCategories := map[string]bool{
|
||||
"illness": true,
|
||||
"appointment": true,
|
||||
"family": true,
|
||||
"other": true,
|
||||
}
|
||||
|
||||
if !validCategories[report.ReasonCategory] {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TestCalculateAttendanceStats tests attendance statistics calculation
|
||||
func TestCalculateAttendanceStats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
records []models.AttendanceRecord
|
||||
expectedPresent int
|
||||
expectedAbsent int
|
||||
expectedLate int
|
||||
}{
|
||||
{
|
||||
name: "all present",
|
||||
records: []models.AttendanceRecord{
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendancePresent},
|
||||
},
|
||||
expectedPresent: 3,
|
||||
expectedAbsent: 0,
|
||||
expectedLate: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed attendance",
|
||||
records: []models.AttendanceRecord{
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendanceAbsent},
|
||||
{Status: models.AttendanceLate},
|
||||
{Status: models.AttendancePresent},
|
||||
{Status: models.AttendanceAbsentExcused},
|
||||
},
|
||||
expectedPresent: 2,
|
||||
expectedAbsent: 2,
|
||||
expectedLate: 1,
|
||||
},
|
||||
{
|
||||
name: "empty records",
|
||||
records: []models.AttendanceRecord{},
|
||||
expectedPresent: 0,
|
||||
expectedAbsent: 0,
|
||||
expectedLate: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
present, absent, late := calculateAttendanceStats(tt.records)
|
||||
if present != tt.expectedPresent {
|
||||
t.Errorf("expected present=%d, got present=%d", tt.expectedPresent, present)
|
||||
}
|
||||
if absent != tt.expectedAbsent {
|
||||
t.Errorf("expected absent=%d, got absent=%d", tt.expectedAbsent, absent)
|
||||
}
|
||||
if late != tt.expectedLate {
|
||||
t.Errorf("expected late=%d, got late=%d", tt.expectedLate, late)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateAttendanceStats calculates attendance statistics
|
||||
func calculateAttendanceStats(records []models.AttendanceRecord) (present, absent, late int) {
|
||||
for _, r := range records {
|
||||
switch r.Status {
|
||||
case models.AttendancePresent:
|
||||
present++
|
||||
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused:
|
||||
absent++
|
||||
case models.AttendanceLate, models.AttendanceLateExcused:
|
||||
late++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TestAttendanceRateCalculation tests attendance rate percentage calculation
|
||||
func TestAttendanceRateCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
present int
|
||||
total int
|
||||
expectedRate float64
|
||||
}{
|
||||
{
|
||||
name: "100% attendance",
|
||||
present: 26,
|
||||
total: 26,
|
||||
expectedRate: 100.0,
|
||||
},
|
||||
{
|
||||
name: "92.3% attendance",
|
||||
present: 24,
|
||||
total: 26,
|
||||
expectedRate: 92.31,
|
||||
},
|
||||
{
|
||||
name: "0% attendance",
|
||||
present: 0,
|
||||
total: 26,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
{
|
||||
name: "empty class",
|
||||
present: 0,
|
||||
total: 0,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rate := calculateAttendanceRate(tt.present, tt.total)
|
||||
// Allow small floating point differences
|
||||
if rate < tt.expectedRate-0.1 || rate > tt.expectedRate+0.1 {
|
||||
t.Errorf("expected rate=%.2f, got rate=%.2f", tt.expectedRate, rate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// calculateAttendanceRate calculates attendance rate as percentage
|
||||
func calculateAttendanceRate(present, total int) float64 {
|
||||
if total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
rate := float64(present) / float64(total) * 100
|
||||
// Round to 2 decimal places
|
||||
return float64(int(rate*100)) / 100
|
||||
}
|
||||
@@ -0,0 +1,568 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserExists = errors.New("user with this email already exists")
|
||||
ErrInvalidToken = errors.New("invalid or expired token")
|
||||
ErrAccountLocked = errors.New("account is temporarily locked")
|
||||
ErrAccountSuspended = errors.New("account is suspended")
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
)
|
||||
|
||||
// AuthService handles authentication logic
|
||||
type AuthService struct {
|
||||
db *pgxpool.Pool
|
||||
jwtSecret string
|
||||
jwtRefreshSecret string
|
||||
accessTokenExp time.Duration
|
||||
refreshTokenExp time.Duration
|
||||
}
|
||||
|
||||
// NewAuthService creates a new AuthService
|
||||
func NewAuthService(db *pgxpool.Pool, jwtSecret, jwtRefreshSecret string) *AuthService {
|
||||
return &AuthService{
|
||||
db: db,
|
||||
jwtSecret: jwtSecret,
|
||||
jwtRefreshSecret: jwtRefreshSecret,
|
||||
accessTokenExp: time.Hour * 1, // 1 hour
|
||||
refreshTokenExp: time.Hour * 24 * 30, // 30 days
|
||||
}
|
||||
}
|
||||
|
||||
// HashPassword hashes a password using bcrypt
|
||||
func (s *AuthService) HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a password against a hash
|
||||
func (s *AuthService) VerifyPassword(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GenerateSecureToken generates a cryptographically secure token
|
||||
func (s *AuthService) GenerateSecureToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// HashToken creates a SHA256 hash of a token for storage
|
||||
func (s *AuthService) HashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// JWTClaims for access tokens
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
AccountStatus string `json:"account_status"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken creates a new JWT access token
|
||||
func (s *AuthService) GenerateAccessToken(user *models.User) (string, error) {
|
||||
claims := JWTClaims{
|
||||
UserID: user.ID.String(),
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AccountStatus: user.AccountStatus,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.accessTokenExp)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Subject: user.ID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.jwtSecret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken creates a new refresh token
|
||||
func (s *AuthService) GenerateRefreshToken() (string, string, error) {
|
||||
token, err := s.GenerateSecureToken(32)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
hash := s.HashToken(token)
|
||||
return token, hash, nil
|
||||
}
|
||||
|
||||
// ValidateAccessToken validates a JWT access token
|
||||
func (s *AuthService) ValidateAccessToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Register creates a new user account
|
||||
func (s *AuthService) Register(ctx context.Context, req *models.RegisterRequest) (*models.User, string, error) {
|
||||
// Check if user already exists
|
||||
var exists bool
|
||||
err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", req.Email).Scan(&exists)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, "", ErrUserExists
|
||||
}
|
||||
|
||||
// Hash password
|
||||
passwordHash, err := s.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Create user
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: req.Email,
|
||||
PasswordHash: &passwordHash,
|
||||
Name: req.Name,
|
||||
Role: "user",
|
||||
EmailVerified: false,
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO users (id, email, password_hash, name, role, email_verified, account_status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
|
||||
`, user.ID, user.Email, user.PasswordHash, user.Name, user.Role, user.EmailVerified, user.AccountStatus)
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Generate email verification token
|
||||
verificationToken, err := s.GenerateSecureToken(32)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Store verification token
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO email_verification_tokens (user_id, token, expires_at, created_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
`, user.ID, verificationToken, time.Now().Add(24*time.Hour))
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to create verification token: %w", err)
|
||||
}
|
||||
|
||||
// Create notification preferences
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO notification_preferences (user_id, email_enabled, push_enabled, in_app_enabled, reminder_frequency, created_at, updated_at)
|
||||
VALUES ($1, true, true, true, 'weekly', NOW(), NOW())
|
||||
`, user.ID)
|
||||
|
||||
if err != nil {
|
||||
// Non-critical error, just log
|
||||
fmt.Printf("Warning: failed to create notification preferences: %v\n", err)
|
||||
}
|
||||
|
||||
return user, verificationToken, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns tokens
|
||||
func (s *AuthService) Login(ctx context.Context, req *models.LoginRequest, ipAddress, userAgent string) (*models.LoginResponse, error) {
|
||||
var user models.User
|
||||
var passwordHash *string
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, email, password_hash, name, role, email_verified, account_status,
|
||||
failed_login_attempts, locked_until, created_at, updated_at
|
||||
FROM users WHERE email = $1
|
||||
`, req.Email).Scan(
|
||||
&user.ID, &user.Email, &passwordHash, &user.Name, &user.Role, &user.EmailVerified,
|
||||
&user.AccountStatus, &user.FailedLoginAttempts, &user.LockedUntil, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Check if account is locked
|
||||
if user.LockedUntil != nil && user.LockedUntil.After(time.Now()) {
|
||||
return nil, ErrAccountLocked
|
||||
}
|
||||
|
||||
// Check if account is suspended
|
||||
if user.AccountStatus == "suspended" {
|
||||
return nil, ErrAccountSuspended
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if passwordHash == nil || !s.VerifyPassword(req.Password, *passwordHash) {
|
||||
// Increment failed login attempts
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE users SET
|
||||
failed_login_attempts = failed_login_attempts + 1,
|
||||
locked_until = CASE WHEN failed_login_attempts >= 4 THEN NOW() + INTERVAL '30 minutes' ELSE locked_until END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, user.ID)
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// Reset failed login attempts and update last login
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE users SET
|
||||
failed_login_attempts = 0,
|
||||
locked_until = NULL,
|
||||
last_login_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, user.ID)
|
||||
|
||||
// Generate tokens
|
||||
accessToken, err := s.GenerateAccessToken(&user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
refreshToken, refreshTokenHash, err := s.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
// Store session
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO user_sessions (user_id, token_hash, ip_address, user_agent, expires_at, created_at, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())
|
||||
`, user.ID, refreshTokenHash, ipAddress, userAgent, time.Now().Add(s.refreshTokenExp))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
return &models.LoginResponse{
|
||||
User: user,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: int(s.accessTokenExp.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes the access token using a refresh token
|
||||
func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*models.LoginResponse, error) {
|
||||
tokenHash := s.HashToken(refreshToken)
|
||||
|
||||
var session models.UserSession
|
||||
var userID uuid.UUID
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, revoked_at FROM user_sessions
|
||||
WHERE token_hash = $1
|
||||
`, tokenHash).Scan(&session.ID, &userID, &session.ExpiresAt, &session.RevokedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// Check if session is expired or revoked
|
||||
if session.RevokedAt != nil || session.ExpiresAt.Before(time.Now()) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// Get user
|
||||
var user models.User
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT id, email, name, role, email_verified, account_status, created_at, updated_at
|
||||
FROM users WHERE id = $1
|
||||
`, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified,
|
||||
&user.AccountStatus, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
// Check account status
|
||||
if user.AccountStatus == "suspended" {
|
||||
return nil, ErrAccountSuspended
|
||||
}
|
||||
|
||||
// Generate new access token
|
||||
accessToken, err := s.GenerateAccessToken(&user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
// Update session last activity
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE user_sessions SET last_activity_at = NOW() WHERE id = $1
|
||||
`, session.ID)
|
||||
|
||||
return &models.LoginResponse{
|
||||
User: user,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: int(s.accessTokenExp.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyEmail verifies a user's email address
|
||||
func (s *AuthService) VerifyEmail(ctx context.Context, token string) error {
|
||||
var tokenID uuid.UUID
|
||||
var userID uuid.UUID
|
||||
var expiresAt time.Time
|
||||
var usedAt *time.Time
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, used_at FROM email_verification_tokens
|
||||
WHERE token = $1
|
||||
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
|
||||
|
||||
if err != nil {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
if usedAt != nil || expiresAt.Before(time.Now()) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE email_verification_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update token: %w", err)
|
||||
}
|
||||
|
||||
// Verify user email
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE users SET email_verified = true, email_verified_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, userID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreatePasswordResetToken creates a password reset token
|
||||
func (s *AuthService) CreatePasswordResetToken(ctx context.Context, email, ipAddress string) (string, *uuid.UUID, error) {
|
||||
var userID uuid.UUID
|
||||
err := s.db.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", email).Scan(&userID)
|
||||
if err != nil {
|
||||
// Don't reveal if user exists
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
token, err := s.GenerateSecureToken(32)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO password_reset_tokens (user_id, token, expires_at, ip_address, created_at)
|
||||
VALUES ($1, $2, $3, $4, NOW())
|
||||
`, userID, token, time.Now().Add(time.Hour), ipAddress)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create reset token: %w", err)
|
||||
}
|
||||
|
||||
return token, &userID, nil
|
||||
}
|
||||
|
||||
// ResetPassword resets a user's password using a reset token
|
||||
func (s *AuthService) ResetPassword(ctx context.Context, token, newPassword string) error {
|
||||
var tokenID uuid.UUID
|
||||
var userID uuid.UUID
|
||||
var expiresAt time.Time
|
||||
var usedAt *time.Time
|
||||
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, user_id, expires_at, used_at FROM password_reset_tokens
|
||||
WHERE token = $1
|
||||
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
|
||||
|
||||
if err != nil {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
if usedAt != nil || expiresAt.Before(time.Now()) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
passwordHash, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark token as used
|
||||
_, err = s.db.Exec(ctx, `UPDATE password_reset_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update token: %w", err)
|
||||
}
|
||||
|
||||
// Update password
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2
|
||||
`, passwordHash, userID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// Revoke all sessions for security
|
||||
_, err = s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: failed to revoke sessions: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password (requires current password)
|
||||
func (s *AuthService) ChangePassword(ctx context.Context, userID uuid.UUID, currentPassword, newPassword string) error {
|
||||
var passwordHash *string
|
||||
err := s.db.QueryRow(ctx, "SELECT password_hash FROM users WHERE id = $1", userID).Scan(&passwordHash)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
if passwordHash == nil || !s.VerifyPassword(currentPassword, *passwordHash) {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
newPasswordHash, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(ctx, `UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2`, newPasswordHash, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (s *AuthService) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
|
||||
var user models.User
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id, email, name, role, email_verified, email_verified_at, account_status,
|
||||
last_login_at, created_at, updated_at
|
||||
FROM users WHERE id = $1
|
||||
`, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified, &user.EmailVerifiedAt,
|
||||
&user.AccountStatus, &user.LastLoginAt, &user.CreatedAt, &user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateProfile updates a user's profile
|
||||
func (s *AuthService) UpdateProfile(ctx context.Context, userID uuid.UUID, req *models.UpdateProfileRequest) (*models.User, error) {
|
||||
_, err := s.db.Exec(ctx, `UPDATE users SET name = $1, updated_at = NOW() WHERE id = $2`, req.Name, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update profile: %w", err)
|
||||
}
|
||||
|
||||
return s.GetUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// GetActiveSessions retrieves all active sessions for a user
|
||||
func (s *AuthService) GetActiveSessions(ctx context.Context, userID uuid.UUID) ([]models.UserSession, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id, user_id, device_info, ip_address, user_agent, expires_at, created_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
||||
ORDER BY last_activity_at DESC
|
||||
`, userID)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []models.UserSession
|
||||
for rows.Next() {
|
||||
var session models.UserSession
|
||||
err := rows.Scan(
|
||||
&session.ID, &session.UserID, &session.DeviceInfo, &session.IPAddress,
|
||||
&session.UserAgent, &session.ExpiresAt, &session.CreatedAt, &session.LastActivityAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan session: %w", err)
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// RevokeSession revokes a specific session
|
||||
func (s *AuthService) RevokeSession(ctx context.Context, userID, sessionID uuid.UUID) error {
|
||||
result, err := s.db.Exec(ctx, `
|
||||
UPDATE user_sessions SET revoked_at = NOW() WHERE id = $1 AND user_id = $2 AND revoked_at IS NULL
|
||||
`, sessionID, userID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke session: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return errors.New("session not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logout revokes a session by refresh token
|
||||
func (s *AuthService) Logout(ctx context.Context, refreshToken string) error {
|
||||
tokenHash := s.HashToken(refreshToken)
|
||||
_, err := s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE token_hash = $1`, tokenHash)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestHashPassword tests password hashing
|
||||
func TestHashPassword(t *testing.T) {
|
||||
// Create service without DB for unit tests
|
||||
s := &AuthService{}
|
||||
|
||||
password := "testPassword123!"
|
||||
hash, err := s.HashPassword(password)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword failed: %v", err)
|
||||
}
|
||||
|
||||
if hash == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
if hash == password {
|
||||
t.Error("Hash should not equal the original password")
|
||||
}
|
||||
|
||||
// Hash should be different each time (bcrypt uses random salt)
|
||||
hash2, _ := s.HashPassword(password)
|
||||
if hash == hash2 {
|
||||
t.Error("Same password should produce different hashes due to salt")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyPassword tests password verification
|
||||
func TestVerifyPassword(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
password := "testPassword123!"
|
||||
hash, _ := s.HashPassword(password)
|
||||
|
||||
// Should verify correct password
|
||||
if !s.VerifyPassword(password, hash) {
|
||||
t.Error("VerifyPassword should return true for correct password")
|
||||
}
|
||||
|
||||
// Should reject incorrect password
|
||||
if s.VerifyPassword("wrongPassword", hash) {
|
||||
t.Error("VerifyPassword should return false for incorrect password")
|
||||
}
|
||||
|
||||
// Should reject empty password
|
||||
if s.VerifyPassword("", hash) {
|
||||
t.Error("VerifyPassword should return false for empty password")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateSecureToken tests token generation
|
||||
func TestGenerateSecureToken(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
}{
|
||||
{"short token", 16},
|
||||
{"standard token", 32},
|
||||
{"long token", 64},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := s.GenerateSecureToken(tt.length)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSecureToken failed: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Token should not be empty")
|
||||
}
|
||||
|
||||
// Tokens should be unique
|
||||
token2, _ := s.GenerateSecureToken(tt.length)
|
||||
if token == token2 {
|
||||
t.Error("Generated tokens should be unique")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashToken tests token hashing for storage
|
||||
func TestHashToken(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
token := "test-token-123"
|
||||
hash := s.HashToken(token)
|
||||
|
||||
if hash == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
if hash == token {
|
||||
t.Error("Hash should not equal the original token")
|
||||
}
|
||||
|
||||
// Same token should produce same hash (deterministic)
|
||||
hash2 := s.HashToken(token)
|
||||
if hash != hash2 {
|
||||
t.Error("Same token should produce same hash")
|
||||
}
|
||||
|
||||
// Different tokens should produce different hashes
|
||||
differentHash := s.HashToken("different-token")
|
||||
if hash == differentHash {
|
||||
t.Error("Different tokens should produce different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateAccessToken tests JWT access token generation
|
||||
func TestGenerateAccessToken(t *testing.T) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
token, err := s.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Token should not be empty")
|
||||
}
|
||||
|
||||
// Token should have three parts (header.payload.signature)
|
||||
parts := 0
|
||||
for _, c := range token {
|
||||
if c == '.' {
|
||||
parts++
|
||||
}
|
||||
}
|
||||
if parts != 2 {
|
||||
t.Errorf("JWT token should have 3 parts, got %d dots", parts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAccessToken tests JWT token validation
|
||||
func TestValidateAccessToken(t *testing.T) {
|
||||
secret := "test-secret-key-for-testing-purposes"
|
||||
s := &AuthService{
|
||||
jwtSecret: secret,
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "admin",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
token, _ := s.GenerateAccessToken(user)
|
||||
|
||||
// Should validate valid token
|
||||
claims, err := s.ValidateAccessToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != user.ID.String() {
|
||||
t.Errorf("Expected UserID %s, got %s", user.ID.String(), claims.UserID)
|
||||
}
|
||||
|
||||
if claims.Email != user.Email {
|
||||
t.Errorf("Expected Email %s, got %s", user.Email, claims.Email)
|
||||
}
|
||||
|
||||
if claims.Role != user.Role {
|
||||
t.Errorf("Expected Role %s, got %s", user.Role, claims.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAccessToken_Invalid tests invalid token scenarios
|
||||
func TestValidateAccessToken_Invalid(t *testing.T) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"empty token", ""},
|
||||
{"invalid format", "not-a-jwt-token"},
|
||||
{"invalid signature", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMTIzIn0.invalidsignature"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := s.ValidateAccessToken(tt.token)
|
||||
if err == nil {
|
||||
t.Error("ValidateAccessToken should fail for invalid token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAccessToken_WrongSecret tests token with wrong secret
|
||||
func TestValidateAccessToken_WrongSecret(t *testing.T) {
|
||||
s1 := &AuthService{
|
||||
jwtSecret: "secret-one",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
s2 := &AuthService{
|
||||
jwtSecret: "secret-two",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
// Generate token with first secret
|
||||
token, _ := s1.GenerateAccessToken(user)
|
||||
|
||||
// Try to validate with second secret (should fail)
|
||||
_, err := s2.ValidateAccessToken(token)
|
||||
if err == nil {
|
||||
t.Error("ValidateAccessToken should fail when using wrong secret")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRefreshToken tests refresh token generation
|
||||
func TestGenerateRefreshToken(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
token, hash, err := s.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRefreshToken failed: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("Token should not be empty")
|
||||
}
|
||||
|
||||
if hash == "" {
|
||||
t.Error("Hash should not be empty")
|
||||
}
|
||||
|
||||
// Verify hash matches token
|
||||
expectedHash := s.HashToken(token)
|
||||
if hash != expectedHash {
|
||||
t.Error("Returned hash should match hashed token")
|
||||
}
|
||||
|
||||
// Tokens should be unique
|
||||
token2, hash2, _ := s.GenerateRefreshToken()
|
||||
if token == token2 {
|
||||
t.Error("Generated tokens should be unique")
|
||||
}
|
||||
if hash == hash2 {
|
||||
t.Error("Generated hashes should be unique")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPasswordStrength tests various password scenarios
|
||||
func TestPasswordStrength(t *testing.T) {
|
||||
s := &AuthService{}
|
||||
|
||||
passwords := []struct {
|
||||
password string
|
||||
valid bool
|
||||
}{
|
||||
{"short", true}, // bcrypt accepts any length
|
||||
{"12345678", true}, // numbers only
|
||||
{"password", true}, // letters only
|
||||
{"Pass123!", true}, // mixed
|
||||
{"", true}, // empty (bcrypt allows)
|
||||
{string(make([]byte, 72)), true}, // max bcrypt length
|
||||
}
|
||||
|
||||
for _, p := range passwords {
|
||||
hash, err := s.HashPassword(p.password)
|
||||
if p.valid && err != nil {
|
||||
t.Errorf("HashPassword failed for valid password %q: %v", p.password, err)
|
||||
}
|
||||
if p.valid && !s.VerifyPassword(p.password, hash) {
|
||||
t.Errorf("VerifyPassword failed for password %q", p.password)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkHashPassword benchmarks password hashing
|
||||
func BenchmarkHashPassword(b *testing.B) {
|
||||
s := &AuthService{}
|
||||
password := "testPassword123!"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.HashPassword(password)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkVerifyPassword benchmarks password verification
|
||||
func BenchmarkVerifyPassword(b *testing.B) {
|
||||
s := &AuthService{}
|
||||
password := "testPassword123!"
|
||||
hash, _ := s.HashPassword(password)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.VerifyPassword(password, hash)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkGenerateAccessToken benchmarks JWT token generation
|
||||
func BenchmarkGenerateAccessToken(b *testing.B) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.GenerateAccessToken(user)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkValidateAccessToken benchmarks JWT token validation
|
||||
func BenchmarkValidateAccessToken(b *testing.B) {
|
||||
s := &AuthService{
|
||||
jwtSecret: "test-secret-key-for-testing-purposes",
|
||||
accessTokenExp: time.Hour,
|
||||
}
|
||||
|
||||
user := &models.User{
|
||||
ID: uuid.New(),
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
AccountStatus: "active",
|
||||
}
|
||||
|
||||
token, _ := s.GenerateAccessToken(user)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.ValidateAccessToken(token)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,518 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestConsentService_CreateConsent tests creating a new consent
|
||||
func TestConsentService_CreateConsent(t *testing.T) {
|
||||
// This is a unit test with table-driven approach
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
versionID uuid.UUID
|
||||
consented bool
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid consent - accepted",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
consented: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid consent - declined",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
consented: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty user ID",
|
||||
userID: uuid.Nil,
|
||||
versionID: uuid.New(),
|
||||
consented: true,
|
||||
expectError: true,
|
||||
errorContains: "user ID",
|
||||
},
|
||||
{
|
||||
name: "empty version ID",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.Nil,
|
||||
consented: true,
|
||||
expectError: true,
|
||||
errorContains: "version ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate inputs (in real implementation this would be in the service)
|
||||
var hasError bool
|
||||
if tt.userID == uuid.Nil {
|
||||
hasError = true
|
||||
} else if tt.versionID == uuid.Nil {
|
||||
hasError = true
|
||||
}
|
||||
|
||||
// Assert
|
||||
if tt.expectError && !hasError {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
if !tt.expectError && hasError {
|
||||
t.Error("Expected no error, got error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_WithdrawConsent tests withdrawing consent
|
||||
func TestConsentService_WithdrawConsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
consentID uuid.UUID
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid withdrawal",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty consent ID",
|
||||
consentID: uuid.Nil,
|
||||
userID: uuid.New(),
|
||||
expectError: true,
|
||||
errorContains: "consent ID",
|
||||
},
|
||||
{
|
||||
name: "empty user ID",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
errorContains: "user ID",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
consentID: uuid.Nil,
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
errorContains: "ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate
|
||||
var hasError bool
|
||||
if tt.consentID == uuid.Nil || tt.userID == uuid.Nil {
|
||||
hasError = true
|
||||
}
|
||||
|
||||
// Assert
|
||||
if tt.expectError && !hasError {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
if !tt.expectError && hasError {
|
||||
t.Error("Expected no error, got error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_CheckConsent tests checking consent status
|
||||
func TestConsentService_CheckConsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
documentType string
|
||||
language string
|
||||
hasConsent bool
|
||||
needsUpdate bool
|
||||
expectedConsent bool
|
||||
expectedNeedsUpd bool
|
||||
}{
|
||||
{
|
||||
name: "user has current consent",
|
||||
userID: uuid.New(),
|
||||
documentType: "terms",
|
||||
language: "de",
|
||||
hasConsent: true,
|
||||
needsUpdate: false,
|
||||
expectedConsent: true,
|
||||
expectedNeedsUpd: false,
|
||||
},
|
||||
{
|
||||
name: "user has outdated consent",
|
||||
userID: uuid.New(),
|
||||
documentType: "privacy",
|
||||
language: "de",
|
||||
hasConsent: true,
|
||||
needsUpdate: true,
|
||||
expectedConsent: true,
|
||||
expectedNeedsUpd: true,
|
||||
},
|
||||
{
|
||||
name: "user has no consent",
|
||||
userID: uuid.New(),
|
||||
documentType: "cookies",
|
||||
language: "de",
|
||||
hasConsent: false,
|
||||
needsUpdate: true,
|
||||
expectedConsent: false,
|
||||
expectedNeedsUpd: true,
|
||||
},
|
||||
{
|
||||
name: "english language",
|
||||
userID: uuid.New(),
|
||||
documentType: "terms",
|
||||
language: "en",
|
||||
hasConsent: true,
|
||||
needsUpdate: false,
|
||||
expectedConsent: true,
|
||||
expectedNeedsUpd: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate consent check logic
|
||||
hasConsent := tt.hasConsent
|
||||
needsUpdate := tt.needsUpdate
|
||||
|
||||
// Assert
|
||||
if hasConsent != tt.expectedConsent {
|
||||
t.Errorf("Expected hasConsent=%v, got %v", tt.expectedConsent, hasConsent)
|
||||
}
|
||||
if needsUpdate != tt.expectedNeedsUpd {
|
||||
t.Errorf("Expected needsUpdate=%v, got %v", tt.expectedNeedsUpd, needsUpdate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_GetConsentHistory tests retrieving consent history
|
||||
func TestConsentService_GetConsentHistory(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
expectError bool
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "valid user with consents",
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
expectEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "valid user without consents",
|
||||
userID: uuid.New(),
|
||||
expectError: false,
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
expectError: true,
|
||||
expectEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
// Assert error expectation
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_UpdateConsent tests updating existing consent
|
||||
func TestConsentService_UpdateConsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
consentID uuid.UUID
|
||||
userID uuid.UUID
|
||||
newConsented bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "update to consented",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
newConsented: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "update to not consented",
|
||||
consentID: uuid.New(),
|
||||
userID: uuid.New(),
|
||||
newConsented: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid consent ID",
|
||||
consentID: uuid.Nil,
|
||||
userID: uuid.New(),
|
||||
newConsented: true,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.consentID == uuid.Nil {
|
||||
err = &ValidationError{Field: "consent ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_GetConsentStats tests getting consent statistics
|
||||
func TestConsentService_GetConsentStats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentType string
|
||||
totalUsers int
|
||||
consentedUsers int
|
||||
expectedRate float64
|
||||
}{
|
||||
{
|
||||
name: "100% consent rate",
|
||||
documentType: "terms",
|
||||
totalUsers: 100,
|
||||
consentedUsers: 100,
|
||||
expectedRate: 100.0,
|
||||
},
|
||||
{
|
||||
name: "50% consent rate",
|
||||
documentType: "privacy",
|
||||
totalUsers: 100,
|
||||
consentedUsers: 50,
|
||||
expectedRate: 50.0,
|
||||
},
|
||||
{
|
||||
name: "0% consent rate",
|
||||
documentType: "cookies",
|
||||
totalUsers: 100,
|
||||
consentedUsers: 0,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
{
|
||||
name: "no users",
|
||||
documentType: "terms",
|
||||
totalUsers: 0,
|
||||
consentedUsers: 0,
|
||||
expectedRate: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Calculate consent rate
|
||||
var consentRate float64
|
||||
if tt.totalUsers > 0 {
|
||||
consentRate = float64(tt.consentedUsers) / float64(tt.totalUsers) * 100
|
||||
}
|
||||
|
||||
// Assert
|
||||
if consentRate != tt.expectedRate {
|
||||
t.Errorf("Expected consent rate %.2f%%, got %.2f%%", tt.expectedRate, consentRate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_BulkConsentCheck tests checking multiple consents at once
|
||||
func TestConsentService_BulkConsentCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
documentTypes []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "check multiple documents",
|
||||
userID: uuid.New(),
|
||||
documentTypes: []string{"terms", "privacy", "cookies"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "check single document",
|
||||
userID: uuid.New(),
|
||||
documentTypes: []string{"terms"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty document list",
|
||||
userID: uuid.New(),
|
||||
documentTypes: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
documentTypes: []string{"terms"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_ConsentVersionComparison tests version comparison logic
|
||||
func TestConsentService_ConsentVersionComparison(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentVersion string
|
||||
consentedVersion string
|
||||
needsUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "same version",
|
||||
currentVersion: "1.0.0",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "minor version update",
|
||||
currentVersion: "1.1.0",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "major version update",
|
||||
currentVersion: "2.0.0",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "patch version update",
|
||||
currentVersion: "1.0.1",
|
||||
consentedVersion: "1.0.0",
|
||||
needsUpdate: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simple version comparison (in real implementation use proper semver)
|
||||
needsUpdate := tt.currentVersion != tt.consentedVersion
|
||||
|
||||
if needsUpdate != tt.needsUpdate {
|
||||
t.Errorf("Expected needsUpdate=%v, got %v", tt.needsUpdate, needsUpdate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsentService_ConsentDeadlineCheck tests deadline validation
|
||||
func TestConsentService_ConsentDeadlineCheck(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadline time.Time
|
||||
isOverdue bool
|
||||
daysLeft int
|
||||
}{
|
||||
{
|
||||
name: "deadline in 30 days",
|
||||
deadline: now.AddDate(0, 0, 30),
|
||||
isOverdue: false,
|
||||
daysLeft: 30,
|
||||
},
|
||||
{
|
||||
name: "deadline in 7 days",
|
||||
deadline: now.AddDate(0, 0, 7),
|
||||
isOverdue: false,
|
||||
daysLeft: 7,
|
||||
},
|
||||
{
|
||||
name: "deadline today",
|
||||
deadline: now,
|
||||
isOverdue: false,
|
||||
daysLeft: 0,
|
||||
},
|
||||
{
|
||||
name: "deadline 1 day overdue",
|
||||
deadline: now.AddDate(0, 0, -1),
|
||||
isOverdue: true,
|
||||
daysLeft: -1,
|
||||
},
|
||||
{
|
||||
name: "deadline 30 days overdue",
|
||||
deadline: now.AddDate(0, 0, -30),
|
||||
isOverdue: true,
|
||||
daysLeft: -30,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Calculate if overdue
|
||||
isOverdue := tt.deadline.Before(now)
|
||||
daysLeft := int(tt.deadline.Sub(now).Hours() / 24)
|
||||
|
||||
if isOverdue != tt.isOverdue {
|
||||
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
|
||||
}
|
||||
|
||||
// Allow 1 day difference due to time precision
|
||||
if abs(daysLeft-tt.daysLeft) > 1 {
|
||||
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// abs returns the absolute value of an integer
|
||||
func abs(n int) int {
|
||||
if n < 0 {
|
||||
return -n
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// DeadlineService handles consent deadlines and account suspensions
|
||||
type DeadlineService struct {
|
||||
pool *pgxpool.Pool
|
||||
notificationService *NotificationService
|
||||
}
|
||||
|
||||
// ConsentDeadline represents a consent deadline for a user
|
||||
type ConsentDeadline struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
DocumentVersionID uuid.UUID `json:"document_version_id"`
|
||||
DeadlineAt time.Time `json:"deadline_at"`
|
||||
ReminderCount int `json:"reminder_count"`
|
||||
LastReminderAt *time.Time `json:"last_reminder_at"`
|
||||
ConsentGivenAt *time.Time `json:"consent_given_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Joined fields
|
||||
DocumentName string `json:"document_name"`
|
||||
VersionNumber string `json:"version_number"`
|
||||
}
|
||||
|
||||
// AccountSuspension represents an account suspension
|
||||
type AccountSuspension struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Reason string `json:"reason"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
SuspendedAt time.Time `json:"suspended_at"`
|
||||
LiftedAt *time.Time `json:"lifted_at"`
|
||||
LiftedBy *uuid.UUID `json:"lifted_by"`
|
||||
}
|
||||
|
||||
// NewDeadlineService creates a new deadline service
|
||||
func NewDeadlineService(pool *pgxpool.Pool, notificationService *NotificationService) *DeadlineService {
|
||||
return &DeadlineService{
|
||||
pool: pool,
|
||||
notificationService: notificationService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDeadlinesForPublishedVersion creates consent deadlines for all active users
|
||||
// when a new mandatory document version is published
|
||||
func (s *DeadlineService) CreateDeadlinesForPublishedVersion(ctx context.Context, versionID uuid.UUID) error {
|
||||
// Get version info
|
||||
var documentName, versionNumber string
|
||||
var isMandatory bool
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT ld.name, dv.version, ld.is_mandatory
|
||||
FROM document_versions dv
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE dv.id = $1
|
||||
`, versionID).Scan(&documentName, &versionNumber, &isMandatory)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get version info: %w", err)
|
||||
}
|
||||
|
||||
// Only create deadlines for mandatory documents
|
||||
if !isMandatory {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deadline is 30 days from now
|
||||
deadlineAt := time.Now().AddDate(0, 0, 30)
|
||||
|
||||
// Get all active users who haven't given consent to this version
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
INSERT INTO consent_deadlines (user_id, document_version_id, deadline_at)
|
||||
SELECT u.id, $1, $2
|
||||
FROM users u
|
||||
WHERE u.account_status = 'active'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM user_consents uc
|
||||
WHERE uc.user_id = u.id AND uc.document_version_id = $1 AND uc.consented = TRUE
|
||||
)
|
||||
ON CONFLICT (user_id, document_version_id) DO NOTHING
|
||||
`, versionID, deadlineAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create deadlines: %w", err)
|
||||
}
|
||||
|
||||
// Notify users via notification service
|
||||
if s.notificationService != nil {
|
||||
go s.notificationService.NotifyConsentRequired(ctx, documentName, versionID.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkConsentGiven marks a deadline as fulfilled when user gives consent
|
||||
func (s *DeadlineService) MarkConsentGiven(ctx context.Context, userID, versionID uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE consent_deadlines
|
||||
SET consent_given_at = NOW()
|
||||
WHERE user_id = $1 AND document_version_id = $2 AND consent_given_at IS NULL
|
||||
`, userID, versionID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if user should be unsuspended
|
||||
return s.checkAndLiftSuspension(ctx, userID)
|
||||
}
|
||||
|
||||
// GetPendingDeadlines returns all pending deadlines for a user
|
||||
func (s *DeadlineService) GetPendingDeadlines(ctx context.Context, userID uuid.UUID) ([]ConsentDeadline, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at,
|
||||
cd.reminder_count, cd.last_reminder_at, cd.consent_given_at, cd.created_at,
|
||||
ld.name as document_name, dv.version as version_number
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE cd.user_id = $1 AND cd.consent_given_at IS NULL
|
||||
ORDER BY cd.deadline_at ASC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var deadlines []ConsentDeadline
|
||||
for rows.Next() {
|
||||
var d ConsentDeadline
|
||||
if err := rows.Scan(&d.ID, &d.UserID, &d.DocumentVersionID, &d.DeadlineAt,
|
||||
&d.ReminderCount, &d.LastReminderAt, &d.ConsentGivenAt, &d.CreatedAt,
|
||||
&d.DocumentName, &d.VersionNumber); err != nil {
|
||||
continue
|
||||
}
|
||||
deadlines = append(deadlines, d)
|
||||
}
|
||||
|
||||
return deadlines, nil
|
||||
}
|
||||
|
||||
// ProcessDailyDeadlines is meant to be called by a cron job daily
|
||||
// It sends reminders and suspends accounts that have missed deadlines
|
||||
func (s *DeadlineService) ProcessDailyDeadlines(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// 1. Send reminders for upcoming deadlines
|
||||
if err := s.sendReminders(ctx, now); err != nil {
|
||||
fmt.Printf("Error sending reminders: %v\n", err)
|
||||
}
|
||||
|
||||
// 2. Suspend accounts with expired deadlines
|
||||
if err := s.suspendExpiredAccounts(ctx, now); err != nil {
|
||||
fmt.Printf("Error suspending accounts: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendReminders sends reminder notifications based on days remaining
|
||||
func (s *DeadlineService) sendReminders(ctx context.Context, now time.Time) error {
|
||||
// Reminder schedule: Day 7, 14, 21, 28
|
||||
reminderDays := []int{7, 14, 21, 28}
|
||||
|
||||
for _, days := range reminderDays {
|
||||
targetDate := now.AddDate(0, 0, days)
|
||||
dayStart := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, targetDate.Location())
|
||||
dayEnd := dayStart.AddDate(0, 0, 1)
|
||||
|
||||
// Find deadlines that fall on this reminder day
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at, cd.reminder_count,
|
||||
ld.name as document_name
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE cd.consent_given_at IS NULL
|
||||
AND cd.deadline_at >= $1 AND cd.deadline_at < $2
|
||||
AND (cd.last_reminder_at IS NULL OR cd.last_reminder_at < $3)
|
||||
`, dayStart, dayEnd, dayStart)
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var id, userID, versionID uuid.UUID
|
||||
var deadlineAt time.Time
|
||||
var reminderCount int
|
||||
var documentName string
|
||||
|
||||
if err := rows.Scan(&id, &userID, &versionID, &deadlineAt, &reminderCount, &documentName); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send reminder notification
|
||||
daysLeft := 30 - (30 - days)
|
||||
urgency := "freundlich"
|
||||
if days <= 7 {
|
||||
urgency = "dringend"
|
||||
} else if days <= 14 {
|
||||
urgency = "wichtig"
|
||||
}
|
||||
|
||||
title := fmt.Sprintf("Erinnerung: Zustimmung erforderlich (%s)", urgency)
|
||||
body := fmt.Sprintf("Bitte bestätigen Sie '%s' innerhalb von %d Tagen.", documentName, daysLeft)
|
||||
|
||||
if s.notificationService != nil {
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeConsentReminder, title, body, map[string]interface{}{
|
||||
"document_name": documentName,
|
||||
"days_left": daysLeft,
|
||||
"version_id": versionID.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Update reminder count and timestamp
|
||||
s.pool.Exec(ctx, `
|
||||
UPDATE consent_deadlines
|
||||
SET reminder_count = reminder_count + 1, last_reminder_at = NOW()
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// suspendExpiredAccounts suspends accounts that have missed their deadline
|
||||
func (s *DeadlineService) suspendExpiredAccounts(ctx context.Context, now time.Time) error {
|
||||
// Find users with expired deadlines
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT DISTINCT cd.user_id, array_agg(ld.name) as documents
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
JOIN users u ON cd.user_id = u.id
|
||||
WHERE cd.consent_given_at IS NULL
|
||||
AND cd.deadline_at < $1
|
||||
AND u.account_status = 'active'
|
||||
AND ld.is_mandatory = TRUE
|
||||
GROUP BY cd.user_id
|
||||
`, now)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
var documents []string
|
||||
|
||||
if err := rows.Scan(&userID, &documents); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Suspend the account
|
||||
if err := s.suspendAccount(ctx, userID, "consent_deadline_missed", documents); err != nil {
|
||||
fmt.Printf("Failed to suspend user %s: %v\n", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// suspendAccount suspends a user account
|
||||
func (s *DeadlineService) suspendAccount(ctx context.Context, userID uuid.UUID, reason string, documents []string) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Update user status
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE users SET account_status = 'suspended', updated_at = NOW()
|
||||
WHERE id = $1 AND account_status = 'active'
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create suspension record
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO account_suspensions (user_id, reason, details)
|
||||
VALUES ($1, $2, $3)
|
||||
`, userID, reason, map[string]interface{}{"documents": documents})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Log to audit
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id, details)
|
||||
VALUES ($1, 'account_suspended', 'user', $1, $2)
|
||||
`, userID, map[string]interface{}{"reason": reason, "documents": documents})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send suspension notification
|
||||
if s.notificationService != nil {
|
||||
title := "Account vorübergehend gesperrt"
|
||||
body := "Ihr Account wurde gesperrt, da ausstehende Zustimmungen nicht innerhalb der Frist erteilt wurden. Bitte bestätigen Sie die ausstehenden Dokumente."
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountSuspended, title, body, map[string]interface{}{
|
||||
"documents": documents,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAndLiftSuspension checks if user has completed all required consents and lifts suspension
|
||||
func (s *DeadlineService) checkAndLiftSuspension(ctx context.Context, userID uuid.UUID) error {
|
||||
// Check if user is currently suspended
|
||||
var accountStatus string
|
||||
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
|
||||
if err != nil || accountStatus != "suspended" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if there are any pending mandatory consents
|
||||
var pendingCount int
|
||||
err = s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM consent_deadlines cd
|
||||
JOIN document_versions dv ON cd.document_version_id = dv.id
|
||||
JOIN legal_documents ld ON dv.document_id = ld.id
|
||||
WHERE cd.user_id = $1
|
||||
AND cd.consent_given_at IS NULL
|
||||
AND ld.is_mandatory = TRUE
|
||||
`, userID).Scan(&pendingCount)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no pending consents, lift the suspension
|
||||
if pendingCount == 0 {
|
||||
return s.liftSuspension(ctx, userID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// liftSuspension lifts a user's suspension
|
||||
func (s *DeadlineService) liftSuspension(ctx context.Context, userID uuid.UUID) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Update user status
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE users SET account_status = 'active', updated_at = NOW()
|
||||
WHERE id = $1 AND account_status = 'suspended'
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update suspension record
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE account_suspensions
|
||||
SET lifted_at = NOW()
|
||||
WHERE user_id = $1 AND lifted_at IS NULL
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Log to audit
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id)
|
||||
VALUES ($1, 'account_restored', 'user', $1)
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send restoration notification
|
||||
if s.notificationService != nil {
|
||||
title := "Account wiederhergestellt"
|
||||
body := "Vielen Dank! Ihr Account wurde wiederhergestellt. Sie können die Anwendung wieder vollständig nutzen."
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountRestored, title, body, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountSuspension returns the current suspension for a user
|
||||
func (s *DeadlineService) GetAccountSuspension(ctx context.Context, userID uuid.UUID) (*AccountSuspension, error) {
|
||||
var suspension AccountSuspension
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, user_id, reason, details, suspended_at, lifted_at, lifted_by
|
||||
FROM account_suspensions
|
||||
WHERE user_id = $1 AND lifted_at IS NULL
|
||||
ORDER BY suspended_at DESC
|
||||
LIMIT 1
|
||||
`, userID).Scan(&suspension.ID, &suspension.UserID, &suspension.Reason, &suspension.Details,
|
||||
&suspension.SuspendedAt, &suspension.LiftedAt, &suspension.LiftedBy)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &suspension, nil
|
||||
}
|
||||
|
||||
// IsUserSuspended checks if a user is currently suspended
|
||||
func (s *DeadlineService) IsUserSuspended(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
var status string
|
||||
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&status)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return status == "suspended", nil
|
||||
}
|
||||
@@ -0,0 +1,439 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestDeadlineService_CreateDeadline tests creating consent deadlines
|
||||
func TestDeadlineService_CreateDeadline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
versionID uuid.UUID
|
||||
deadlineAt time.Time
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid deadline - 30 days",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, 30),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid deadline - 14 days",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, 14),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, 30),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid version ID",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.Nil,
|
||||
deadlineAt: time.Now().AddDate(0, 0, 30),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "deadline in past",
|
||||
userID: uuid.New(),
|
||||
versionID: uuid.New(),
|
||||
deadlineAt: time.Now().AddDate(0, 0, -1),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if tt.versionID == uuid.Nil {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
} else if tt.deadlineAt.Before(time.Now()) {
|
||||
err = &ValidationError{Field: "deadline", Message: "must be in the future"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_CheckDeadlineStatus tests deadline status checking
|
||||
func TestDeadlineService_CheckDeadlineStatus(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadlineAt time.Time
|
||||
isOverdue bool
|
||||
daysLeft int
|
||||
urgency string
|
||||
}{
|
||||
{
|
||||
name: "30 days left",
|
||||
deadlineAt: now.AddDate(0, 0, 30),
|
||||
isOverdue: false,
|
||||
daysLeft: 30,
|
||||
urgency: "normal",
|
||||
},
|
||||
{
|
||||
name: "7 days left - warning",
|
||||
deadlineAt: now.AddDate(0, 0, 7),
|
||||
isOverdue: false,
|
||||
daysLeft: 7,
|
||||
urgency: "warning",
|
||||
},
|
||||
{
|
||||
name: "3 days left - urgent",
|
||||
deadlineAt: now.AddDate(0, 0, 3),
|
||||
isOverdue: false,
|
||||
daysLeft: 3,
|
||||
urgency: "urgent",
|
||||
},
|
||||
{
|
||||
name: "1 day left - critical",
|
||||
deadlineAt: now.AddDate(0, 0, 1),
|
||||
isOverdue: false,
|
||||
daysLeft: 1,
|
||||
urgency: "critical",
|
||||
},
|
||||
{
|
||||
name: "overdue by 1 day",
|
||||
deadlineAt: now.AddDate(0, 0, -1),
|
||||
isOverdue: true,
|
||||
daysLeft: -1,
|
||||
urgency: "overdue",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isOverdue := tt.deadlineAt.Before(now)
|
||||
daysLeft := int(tt.deadlineAt.Sub(now).Hours() / 24)
|
||||
|
||||
var urgency string
|
||||
if isOverdue {
|
||||
urgency = "overdue"
|
||||
} else if daysLeft <= 1 {
|
||||
urgency = "critical"
|
||||
} else if daysLeft <= 3 {
|
||||
urgency = "urgent"
|
||||
} else if daysLeft <= 7 {
|
||||
urgency = "warning"
|
||||
} else {
|
||||
urgency = "normal"
|
||||
}
|
||||
|
||||
if isOverdue != tt.isOverdue {
|
||||
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
|
||||
}
|
||||
|
||||
if abs(daysLeft-tt.daysLeft) > 1 { // Allow 1 day difference
|
||||
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
|
||||
}
|
||||
|
||||
if urgency != tt.urgency {
|
||||
t.Errorf("Expected urgency=%s, got %s", tt.urgency, urgency)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_SendReminders tests reminder scheduling
|
||||
func TestDeadlineService_SendReminders(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadlineAt time.Time
|
||||
lastReminderAt *time.Time
|
||||
reminderCount int
|
||||
shouldSend bool
|
||||
nextReminder int // days before deadline
|
||||
}{
|
||||
{
|
||||
name: "first reminder - 14 days before",
|
||||
deadlineAt: now.AddDate(0, 0, 14),
|
||||
lastReminderAt: nil,
|
||||
reminderCount: 0,
|
||||
shouldSend: true,
|
||||
nextReminder: 14,
|
||||
},
|
||||
{
|
||||
name: "second reminder - 7 days before",
|
||||
deadlineAt: now.AddDate(0, 0, 7),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -7)),
|
||||
reminderCount: 1,
|
||||
shouldSend: true,
|
||||
nextReminder: 7,
|
||||
},
|
||||
{
|
||||
name: "third reminder - 3 days before",
|
||||
deadlineAt: now.AddDate(0, 0, 3),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -4)),
|
||||
reminderCount: 2,
|
||||
shouldSend: true,
|
||||
nextReminder: 3,
|
||||
},
|
||||
{
|
||||
name: "final reminder - 1 day before",
|
||||
deadlineAt: now.AddDate(0, 0, 1),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -2)),
|
||||
reminderCount: 3,
|
||||
shouldSend: true,
|
||||
nextReminder: 1,
|
||||
},
|
||||
{
|
||||
name: "too soon for next reminder",
|
||||
deadlineAt: now.AddDate(0, 0, 10),
|
||||
lastReminderAt: ptrTime(now.AddDate(0, 0, -1)),
|
||||
reminderCount: 1,
|
||||
shouldSend: false,
|
||||
nextReminder: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
daysUntilDeadline := int(tt.deadlineAt.Sub(now).Hours() / 24)
|
||||
|
||||
// Reminder schedule: 14, 7, 3, 1 days before deadline
|
||||
reminderDays := []int{14, 7, 3, 1}
|
||||
shouldSend := false
|
||||
|
||||
for _, day := range reminderDays {
|
||||
if daysUntilDeadline == day {
|
||||
// Check if enough time passed since last reminder
|
||||
if tt.lastReminderAt == nil || now.Sub(*tt.lastReminderAt) > 12*time.Hour {
|
||||
shouldSend = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldSend != tt.shouldSend {
|
||||
t.Errorf("Expected shouldSend=%v, got %v", tt.shouldSend, shouldSend)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_SuspendAccount tests account suspension logic
|
||||
func TestDeadlineService_SuspendAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
reason string
|
||||
shouldSuspend bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "suspend for missed deadline",
|
||||
userID: uuid.New(),
|
||||
reason: "consent_deadline_exceeded",
|
||||
shouldSuspend: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
reason: "consent_deadline_exceeded",
|
||||
shouldSuspend: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid reason",
|
||||
userID: uuid.New(),
|
||||
reason: "",
|
||||
shouldSuspend: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
validReasons := map[string]bool{
|
||||
"consent_deadline_exceeded": true,
|
||||
"mandatory_consent_missing": true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if !validReasons[tt.reason] && tt.reason != "" {
|
||||
err = &ValidationError{Field: "reason", Message: "invalid suspension reason"}
|
||||
} else if tt.reason == "" {
|
||||
err = &ValidationError{Field: "reason", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_LiftSuspension tests lifting account suspension
|
||||
func TestDeadlineService_LiftSuspension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uuid.UUID
|
||||
adminID uuid.UUID
|
||||
reason string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "lift valid suspension",
|
||||
userID: uuid.New(),
|
||||
adminID: uuid.New(),
|
||||
reason: "consent provided",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
userID: uuid.Nil,
|
||||
adminID: uuid.New(),
|
||||
reason: "consent provided",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid admin ID",
|
||||
userID: uuid.New(),
|
||||
adminID: uuid.Nil,
|
||||
reason: "consent provided",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.userID == uuid.Nil {
|
||||
err = &ValidationError{Field: "user ID", Message: "required"}
|
||||
} else if tt.adminID == uuid.Nil {
|
||||
err = &ValidationError{Field: "admin ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_GetOverdueDeadlines tests finding overdue deadlines
|
||||
func TestDeadlineService_GetOverdueDeadlines(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deadlines []time.Time
|
||||
expected int // number of overdue
|
||||
}{
|
||||
{
|
||||
name: "no overdue deadlines",
|
||||
deadlines: []time.Time{
|
||||
now.AddDate(0, 0, 1),
|
||||
now.AddDate(0, 0, 7),
|
||||
now.AddDate(0, 0, 30),
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "some overdue",
|
||||
deadlines: []time.Time{
|
||||
now.AddDate(0, 0, -1),
|
||||
now.AddDate(0, 0, -5),
|
||||
now.AddDate(0, 0, 7),
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "all overdue",
|
||||
deadlines: []time.Time{
|
||||
now.AddDate(0, 0, -1),
|
||||
now.AddDate(0, 0, -7),
|
||||
now.AddDate(0, 0, -30),
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
overdueCount := 0
|
||||
for _, deadline := range tt.deadlines {
|
||||
if deadline.Before(now) {
|
||||
overdueCount++
|
||||
}
|
||||
}
|
||||
|
||||
if overdueCount != tt.expected {
|
||||
t.Errorf("Expected %d overdue, got %d", tt.expected, overdueCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeadlineService_ProcessScheduledTasks tests scheduled task processing
|
||||
func TestDeadlineService_ProcessScheduledTasks(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
task string
|
||||
scheduledAt time.Time
|
||||
shouldProcess bool
|
||||
}{
|
||||
{
|
||||
name: "process due task",
|
||||
task: "send_reminder",
|
||||
scheduledAt: now.Add(-1 * time.Hour),
|
||||
shouldProcess: true,
|
||||
},
|
||||
{
|
||||
name: "skip future task",
|
||||
task: "send_reminder",
|
||||
scheduledAt: now.Add(1 * time.Hour),
|
||||
shouldProcess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldProcess := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
|
||||
|
||||
if shouldProcess != tt.shouldProcess {
|
||||
t.Errorf("Expected shouldProcess=%v, got %v", tt.shouldProcess, shouldProcess)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
@@ -0,0 +1,728 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestDocumentService_CreateDocument tests creating a new legal document
|
||||
func TestDocumentService_CreateDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
docType string
|
||||
docName string
|
||||
description string
|
||||
isMandatory bool
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid mandatory document",
|
||||
docType: "terms",
|
||||
docName: "Terms of Service",
|
||||
description: "Our terms and conditions",
|
||||
isMandatory: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid optional document",
|
||||
docType: "cookies",
|
||||
docName: "Cookie Policy",
|
||||
description: "How we use cookies",
|
||||
isMandatory: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty document type",
|
||||
docType: "",
|
||||
docName: "Test Document",
|
||||
description: "Test",
|
||||
isMandatory: true,
|
||||
expectError: true,
|
||||
errorContains: "type",
|
||||
},
|
||||
{
|
||||
name: "empty document name",
|
||||
docType: "privacy",
|
||||
docName: "",
|
||||
description: "Test",
|
||||
isMandatory: true,
|
||||
expectError: true,
|
||||
errorContains: "name",
|
||||
},
|
||||
{
|
||||
name: "invalid document type",
|
||||
docType: "invalid_type",
|
||||
docName: "Test",
|
||||
description: "Test",
|
||||
isMandatory: false,
|
||||
expectError: true,
|
||||
errorContains: "type",
|
||||
},
|
||||
}
|
||||
|
||||
validTypes := map[string]bool{
|
||||
"terms": true,
|
||||
"privacy": true,
|
||||
"cookies": true,
|
||||
"community_guidelines": true,
|
||||
"imprint": true,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate inputs
|
||||
var err error
|
||||
if tt.docType == "" {
|
||||
err = &ValidationError{Field: "type", Message: "required"}
|
||||
} else if !validTypes[tt.docType] {
|
||||
err = &ValidationError{Field: "type", Message: "invalid document type"}
|
||||
} else if tt.docName == "" {
|
||||
err = &ValidationError{Field: "name", Message: "required"}
|
||||
}
|
||||
|
||||
// Assert
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_UpdateDocument tests updating a document
|
||||
func TestDocumentService_UpdateDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentID uuid.UUID
|
||||
newName string
|
||||
newActive bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid update",
|
||||
documentID: uuid.New(),
|
||||
newName: "Updated Name",
|
||||
newActive: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "deactivate document",
|
||||
documentID: uuid.New(),
|
||||
newName: "Test",
|
||||
newActive: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid document ID",
|
||||
documentID: uuid.Nil,
|
||||
newName: "Test",
|
||||
newActive: true,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.documentID == uuid.Nil {
|
||||
err = &ValidationError{Field: "document ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_CreateVersion tests creating a document version
|
||||
func TestDocumentService_CreateVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentID uuid.UUID
|
||||
version string
|
||||
language string
|
||||
title string
|
||||
content string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid version - German",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "de",
|
||||
title: "Nutzungsbedingungen",
|
||||
content: "<h1>Terms</h1><p>Content...</p>",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid version - English",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "en",
|
||||
title: "Terms of Service",
|
||||
content: "<h1>Terms</h1><p>Content...</p>",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid version format",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0",
|
||||
language: "de",
|
||||
title: "Test",
|
||||
content: "Content",
|
||||
expectError: true,
|
||||
errorContains: "version",
|
||||
},
|
||||
{
|
||||
name: "invalid language",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "fr",
|
||||
title: "Test",
|
||||
content: "Content",
|
||||
expectError: true,
|
||||
errorContains: "language",
|
||||
},
|
||||
{
|
||||
name: "empty title",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "de",
|
||||
title: "",
|
||||
content: "Content",
|
||||
expectError: true,
|
||||
errorContains: "title",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
documentID: uuid.New(),
|
||||
version: "1.0.0",
|
||||
language: "de",
|
||||
title: "Test",
|
||||
content: "",
|
||||
expectError: true,
|
||||
errorContains: "content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Validate semver format (X.Y.Z pattern)
|
||||
validVersion := regexp.MustCompile(`^\d+\.\d+\.\d+$`).MatchString(tt.version)
|
||||
validLanguage := tt.language == "de" || tt.language == "en"
|
||||
|
||||
var err error
|
||||
if !validVersion {
|
||||
err = &ValidationError{Field: "version", Message: "invalid format"}
|
||||
} else if !validLanguage {
|
||||
err = &ValidationError{Field: "language", Message: "must be 'de' or 'en'"}
|
||||
} else if tt.title == "" {
|
||||
err = &ValidationError{Field: "title", Message: "required"}
|
||||
} else if tt.content == "" {
|
||||
err = &ValidationError{Field: "content", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_VersionStatusTransitions tests version status workflow
|
||||
func TestDocumentService_VersionStatusTransitions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fromStatus string
|
||||
toStatus string
|
||||
isAllowed bool
|
||||
}{
|
||||
// Valid transitions
|
||||
{"draft to review", "draft", "review", true},
|
||||
{"review to approved", "review", "approved", true},
|
||||
{"review to rejected", "review", "rejected", true},
|
||||
{"approved to published", "approved", "published", true},
|
||||
{"approved to scheduled", "approved", "scheduled", true},
|
||||
{"scheduled to published", "scheduled", "published", true},
|
||||
{"published to archived", "published", "archived", true},
|
||||
{"rejected to draft", "rejected", "draft", true},
|
||||
|
||||
// Invalid transitions
|
||||
{"draft to published", "draft", "published", false},
|
||||
{"draft to approved", "draft", "approved", false},
|
||||
{"review to published", "review", "published", false},
|
||||
{"published to draft", "published", "draft", false},
|
||||
{"published to review", "published", "review", false},
|
||||
{"archived to draft", "archived", "draft", false},
|
||||
{"archived to published", "archived", "published", false},
|
||||
}
|
||||
|
||||
// Define valid transitions
|
||||
validTransitions := map[string][]string{
|
||||
"draft": {"review"},
|
||||
"review": {"approved", "rejected"},
|
||||
"approved": {"published", "scheduled"},
|
||||
"scheduled": {"published"},
|
||||
"published": {"archived"},
|
||||
"rejected": {"draft"},
|
||||
"archived": {}, // terminal state
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Check if transition is allowed
|
||||
allowed := false
|
||||
if transitions, ok := validTransitions[tt.fromStatus]; ok {
|
||||
for _, validTo := range transitions {
|
||||
if validTo == tt.toStatus {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.isAllowed {
|
||||
t.Errorf("Transition %s->%s: expected allowed=%v, got %v",
|
||||
tt.fromStatus, tt.toStatus, tt.isAllowed, allowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_PublishVersion tests publishing a version
|
||||
func TestDocumentService_PublishVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionID uuid.UUID
|
||||
currentStatus string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "publish approved version",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "approved",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "publish scheduled version",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "scheduled",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "cannot publish draft",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "draft",
|
||||
expectError: true,
|
||||
errorContains: "draft",
|
||||
},
|
||||
{
|
||||
name: "cannot publish review",
|
||||
versionID: uuid.New(),
|
||||
currentStatus: "review",
|
||||
expectError: true,
|
||||
errorContains: "review",
|
||||
},
|
||||
{
|
||||
name: "invalid version ID",
|
||||
versionID: uuid.Nil,
|
||||
currentStatus: "approved",
|
||||
expectError: true,
|
||||
errorContains: "ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.versionID == uuid.Nil {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
} else if tt.currentStatus != "approved" && tt.currentStatus != "scheduled" {
|
||||
err = &ValidationError{Field: "status", Message: "only approved or scheduled versions can be published"}
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_ArchiveVersion tests archiving a version
|
||||
func TestDocumentService_ArchiveVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionID uuid.UUID
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "archive valid version",
|
||||
versionID: uuid.New(),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid version ID",
|
||||
versionID: uuid.Nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.versionID == uuid.Nil {
|
||||
err = &ValidationError{Field: "version ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_DeleteVersion tests deleting a version
|
||||
func TestDocumentService_DeleteVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionID uuid.UUID
|
||||
status string
|
||||
canDelete bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete draft version",
|
||||
versionID: uuid.New(),
|
||||
status: "draft",
|
||||
canDelete: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete rejected version",
|
||||
versionID: uuid.New(),
|
||||
status: "rejected",
|
||||
canDelete: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "cannot delete published version",
|
||||
versionID: uuid.New(),
|
||||
status: "published",
|
||||
canDelete: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "cannot delete approved version",
|
||||
versionID: uuid.New(),
|
||||
status: "approved",
|
||||
canDelete: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "cannot delete archived version",
|
||||
versionID: uuid.New(),
|
||||
status: "archived",
|
||||
canDelete: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Only draft and rejected can be deleted
|
||||
canDelete := tt.status == "draft" || tt.status == "rejected"
|
||||
|
||||
var err error
|
||||
if !canDelete {
|
||||
err = &ValidationError{Field: "status", Message: "only draft or rejected versions can be deleted"}
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if canDelete != tt.canDelete {
|
||||
t.Errorf("Expected canDelete=%v, got %v", tt.canDelete, canDelete)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_GetLatestVersion tests retrieving the latest version
|
||||
func TestDocumentService_GetLatestVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
documentID uuid.UUID
|
||||
language string
|
||||
status string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get latest German version",
|
||||
documentID: uuid.New(),
|
||||
language: "de",
|
||||
status: "published",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "get latest English version",
|
||||
documentID: uuid.New(),
|
||||
language: "en",
|
||||
status: "published",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid document ID",
|
||||
documentID: uuid.Nil,
|
||||
language: "de",
|
||||
status: "published",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var err error
|
||||
if tt.documentID == uuid.Nil {
|
||||
err = &ValidationError{Field: "document ID", Message: "required"}
|
||||
}
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_CompareVersions tests version comparison
|
||||
func TestDocumentService_CompareVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
version1 string
|
||||
version2 string
|
||||
isDifferent bool
|
||||
}{
|
||||
{
|
||||
name: "same version",
|
||||
version1: "1.0.0",
|
||||
version2: "1.0.0",
|
||||
isDifferent: false,
|
||||
},
|
||||
{
|
||||
name: "different major version",
|
||||
version1: "2.0.0",
|
||||
version2: "1.0.0",
|
||||
isDifferent: true,
|
||||
},
|
||||
{
|
||||
name: "different minor version",
|
||||
version1: "1.1.0",
|
||||
version2: "1.0.0",
|
||||
isDifferent: true,
|
||||
},
|
||||
{
|
||||
name: "different patch version",
|
||||
version1: "1.0.1",
|
||||
version2: "1.0.0",
|
||||
isDifferent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isDifferent := tt.version1 != tt.version2
|
||||
|
||||
if isDifferent != tt.isDifferent {
|
||||
t.Errorf("Expected isDifferent=%v, got %v", tt.isDifferent, isDifferent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_ScheduledPublishing tests scheduled publishing
|
||||
func TestDocumentService_ScheduledPublishing(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scheduledAt time.Time
|
||||
shouldPublish bool
|
||||
}{
|
||||
{
|
||||
name: "scheduled for past - should publish",
|
||||
scheduledAt: now.Add(-1 * time.Hour),
|
||||
shouldPublish: true,
|
||||
},
|
||||
{
|
||||
name: "scheduled for now - should publish",
|
||||
scheduledAt: now,
|
||||
shouldPublish: true,
|
||||
},
|
||||
{
|
||||
name: "scheduled for future - should not publish",
|
||||
scheduledAt: now.Add(1 * time.Hour),
|
||||
shouldPublish: false,
|
||||
},
|
||||
{
|
||||
name: "scheduled for tomorrow - should not publish",
|
||||
scheduledAt: now.AddDate(0, 0, 1),
|
||||
shouldPublish: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldPublish := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
|
||||
|
||||
if shouldPublish != tt.shouldPublish {
|
||||
t.Errorf("Expected shouldPublish=%v, got %v", tt.shouldPublish, shouldPublish)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_ApprovalWorkflow tests the approval workflow
|
||||
func TestDocumentService_ApprovalWorkflow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
action string
|
||||
userRole string
|
||||
isAllowed bool
|
||||
}{
|
||||
// Admin permissions
|
||||
{"admin submit for review", "submit_review", "admin", true},
|
||||
{"admin cannot approve", "approve", "admin", false},
|
||||
{"admin can publish", "publish", "admin", true},
|
||||
|
||||
// DSB permissions
|
||||
{"dsb can approve", "approve", "data_protection_officer", true},
|
||||
{"dsb can reject", "reject", "data_protection_officer", true},
|
||||
{"dsb can publish", "publish", "data_protection_officer", true},
|
||||
|
||||
// User permissions
|
||||
{"user cannot submit", "submit_review", "user", false},
|
||||
{"user cannot approve", "approve", "user", false},
|
||||
{"user cannot publish", "publish", "user", false},
|
||||
}
|
||||
|
||||
permissions := map[string]map[string]bool{
|
||||
"admin": {
|
||||
"submit_review": true,
|
||||
"approve": false,
|
||||
"reject": false,
|
||||
"publish": true,
|
||||
},
|
||||
"data_protection_officer": {
|
||||
"submit_review": true,
|
||||
"approve": true,
|
||||
"reject": true,
|
||||
"publish": true,
|
||||
},
|
||||
"user": {
|
||||
"submit_review": false,
|
||||
"approve": false,
|
||||
"reject": false,
|
||||
"publish": false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rolePerms, ok := permissions[tt.userRole]
|
||||
if !ok {
|
||||
t.Fatalf("Unknown role: %s", tt.userRole)
|
||||
}
|
||||
|
||||
isAllowed := rolePerms[tt.action]
|
||||
|
||||
if isAllowed != tt.isAllowed {
|
||||
t.Errorf("Role %s action %s: expected allowed=%v, got %v",
|
||||
tt.userRole, tt.action, tt.isAllowed, isAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentService_FourEyesPrinciple tests the four-eyes principle
|
||||
func TestDocumentService_FourEyesPrinciple(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
createdBy uuid.UUID
|
||||
approver uuid.UUID
|
||||
approverRole string
|
||||
canApprove bool
|
||||
}{
|
||||
{
|
||||
name: "different users - DSB can approve",
|
||||
createdBy: uuid.New(),
|
||||
approver: uuid.New(),
|
||||
approverRole: "data_protection_officer",
|
||||
canApprove: true,
|
||||
},
|
||||
{
|
||||
name: "same user - DSB cannot approve own",
|
||||
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approverRole: "data_protection_officer",
|
||||
canApprove: false,
|
||||
},
|
||||
{
|
||||
name: "same user - admin CAN approve own (exception)",
|
||||
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
|
||||
approverRole: "admin",
|
||||
canApprove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Four-eyes principle: DSB cannot approve their own work
|
||||
// Exception: Admins can (for development/testing)
|
||||
canApprove := tt.createdBy != tt.approver || tt.approverRole == "admin"
|
||||
|
||||
if canApprove != tt.canApprove {
|
||||
t.Errorf("Expected canApprove=%v, got %v", tt.canApprove, canApprove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,947 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// DSRService handles Data Subject Request business logic
|
||||
type DSRService struct {
|
||||
pool *pgxpool.Pool
|
||||
notificationService *NotificationService
|
||||
emailService *EmailService
|
||||
}
|
||||
|
||||
// NewDSRService creates a new DSRService
|
||||
func NewDSRService(pool *pgxpool.Pool, notificationService *NotificationService, emailService *EmailService) *DSRService {
|
||||
return &DSRService{
|
||||
pool: pool,
|
||||
notificationService: notificationService,
|
||||
emailService: emailService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPool returns the database pool for direct queries
|
||||
func (s *DSRService) GetPool() *pgxpool.Pool {
|
||||
return s.pool
|
||||
}
|
||||
|
||||
// generateRequestNumber generates a unique request number like DSR-2025-000001
|
||||
func (s *DSRService) generateRequestNumber(ctx context.Context) (string, error) {
|
||||
var seqNum int64
|
||||
err := s.pool.QueryRow(ctx, "SELECT nextval('dsr_request_number_seq')").Scan(&seqNum)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get next sequence number: %w", err)
|
||||
}
|
||||
year := time.Now().Year()
|
||||
return fmt.Sprintf("DSR-%d-%06d", year, seqNum), nil
|
||||
}
|
||||
|
||||
// CreateRequest creates a new data subject request
|
||||
func (s *DSRService) CreateRequest(ctx context.Context, req models.CreateDSRRequest, createdBy *uuid.UUID) (*models.DataSubjectRequest, error) {
|
||||
// Validate request type
|
||||
requestType := models.DSRRequestType(req.RequestType)
|
||||
if !isValidRequestType(requestType) {
|
||||
return nil, fmt.Errorf("invalid request type: %s", req.RequestType)
|
||||
}
|
||||
|
||||
// Generate request number
|
||||
requestNumber, err := s.generateRequestNumber(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate deadline
|
||||
deadlineDays := requestType.DeadlineDays()
|
||||
deadline := time.Now().AddDate(0, 0, deadlineDays)
|
||||
|
||||
// Determine priority
|
||||
priority := models.DSRPriorityNormal
|
||||
if req.Priority != "" {
|
||||
priority = models.DSRPriority(req.Priority)
|
||||
} else if requestType.IsExpedited() {
|
||||
priority = models.DSRPriorityExpedited
|
||||
}
|
||||
|
||||
// Determine source
|
||||
source := models.DSRSourceAPI
|
||||
if req.Source != "" {
|
||||
source = models.DSRSource(req.Source)
|
||||
}
|
||||
|
||||
// Serialize request details
|
||||
detailsJSON, err := json.Marshal(req.RequestDetails)
|
||||
if err != nil {
|
||||
detailsJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// Try to find existing user by email
|
||||
var userID *uuid.UUID
|
||||
var foundUserID uuid.UUID
|
||||
err = s.pool.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", req.RequesterEmail).Scan(&foundUserID)
|
||||
if err == nil {
|
||||
userID = &foundUserID
|
||||
}
|
||||
|
||||
// Insert request
|
||||
var dsr models.DataSubjectRequest
|
||||
err = s.pool.QueryRow(ctx, `
|
||||
INSERT INTO data_subject_requests (
|
||||
user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone,
|
||||
request_details, deadline_at, legal_deadline_days, created_by
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone, identity_verified,
|
||||
request_details, deadline_at, legal_deadline_days, created_at, updated_at, created_by
|
||||
`, userID, requestNumber, requestType, models.DSRStatusIntake, priority, source,
|
||||
req.RequesterEmail, req.RequesterName, req.RequesterPhone,
|
||||
detailsJSON, deadline, deadlineDays, createdBy,
|
||||
).Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.RequesterPhone, &dsr.IdentityVerified, &detailsJSON,
|
||||
&dsr.DeadlineAt, &dsr.LegalDeadlineDays, &dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DSR: %w", err)
|
||||
}
|
||||
|
||||
// Parse details back
|
||||
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
|
||||
|
||||
// Record initial status
|
||||
s.recordStatusChange(ctx, dsr.ID, nil, models.DSRStatusIntake, createdBy, "Anfrage eingegangen")
|
||||
|
||||
// Notify DPOs about new request
|
||||
go s.notifyNewRequest(context.Background(), &dsr)
|
||||
|
||||
return &dsr, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a DSR by ID
|
||||
func (s *DSRService) GetByID(ctx context.Context, id uuid.UUID) (*models.DataSubjectRequest, error) {
|
||||
var dsr models.DataSubjectRequest
|
||||
var detailsJSON, resultDataJSON []byte
|
||||
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone,
|
||||
identity_verified, identity_verified_at, identity_verified_by, identity_verification_method,
|
||||
request_details, deadline_at, legal_deadline_days, extended_deadline_at, extension_reason,
|
||||
assigned_to, processing_notes, completed_at, completed_by, result_summary, result_data,
|
||||
rejected_at, rejected_by, rejection_reason, rejection_legal_basis,
|
||||
created_at, updated_at, created_by
|
||||
FROM data_subject_requests WHERE id = $1
|
||||
`, id).Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.IdentityVerifiedAt,
|
||||
&dsr.IdentityVerifiedBy, &dsr.IdentityVerificationMethod,
|
||||
&detailsJSON, &dsr.DeadlineAt, &dsr.LegalDeadlineDays,
|
||||
&dsr.ExtendedDeadlineAt, &dsr.ExtensionReason, &dsr.AssignedTo,
|
||||
&dsr.ProcessingNotes, &dsr.CompletedAt, &dsr.CompletedBy,
|
||||
&dsr.ResultSummary, &resultDataJSON, &dsr.RejectedAt, &dsr.RejectedBy,
|
||||
&dsr.RejectionReason, &dsr.RejectionLegalBasis,
|
||||
&dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
|
||||
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
|
||||
json.Unmarshal(resultDataJSON, &dsr.ResultData)
|
||||
|
||||
return &dsr, nil
|
||||
}
|
||||
|
||||
// GetByNumber retrieves a DSR by request number
|
||||
func (s *DSRService) GetByNumber(ctx context.Context, requestNumber string) (*models.DataSubjectRequest, error) {
|
||||
var id uuid.UUID
|
||||
err := s.pool.QueryRow(ctx, "SELECT id FROM data_subject_requests WHERE request_number = $1", requestNumber).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
return s.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// List retrieves DSRs with filters and pagination
|
||||
func (s *DSRService) List(ctx context.Context, filters models.DSRListFilters, limit, offset int) ([]models.DataSubjectRequest, int, error) {
|
||||
// Build query
|
||||
baseQuery := "FROM data_subject_requests WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if filters.Status != nil && *filters.Status != "" {
|
||||
baseQuery += fmt.Sprintf(" AND status = $%d", argIndex)
|
||||
args = append(args, *filters.Status)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.RequestType != nil && *filters.RequestType != "" {
|
||||
baseQuery += fmt.Sprintf(" AND request_type = $%d", argIndex)
|
||||
args = append(args, *filters.RequestType)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.AssignedTo != nil && *filters.AssignedTo != "" {
|
||||
baseQuery += fmt.Sprintf(" AND assigned_to = $%d", argIndex)
|
||||
args = append(args, *filters.AssignedTo)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.Priority != nil && *filters.Priority != "" {
|
||||
baseQuery += fmt.Sprintf(" AND priority = $%d", argIndex)
|
||||
args = append(args, *filters.Priority)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.OverdueOnly {
|
||||
baseQuery += " AND deadline_at < NOW() AND status NOT IN ('completed', 'rejected', 'cancelled')"
|
||||
}
|
||||
|
||||
if filters.FromDate != nil {
|
||||
baseQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex)
|
||||
args = append(args, *filters.FromDate)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.ToDate != nil {
|
||||
baseQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex)
|
||||
args = append(args, *filters.ToDate)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filters.Search != nil && *filters.Search != "" {
|
||||
searchPattern := "%" + *filters.Search + "%"
|
||||
baseQuery += fmt.Sprintf(" AND (request_number ILIKE $%d OR requester_email ILIKE $%d OR requester_name ILIKE $%d)", argIndex, argIndex, argIndex)
|
||||
args = append(args, searchPattern)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
// Get total count
|
||||
var total int
|
||||
err := s.pool.QueryRow(ctx, "SELECT COUNT(*) "+baseQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count DSRs: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, requester_phone, identity_verified,
|
||||
deadline_at, legal_deadline_days, assigned_to, created_at, updated_at
|
||||
%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d
|
||||
`, baseQuery, argIndex, argIndex+1)
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query DSRs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dsrs []models.DataSubjectRequest
|
||||
for rows.Next() {
|
||||
var dsr models.DataSubjectRequest
|
||||
err := rows.Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.DeadlineAt,
|
||||
&dsr.LegalDeadlineDays, &dsr.AssignedTo, &dsr.CreatedAt, &dsr.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to scan DSR: %w", err)
|
||||
}
|
||||
dsrs = append(dsrs, dsr)
|
||||
}
|
||||
|
||||
return dsrs, total, nil
|
||||
}
|
||||
|
||||
// ListByUser retrieves DSRs for a specific user
|
||||
func (s *DSRService) ListByUser(ctx context.Context, userID uuid.UUID) ([]models.DataSubjectRequest, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, user_id, request_number, request_type, status, priority, source,
|
||||
requester_email, requester_name, deadline_at, created_at, updated_at
|
||||
FROM data_subject_requests
|
||||
WHERE user_id = $1 OR requester_email = (SELECT email FROM users WHERE id = $1)
|
||||
ORDER BY created_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query user DSRs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dsrs []models.DataSubjectRequest
|
||||
for rows.Next() {
|
||||
var dsr models.DataSubjectRequest
|
||||
err := rows.Scan(
|
||||
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
|
||||
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
|
||||
&dsr.DeadlineAt, &dsr.CreatedAt, &dsr.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan DSR: %w", err)
|
||||
}
|
||||
dsrs = append(dsrs, dsr)
|
||||
}
|
||||
|
||||
return dsrs, nil
|
||||
}
|
||||
|
||||
// UpdateStatus changes the status of a DSR
|
||||
func (s *DSRService) UpdateStatus(ctx context.Context, id uuid.UUID, newStatus models.DSRStatus, comment string, changedBy *uuid.UUID) error {
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
err := s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
|
||||
// Validate transition
|
||||
if !isValidStatusTransition(currentStatus, newStatus) {
|
||||
return fmt.Errorf("invalid status transition from %s to %s", currentStatus, newStatus)
|
||||
}
|
||||
|
||||
// Update status
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests SET status = $1, updated_at = NOW() WHERE id = $2
|
||||
`, newStatus, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update status: %w", err)
|
||||
}
|
||||
|
||||
// Record status change
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, newStatus, changedBy, comment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyIdentity marks identity as verified
|
||||
func (s *DSRService) VerifyIdentity(ctx context.Context, id uuid.UUID, method string, verifiedBy uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET identity_verified = TRUE,
|
||||
identity_verified_at = NOW(),
|
||||
identity_verified_by = $1,
|
||||
identity_verification_method = $2,
|
||||
status = CASE WHEN status = 'intake' THEN 'identity_verification' ELSE status END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, verifiedBy, method, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify identity: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, nil, models.DSRStatusIdentityVerification, &verifiedBy, "Identität verifiziert via "+method)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignRequest assigns a DSR to a handler
|
||||
func (s *DSRService) AssignRequest(ctx context.Context, id uuid.UUID, assigneeID uuid.UUID, assignedBy uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests SET assigned_to = $1, updated_at = NOW() WHERE id = $2
|
||||
`, assigneeID, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to assign DSR: %w", err)
|
||||
}
|
||||
|
||||
// Get assignee name for comment
|
||||
var assigneeName string
|
||||
s.pool.QueryRow(ctx, "SELECT COALESCE(name, email) FROM users WHERE id = $1", assigneeID).Scan(&assigneeName)
|
||||
|
||||
s.recordStatusChange(ctx, id, nil, "", &assignedBy, "Zugewiesen an "+assigneeName)
|
||||
|
||||
// Notify assignee
|
||||
go s.notifyAssignment(context.Background(), id, assigneeID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtendDeadline extends the deadline for a DSR
|
||||
func (s *DSRService) ExtendDeadline(ctx context.Context, id uuid.UUID, reason string, days int, extendedBy uuid.UUID) error {
|
||||
// Default extension is 2 months (60 days) per Art. 12(3)
|
||||
if days <= 0 {
|
||||
days = 60
|
||||
}
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET extended_deadline_at = deadline_at + ($1 || ' days')::INTERVAL,
|
||||
extension_reason = $2,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, days, reason, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extend deadline: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, nil, "", &extendedBy, fmt.Sprintf("Frist um %d Tage verlängert: %s", days, reason))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteRequest marks a DSR as completed
|
||||
func (s *DSRService) CompleteRequest(ctx context.Context, id uuid.UUID, summary string, resultData map[string]interface{}, completedBy uuid.UUID) error {
|
||||
resultJSON, _ := json.Marshal(resultData)
|
||||
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET status = 'completed',
|
||||
completed_at = NOW(),
|
||||
completed_by = $1,
|
||||
result_summary = $2,
|
||||
result_data = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
`, completedBy, summary, resultJSON, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to complete DSR: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, models.DSRStatusCompleted, &completedBy, summary)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectRequest rejects a DSR with legal basis
|
||||
func (s *DSRService) RejectRequest(ctx context.Context, id uuid.UUID, reason, legalBasis string, rejectedBy uuid.UUID) error {
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests
|
||||
SET status = 'rejected',
|
||||
rejected_at = NOW(),
|
||||
rejected_by = $1,
|
||||
rejection_reason = $2,
|
||||
rejection_legal_basis = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $4
|
||||
`, rejectedBy, reason, legalBasis, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reject DSR: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, models.DSRStatusRejected, &rejectedBy, fmt.Sprintf("Abgelehnt (%s): %s", legalBasis, reason))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelRequest cancels a DSR (by user)
|
||||
func (s *DSRService) CancelRequest(ctx context.Context, id uuid.UUID, cancelledBy uuid.UUID) error {
|
||||
// Verify ownership
|
||||
var userID *uuid.UUID
|
||||
err := s.pool.QueryRow(ctx, "SELECT user_id FROM data_subject_requests WHERE id = $1", id).Scan(&userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DSR not found: %w", err)
|
||||
}
|
||||
if userID == nil || *userID != cancelledBy {
|
||||
return fmt.Errorf("unauthorized: can only cancel own requests")
|
||||
}
|
||||
|
||||
// Get current status
|
||||
var currentStatus models.DSRStatus
|
||||
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(¤tStatus)
|
||||
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
UPDATE data_subject_requests SET status = 'cancelled', updated_at = NOW() WHERE id = $1
|
||||
`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cancel DSR: %w", err)
|
||||
}
|
||||
|
||||
s.recordStatusChange(ctx, id, ¤tStatus, models.DSRStatusCancelled, &cancelledBy, "Vom Antragsteller storniert")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDashboardStats returns statistics for the admin dashboard
|
||||
func (s *DSRService) GetDashboardStats(ctx context.Context) (*models.DSRDashboardStats, error) {
|
||||
stats := &models.DSRDashboardStats{
|
||||
ByType: make(map[string]int),
|
||||
ByStatus: make(map[string]int),
|
||||
}
|
||||
|
||||
// Total requests
|
||||
s.pool.QueryRow(ctx, "SELECT COUNT(*) FROM data_subject_requests").Scan(&stats.TotalRequests)
|
||||
|
||||
// Pending requests (not completed, rejected, or cancelled)
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM data_subject_requests
|
||||
WHERE status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`).Scan(&stats.PendingRequests)
|
||||
|
||||
// Overdue requests
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) < NOW()
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`).Scan(&stats.OverdueRequests)
|
||||
|
||||
// Completed this month
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM data_subject_requests
|
||||
WHERE status = 'completed'
|
||||
AND completed_at >= DATE_TRUNC('month', NOW())
|
||||
`).Scan(&stats.CompletedThisMonth)
|
||||
|
||||
// Average processing days
|
||||
s.pool.QueryRow(ctx, `
|
||||
SELECT COALESCE(AVG(EXTRACT(EPOCH FROM (completed_at - created_at)) / 86400), 0)
|
||||
FROM data_subject_requests WHERE status = 'completed'
|
||||
`).Scan(&stats.AverageProcessingDays)
|
||||
|
||||
// Count by type
|
||||
rows, _ := s.pool.Query(ctx, `
|
||||
SELECT request_type, COUNT(*) FROM data_subject_requests GROUP BY request_type
|
||||
`)
|
||||
for rows.Next() {
|
||||
var t string
|
||||
var count int
|
||||
rows.Scan(&t, &count)
|
||||
stats.ByType[t] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Count by status
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT status, COUNT(*) FROM data_subject_requests GROUP BY status
|
||||
`)
|
||||
for rows.Next() {
|
||||
var s string
|
||||
var count int
|
||||
rows.Scan(&s, &count)
|
||||
stats.ByStatus[s] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Upcoming deadlines (next 7 days)
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, status, requester_email, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN NOW() AND NOW() + INTERVAL '7 days'
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
ORDER BY deadline_at ASC LIMIT 10
|
||||
`)
|
||||
for rows.Next() {
|
||||
var dsr models.DataSubjectRequest
|
||||
rows.Scan(&dsr.ID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status, &dsr.RequesterEmail, &dsr.DeadlineAt)
|
||||
stats.UpcomingDeadlines = append(stats.UpcomingDeadlines, dsr)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetStatusHistory retrieves the status history for a DSR
|
||||
func (s *DSRService) GetStatusHistory(ctx context.Context, requestID uuid.UUID) ([]models.DSRStatusHistory, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, request_id, from_status, to_status, changed_by, comment, metadata, created_at
|
||||
FROM dsr_status_history WHERE request_id = $1 ORDER BY created_at DESC
|
||||
`, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query status history: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var history []models.DSRStatusHistory
|
||||
for rows.Next() {
|
||||
var h models.DSRStatusHistory
|
||||
var metadataJSON []byte
|
||||
err := rows.Scan(&h.ID, &h.RequestID, &h.FromStatus, &h.ToStatus, &h.ChangedBy, &h.Comment, &metadataJSON, &h.CreatedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
json.Unmarshal(metadataJSON, &h.Metadata)
|
||||
history = append(history, h)
|
||||
}
|
||||
|
||||
return history, nil
|
||||
}
|
||||
|
||||
// GetCommunications retrieves communications for a DSR
|
||||
func (s *DSRService) GetCommunications(ctx context.Context, requestID uuid.UUID) ([]models.DSRCommunication, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, request_id, direction, channel, communication_type, template_version_id,
|
||||
subject, body_html, body_text, recipient_email, sent_at, error_message,
|
||||
attachments, created_at, created_by
|
||||
FROM dsr_communications WHERE request_id = $1 ORDER BY created_at DESC
|
||||
`, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query communications: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var comms []models.DSRCommunication
|
||||
for rows.Next() {
|
||||
var c models.DSRCommunication
|
||||
var attachmentsJSON []byte
|
||||
err := rows.Scan(&c.ID, &c.RequestID, &c.Direction, &c.Channel, &c.CommunicationType,
|
||||
&c.TemplateVersionID, &c.Subject, &c.BodyHTML, &c.BodyText, &c.RecipientEmail,
|
||||
&c.SentAt, &c.ErrorMessage, &attachmentsJSON, &c.CreatedAt, &c.CreatedBy)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
json.Unmarshal(attachmentsJSON, &c.Attachments)
|
||||
comms = append(comms, c)
|
||||
}
|
||||
|
||||
return comms, nil
|
||||
}
|
||||
|
||||
// SendCommunication sends a communication for a DSR
|
||||
func (s *DSRService) SendCommunication(ctx context.Context, requestID uuid.UUID, req models.SendDSRCommunicationRequest, sentBy uuid.UUID) error {
|
||||
// Get DSR details
|
||||
dsr, err := s.GetByID(ctx, requestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get template if specified
|
||||
var subject, bodyHTML, bodyText string
|
||||
if req.TemplateVersionID != nil {
|
||||
templateVersionID, _ := uuid.Parse(*req.TemplateVersionID)
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT subject, body_html, body_text FROM dsr_template_versions WHERE id = $1 AND status = 'published'
|
||||
`, templateVersionID).Scan(&subject, &bodyHTML, &bodyText)
|
||||
if err != nil {
|
||||
return fmt.Errorf("template version not found or not published: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Use custom content if provided
|
||||
if req.CustomSubject != nil {
|
||||
subject = *req.CustomSubject
|
||||
}
|
||||
if req.CustomBody != nil {
|
||||
bodyHTML = *req.CustomBody
|
||||
bodyText = stripHTML(*req.CustomBody)
|
||||
}
|
||||
|
||||
// Replace variables
|
||||
variables := map[string]string{
|
||||
"requester_name": stringOrDefault(dsr.RequesterName, "Antragsteller/in"),
|
||||
"request_number": dsr.RequestNumber,
|
||||
"request_type_de": dsr.RequestType.Label(),
|
||||
"request_date": dsr.CreatedAt.Format("02.01.2006"),
|
||||
"deadline_date": dsr.DeadlineAt.Format("02.01.2006"),
|
||||
}
|
||||
for k, v := range req.Variables {
|
||||
variables[k] = v
|
||||
}
|
||||
subject = replaceVariables(subject, variables)
|
||||
bodyHTML = replaceVariables(bodyHTML, variables)
|
||||
bodyText = replaceVariables(bodyText, variables)
|
||||
|
||||
// Send email
|
||||
if s.emailService != nil {
|
||||
err = s.emailService.SendEmail(dsr.RequesterEmail, subject, bodyHTML, bodyText)
|
||||
if err != nil {
|
||||
// Log error but continue
|
||||
_, _ = s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
|
||||
template_version_id, subject, body_html, body_text, recipient_email, error_message, created_by)
|
||||
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
|
||||
dsr.RequesterEmail, err.Error(), sentBy)
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Log communication
|
||||
now := time.Now()
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
|
||||
template_version_id, subject, body_html, body_text, recipient_email, sent_at, created_by)
|
||||
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
|
||||
dsr.RequesterEmail, now, sentBy)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// InitErasureExceptionChecks initializes exception checks for an erasure request
|
||||
func (s *DSRService) InitErasureExceptionChecks(ctx context.Context, requestID uuid.UUID) error {
|
||||
exceptions := []struct {
|
||||
Type string
|
||||
Description string
|
||||
}{
|
||||
{models.DSRExceptionFreedomExpression, "Ausübung des Rechts auf freie Meinungsäußerung und Information (Art. 17 Abs. 3 lit. a)"},
|
||||
{models.DSRExceptionLegalObligation, "Erfüllung einer rechtlichen Verpflichtung oder öffentlichen Aufgabe (Art. 17 Abs. 3 lit. b)"},
|
||||
{models.DSRExceptionPublicHealth, "Gründe des öffentlichen Interesses im Bereich der öffentlichen Gesundheit (Art. 17 Abs. 3 lit. c)"},
|
||||
{models.DSRExceptionArchiving, "Im öffentlichen Interesse liegende Archivzwecke, Forschung oder Statistik (Art. 17 Abs. 3 lit. d)"},
|
||||
{models.DSRExceptionLegalClaims, "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen (Art. 17 Abs. 3 lit. e)"},
|
||||
}
|
||||
|
||||
for _, exc := range exceptions {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_exception_checks (request_id, exception_type, description)
|
||||
VALUES ($1, $2, $3) ON CONFLICT DO NOTHING
|
||||
`, requestID, exc.Type, exc.Description)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create exception check: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExceptionChecks retrieves exception checks for a DSR
|
||||
func (s *DSRService) GetExceptionChecks(ctx context.Context, requestID uuid.UUID) ([]models.DSRExceptionCheck, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, request_id, exception_type, description, applies, checked_by, checked_at, notes, created_at
|
||||
FROM dsr_exception_checks WHERE request_id = $1 ORDER BY created_at
|
||||
`, requestID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query exception checks: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var checks []models.DSRExceptionCheck
|
||||
for rows.Next() {
|
||||
var c models.DSRExceptionCheck
|
||||
err := rows.Scan(&c.ID, &c.RequestID, &c.ExceptionType, &c.Description, &c.Applies,
|
||||
&c.CheckedBy, &c.CheckedAt, &c.Notes, &c.CreatedAt)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
checks = append(checks, c)
|
||||
}
|
||||
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
// UpdateExceptionCheck updates an exception check
|
||||
func (s *DSRService) UpdateExceptionCheck(ctx context.Context, checkID uuid.UUID, applies bool, notes *string, checkedBy uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE dsr_exception_checks
|
||||
SET applies = $1, notes = $2, checked_by = $3, checked_at = NOW()
|
||||
WHERE id = $4
|
||||
`, applies, notes, checkedBy, checkID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ProcessDeadlines checks for approaching and overdue deadlines
|
||||
func (s *DSRService) ProcessDeadlines(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// Find requests with deadlines in 3 days
|
||||
threeDaysAhead := now.AddDate(0, 0, 3)
|
||||
rows, _ := s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, assigned_to, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`, now, threeDaysAhead)
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var requestNumber, requestType string
|
||||
var assignedTo *uuid.UUID
|
||||
var deadline time.Time
|
||||
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
|
||||
|
||||
// Notify assigned user or all DPOs
|
||||
if assignedTo != nil {
|
||||
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 3)
|
||||
} else {
|
||||
s.notifyAllDPOs(ctx, id, requestNumber, "Frist in 3 Tagen", deadline)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Find requests with deadlines in 1 day
|
||||
oneDayAhead := now.AddDate(0, 0, 1)
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, assigned_to, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`, now, oneDayAhead)
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var requestNumber, requestType string
|
||||
var assignedTo *uuid.UUID
|
||||
var deadline time.Time
|
||||
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
|
||||
|
||||
if assignedTo != nil {
|
||||
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 1)
|
||||
} else {
|
||||
s.notifyAllDPOs(ctx, id, requestNumber, "Frist morgen!", deadline)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Find overdue requests
|
||||
rows, _ = s.pool.Query(ctx, `
|
||||
SELECT id, request_number, request_type, assigned_to, deadline_at
|
||||
FROM data_subject_requests
|
||||
WHERE COALESCE(extended_deadline_at, deadline_at) < $1
|
||||
AND status NOT IN ('completed', 'rejected', 'cancelled')
|
||||
`, now)
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
var requestNumber, requestType string
|
||||
var assignedTo *uuid.UUID
|
||||
var deadline time.Time
|
||||
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
|
||||
|
||||
// Notify all DPOs for overdue
|
||||
s.notifyAllDPOs(ctx, id, requestNumber, "ÜBERFÄLLIG!", deadline)
|
||||
|
||||
// Log to audit
|
||||
s.pool.Exec(ctx, `
|
||||
INSERT INTO consent_audit_log (action, entity_type, entity_id, details)
|
||||
VALUES ('dsr_overdue', 'dsr', $1, $2)
|
||||
`, id, fmt.Sprintf(`{"request_number": "%s", "deadline": "%s"}`, requestNumber, deadline.Format(time.RFC3339)))
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func (s *DSRService) recordStatusChange(ctx context.Context, requestID uuid.UUID, fromStatus *models.DSRStatus, toStatus models.DSRStatus, changedBy *uuid.UUID, comment string) {
|
||||
s.pool.Exec(ctx, `
|
||||
INSERT INTO dsr_status_history (request_id, from_status, to_status, changed_by, comment)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, requestID, fromStatus, toStatus, changedBy, comment)
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyNewRequest(ctx context.Context, dsr *models.DataSubjectRequest) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
// Notify all DPOs
|
||||
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
rows.Scan(&userID)
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRReceived,
|
||||
"Neue Betroffenenanfrage",
|
||||
fmt.Sprintf("Neue %s eingegangen: %s", dsr.RequestType.Label(), dsr.RequestNumber),
|
||||
map[string]interface{}{"dsr_id": dsr.ID, "request_number": dsr.RequestNumber})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyAssignment(ctx context.Context, dsrID, assigneeID uuid.UUID) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
dsr, _ := s.GetByID(ctx, dsrID)
|
||||
if dsr != nil {
|
||||
s.notificationService.CreateNotification(ctx, assigneeID, NotificationTypeDSRAssigned,
|
||||
"Betroffenenanfrage zugewiesen",
|
||||
fmt.Sprintf("Ihnen wurde die Anfrage %s zugewiesen", dsr.RequestNumber),
|
||||
map[string]interface{}{"dsr_id": dsrID, "request_number": dsr.RequestNumber})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyDeadlineWarning(ctx context.Context, dsrID, userID uuid.UUID, requestNumber string, deadline time.Time, daysLeft int) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
|
||||
fmt.Sprintf("Fristwarnung: %s", requestNumber),
|
||||
fmt.Sprintf("Die Frist für %s läuft in %d Tag(en) ab (%s)", requestNumber, daysLeft, deadline.Format("02.01.2006")),
|
||||
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline, "days_left": daysLeft})
|
||||
}
|
||||
|
||||
func (s *DSRService) notifyAllDPOs(ctx context.Context, dsrID uuid.UUID, requestNumber, message string, deadline time.Time) {
|
||||
if s.notificationService == nil {
|
||||
return
|
||||
}
|
||||
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var userID uuid.UUID
|
||||
rows.Scan(&userID)
|
||||
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
|
||||
fmt.Sprintf("%s: %s", message, requestNumber),
|
||||
fmt.Sprintf("Anfrage %s: %s (Frist: %s)", requestNumber, message, deadline.Format("02.01.2006")),
|
||||
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline})
|
||||
}
|
||||
}
|
||||
|
||||
func isValidRequestType(rt models.DSRRequestType) bool {
|
||||
switch rt {
|
||||
case models.DSRTypeAccess, models.DSRTypeRectification, models.DSRTypeErasure,
|
||||
models.DSRTypeRestriction, models.DSRTypePortability:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isValidStatusTransition(from, to models.DSRStatus) bool {
|
||||
validTransitions := map[models.DSRStatus][]models.DSRStatus{
|
||||
models.DSRStatusIntake: {models.DSRStatusIdentityVerification, models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
|
||||
models.DSRStatusIdentityVerification: {models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
|
||||
models.DSRStatusProcessing: {models.DSRStatusCompleted, models.DSRStatusRejected, models.DSRStatusCancelled},
|
||||
models.DSRStatusCompleted: {},
|
||||
models.DSRStatusRejected: {},
|
||||
models.DSRStatusCancelled: {},
|
||||
}
|
||||
|
||||
allowed, exists := validTransitions[from]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
for _, s := range allowed {
|
||||
if s == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringOrDefault(s *string, def string) string {
|
||||
if s != nil {
|
||||
return *s
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func replaceVariables(text string, variables map[string]string) string {
|
||||
for k, v := range variables {
|
||||
text = strings.ReplaceAll(text, "{{"+k+"}}", v)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func stripHTML(html string) string {
|
||||
// Simple HTML stripping - in production use a proper library
|
||||
text := strings.ReplaceAll(html, "<br>", "\n")
|
||||
text = strings.ReplaceAll(text, "<br/>", "\n")
|
||||
text = strings.ReplaceAll(text, "<br />", "\n")
|
||||
text = strings.ReplaceAll(text, "</p>", "\n\n")
|
||||
// Remove all remaining tags
|
||||
for {
|
||||
start := strings.Index(text, "<")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(text[start:], ">")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
text = text[:start] + text[start+end+1:]
|
||||
}
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
@@ -0,0 +1,420 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/breakpilot/consent-service/internal/models"
|
||||
)
|
||||
|
||||
// TestDSRRequestTypeLabel tests label generation for request types
|
||||
func TestDSRRequestTypeLabel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
expected string
|
||||
}{
|
||||
{"access type", models.DSRTypeAccess, "Auskunftsanfrage (Art. 15)"},
|
||||
{"rectification type", models.DSRTypeRectification, "Berichtigungsanfrage (Art. 16)"},
|
||||
{"erasure type", models.DSRTypeErasure, "Löschanfrage (Art. 17)"},
|
||||
{"restriction type", models.DSRTypeRestriction, "Einschränkungsanfrage (Art. 18)"},
|
||||
{"portability type", models.DSRTypePortability, "Datenübertragung (Art. 20)"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.reqType.Label()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRRequestTypeDeadlineDays tests deadline calculation for different request types
|
||||
func TestDSRRequestTypeDeadlineDays(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
expectedDays int
|
||||
}{
|
||||
{"access has 30 days", models.DSRTypeAccess, 30},
|
||||
{"portability has 30 days", models.DSRTypePortability, 30},
|
||||
{"rectification has 14 days", models.DSRTypeRectification, 14},
|
||||
{"erasure has 14 days", models.DSRTypeErasure, 14},
|
||||
{"restriction has 14 days", models.DSRTypeRestriction, 14},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.reqType.DeadlineDays()
|
||||
if result != tt.expectedDays {
|
||||
t.Errorf("Expected %d days, got %d", tt.expectedDays, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRRequestTypeIsExpedited tests expedited flag for request types
|
||||
func TestDSRRequestTypeIsExpedited(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
isExpedited bool
|
||||
}{
|
||||
{"access not expedited", models.DSRTypeAccess, false},
|
||||
{"portability not expedited", models.DSRTypePortability, false},
|
||||
{"rectification is expedited", models.DSRTypeRectification, true},
|
||||
{"erasure is expedited", models.DSRTypeErasure, true},
|
||||
{"restriction is expedited", models.DSRTypeRestriction, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.reqType.IsExpedited()
|
||||
if result != tt.isExpedited {
|
||||
t.Errorf("Expected IsExpedited=%v, got %v", tt.isExpedited, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRStatusLabel tests label generation for statuses
|
||||
func TestDSRStatusLabel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status models.DSRStatus
|
||||
expected string
|
||||
}{
|
||||
{"intake status", models.DSRStatusIntake, "Eingang"},
|
||||
{"identity verification", models.DSRStatusIdentityVerification, "Identitätsprüfung"},
|
||||
{"processing status", models.DSRStatusProcessing, "In Bearbeitung"},
|
||||
{"completed status", models.DSRStatusCompleted, "Abgeschlossen"},
|
||||
{"rejected status", models.DSRStatusRejected, "Abgelehnt"},
|
||||
{"cancelled status", models.DSRStatusCancelled, "Storniert"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.status.Label()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidDSRRequestType tests request type validation
|
||||
func TestValidDSRRequestType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType string
|
||||
valid bool
|
||||
}{
|
||||
{"valid access", "access", true},
|
||||
{"valid rectification", "rectification", true},
|
||||
{"valid erasure", "erasure", true},
|
||||
{"valid restriction", "restriction", true},
|
||||
{"valid portability", "portability", true},
|
||||
{"invalid type", "invalid", false},
|
||||
{"empty type", "", false},
|
||||
{"random string", "delete_everything", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := models.IsValidDSRRequestType(tt.reqType)
|
||||
if result != tt.valid {
|
||||
t.Errorf("Expected IsValidDSRRequestType=%v for %s, got %v", tt.valid, tt.reqType, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidDSRStatus tests status validation
|
||||
func TestValidDSRStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status string
|
||||
valid bool
|
||||
}{
|
||||
{"valid intake", "intake", true},
|
||||
{"valid identity_verification", "identity_verification", true},
|
||||
{"valid processing", "processing", true},
|
||||
{"valid completed", "completed", true},
|
||||
{"valid rejected", "rejected", true},
|
||||
{"valid cancelled", "cancelled", true},
|
||||
{"invalid status", "invalid", false},
|
||||
{"empty status", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := models.IsValidDSRStatus(tt.status)
|
||||
if result != tt.valid {
|
||||
t.Errorf("Expected IsValidDSRStatus=%v for %s, got %v", tt.valid, tt.status, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDSRStatusTransitionValidation tests allowed status transitions
|
||||
func TestDSRStatusTransitionValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fromStatus models.DSRStatus
|
||||
toStatus models.DSRStatus
|
||||
allowed bool
|
||||
}{
|
||||
// From intake
|
||||
{"intake to identity_verification", models.DSRStatusIntake, models.DSRStatusIdentityVerification, true},
|
||||
{"intake to processing", models.DSRStatusIntake, models.DSRStatusProcessing, true},
|
||||
{"intake to rejected", models.DSRStatusIntake, models.DSRStatusRejected, true},
|
||||
{"intake to cancelled", models.DSRStatusIntake, models.DSRStatusCancelled, true},
|
||||
{"intake to completed invalid", models.DSRStatusIntake, models.DSRStatusCompleted, false},
|
||||
|
||||
// From identity_verification
|
||||
{"identity to processing", models.DSRStatusIdentityVerification, models.DSRStatusProcessing, true},
|
||||
{"identity to rejected", models.DSRStatusIdentityVerification, models.DSRStatusRejected, true},
|
||||
{"identity to cancelled", models.DSRStatusIdentityVerification, models.DSRStatusCancelled, true},
|
||||
|
||||
// From processing
|
||||
{"processing to completed", models.DSRStatusProcessing, models.DSRStatusCompleted, true},
|
||||
{"processing to rejected", models.DSRStatusProcessing, models.DSRStatusRejected, true},
|
||||
{"processing to intake invalid", models.DSRStatusProcessing, models.DSRStatusIntake, false},
|
||||
|
||||
// From completed
|
||||
{"completed to anything invalid", models.DSRStatusCompleted, models.DSRStatusProcessing, false},
|
||||
|
||||
// From rejected
|
||||
{"rejected to anything invalid", models.DSRStatusRejected, models.DSRStatusProcessing, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := testIsValidStatusTransition(tt.fromStatus, tt.toStatus)
|
||||
if result != tt.allowed {
|
||||
t.Errorf("Expected transition %s->%s allowed=%v, got %v",
|
||||
tt.fromStatus, tt.toStatus, tt.allowed, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testIsValidStatusTransition is a test helper for validating status transitions
|
||||
// This mirrors the logic in dsr_service.go for testing purposes
|
||||
func testIsValidStatusTransition(from, to models.DSRStatus) bool {
|
||||
validTransitions := map[models.DSRStatus][]models.DSRStatus{
|
||||
models.DSRStatusIntake: {
|
||||
models.DSRStatusIdentityVerification,
|
||||
models.DSRStatusProcessing,
|
||||
models.DSRStatusRejected,
|
||||
models.DSRStatusCancelled,
|
||||
},
|
||||
models.DSRStatusIdentityVerification: {
|
||||
models.DSRStatusProcessing,
|
||||
models.DSRStatusRejected,
|
||||
models.DSRStatusCancelled,
|
||||
},
|
||||
models.DSRStatusProcessing: {
|
||||
models.DSRStatusCompleted,
|
||||
models.DSRStatusRejected,
|
||||
models.DSRStatusCancelled,
|
||||
},
|
||||
models.DSRStatusCompleted: {},
|
||||
models.DSRStatusRejected: {},
|
||||
models.DSRStatusCancelled: {},
|
||||
}
|
||||
|
||||
allowed, exists := validTransitions[from]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, s := range allowed {
|
||||
if s == to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestCalculateDeadline tests deadline calculation
|
||||
func TestCalculateDeadline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType models.DSRRequestType
|
||||
expectedDays int
|
||||
}{
|
||||
{"access 30 days", models.DSRTypeAccess, 30},
|
||||
{"erasure 14 days", models.DSRTypeErasure, 14},
|
||||
{"rectification 14 days", models.DSRTypeRectification, 14},
|
||||
{"restriction 14 days", models.DSRTypeRestriction, 14},
|
||||
{"portability 30 days", models.DSRTypePortability, 30},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
now := time.Now()
|
||||
deadline := now.AddDate(0, 0, tt.expectedDays)
|
||||
days := tt.reqType.DeadlineDays()
|
||||
|
||||
if days != tt.expectedDays {
|
||||
t.Errorf("Expected %d days, got %d", tt.expectedDays, days)
|
||||
}
|
||||
|
||||
// Verify deadline is approximately correct (within 1 day due to test timing)
|
||||
calculatedDeadline := now.AddDate(0, 0, days)
|
||||
diff := calculatedDeadline.Sub(deadline)
|
||||
if diff > time.Hour*24 || diff < -time.Hour*24 {
|
||||
t.Errorf("Deadline calculation off by more than a day")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateDSRRequest_Validation tests validation of create request
|
||||
func TestCreateDSRRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request models.CreateDSRRequest
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid access request",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "access",
|
||||
RequesterEmail: "test@example.com",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid erasure request with name",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "erasure",
|
||||
RequesterEmail: "test@example.com",
|
||||
RequesterName: stringPtr("Max Mustermann"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing email",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "access",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid request type",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "invalid_type",
|
||||
RequesterEmail: "test@example.com",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty request type",
|
||||
request: models.CreateDSRRequest{
|
||||
RequestType: "",
|
||||
RequesterEmail: "test@example.com",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := testValidateCreateDSRRequest(tt.request)
|
||||
hasError := err != nil
|
||||
|
||||
if hasError != tt.expectError {
|
||||
t.Errorf("Expected error=%v, got error=%v (err: %v)", tt.expectError, hasError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testValidateCreateDSRRequest is a test helper for validating create DSR requests
|
||||
func testValidateCreateDSRRequest(req models.CreateDSRRequest) error {
|
||||
if req.RequesterEmail == "" {
|
||||
return &dsrValidationError{"requester_email is required"}
|
||||
}
|
||||
if !models.IsValidDSRRequestType(req.RequestType) {
|
||||
return &dsrValidationError{"invalid request_type"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dsrValidationError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *dsrValidationError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// TestDSRTemplateTypes tests the template types
|
||||
func TestDSRTemplateTypes(t *testing.T) {
|
||||
expectedTemplates := []string{
|
||||
"dsr_receipt_access",
|
||||
"dsr_receipt_rectification",
|
||||
"dsr_receipt_erasure",
|
||||
"dsr_receipt_restriction",
|
||||
"dsr_receipt_portability",
|
||||
"dsr_identity_request",
|
||||
"dsr_processing_started",
|
||||
"dsr_processing_update",
|
||||
"dsr_clarification_request",
|
||||
"dsr_completed_access",
|
||||
"dsr_completed_rectification",
|
||||
"dsr_completed_erasure",
|
||||
"dsr_completed_restriction",
|
||||
"dsr_completed_portability",
|
||||
"dsr_restriction_lifted",
|
||||
"dsr_rejected_identity",
|
||||
"dsr_rejected_exception",
|
||||
"dsr_rejected_unfounded",
|
||||
"dsr_deadline_warning",
|
||||
}
|
||||
|
||||
// This test documents the expected template types
|
||||
// The actual templates are created in database migration
|
||||
for _, template := range expectedTemplates {
|
||||
if template == "" {
|
||||
t.Error("Template type should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if len(expectedTemplates) != 19 {
|
||||
t.Errorf("Expected 19 template types, got %d", len(expectedTemplates))
|
||||
}
|
||||
}
|
||||
|
||||
// TestErasureExceptionTypes tests Art. 17(3) exception types
|
||||
func TestErasureExceptionTypes(t *testing.T) {
|
||||
exceptions := []struct {
|
||||
code string
|
||||
description string
|
||||
}{
|
||||
{"art_17_3_a", "Meinungs- und Informationsfreiheit"},
|
||||
{"art_17_3_b", "Rechtliche Verpflichtung"},
|
||||
{"art_17_3_c", "Öffentliches Interesse im Gesundheitsbereich"},
|
||||
{"art_17_3_d", "Archivzwecke, wissenschaftliche/historische Forschung"},
|
||||
{"art_17_3_e", "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen"},
|
||||
}
|
||||
|
||||
if len(exceptions) != 5 {
|
||||
t.Errorf("Expected 5 Art. 17(3) exceptions, got %d", len(exceptions))
|
||||
}
|
||||
|
||||
for _, ex := range exceptions {
|
||||
if ex.code == "" || ex.description == "" {
|
||||
t.Error("Exception code and description should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stringPtr returns a pointer to the given string
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,554 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EmailConfig holds SMTP configuration
|
||||
type EmailConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromName string
|
||||
FromAddr string
|
||||
BaseURL string // Frontend URL for links
|
||||
}
|
||||
|
||||
// EmailService handles sending emails
|
||||
type EmailService struct {
|
||||
config EmailConfig
|
||||
}
|
||||
|
||||
// NewEmailService creates a new EmailService
|
||||
func NewEmailService(config EmailConfig) *EmailService {
|
||||
return &EmailService{config: config}
|
||||
}
|
||||
|
||||
// SendEmail sends an email
|
||||
func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
|
||||
// Build MIME message
|
||||
var msg bytes.Buffer
|
||||
|
||||
msg.WriteString(fmt.Sprintf("From: %s <%s>\r\n", s.config.FromName, s.config.FromAddr))
|
||||
msg.WriteString(fmt.Sprintf("To: %s\r\n", to))
|
||||
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
|
||||
msg.WriteString("MIME-Version: 1.0\r\n")
|
||||
msg.WriteString("Content-Type: multipart/alternative; boundary=\"boundary42\"\r\n")
|
||||
msg.WriteString("\r\n")
|
||||
|
||||
// Text part
|
||||
msg.WriteString("--boundary42\r\n")
|
||||
msg.WriteString("Content-Type: text/plain; charset=\"UTF-8\"\r\n")
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString(textBody)
|
||||
msg.WriteString("\r\n")
|
||||
|
||||
// HTML part
|
||||
msg.WriteString("--boundary42\r\n")
|
||||
msg.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n")
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString(htmlBody)
|
||||
msg.WriteString("\r\n")
|
||||
msg.WriteString("--boundary42--\r\n")
|
||||
|
||||
// Send email
|
||||
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
|
||||
auth := smtp.PlainAuth("", s.config.Username, s.config.Password, s.config.Host)
|
||||
|
||||
err := smtp.SendMail(addr, auth, s.config.FromAddr, []string{to}, msg.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendVerificationEmail sends an email verification email
|
||||
func (s *EmailService) SendVerificationEmail(to, name, token string) error {
|
||||
verifyLink := fmt.Sprintf("%s/verify-email?token=%s", s.config.BaseURL, token)
|
||||
|
||||
subject := "Bitte bestätigen Sie Ihre E-Mail-Adresse - BreakPilot"
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Willkommen bei BreakPilot!
|
||||
|
||||
Bitte bestätigen Sie Ihre E-Mail-Adresse, indem Sie den folgenden Link öffnen:
|
||||
%s
|
||||
|
||||
Dieser Link ist 24 Stunden gültig.
|
||||
|
||||
Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), verifyLink)
|
||||
|
||||
htmlBody := s.renderTemplate("verification", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"VerifyLink": verifyLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendPasswordResetEmail sends a password reset email
|
||||
func (s *EmailService) SendPasswordResetEmail(to, name, token string) error {
|
||||
resetLink := fmt.Sprintf("%s/reset-password?token=%s", s.config.BaseURL, token)
|
||||
|
||||
subject := "Passwort zurücksetzen - BreakPilot"
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.
|
||||
|
||||
Klicken Sie auf den folgenden Link, um Ihr Passwort zurückzusetzen:
|
||||
%s
|
||||
|
||||
Dieser Link ist 1 Stunde gültig.
|
||||
|
||||
Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), resetLink)
|
||||
|
||||
htmlBody := s.renderTemplate("password_reset", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"ResetLink": resetLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendNewVersionNotification sends a notification about new document version
|
||||
func (s *EmailService) SendNewVersionNotification(to, name, documentName, documentType string, deadlineDays int) error {
|
||||
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
|
||||
|
||||
subject := fmt.Sprintf("Neue Version: %s - Bitte bestätigen Sie innerhalb von %d Tagen", documentName, deadlineDays)
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Wir haben unsere %s aktualisiert.
|
||||
|
||||
Bitte lesen und bestätigen Sie die neuen Bedingungen innerhalb der nächsten %d Tage:
|
||||
%s
|
||||
|
||||
Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), documentName, deadlineDays, consentLink)
|
||||
|
||||
htmlBody := s.renderTemplate("new_version", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"DocumentName": documentName,
|
||||
"DeadlineDays": deadlineDays,
|
||||
"ConsentLink": consentLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendConsentReminder sends a consent reminder email
|
||||
func (s *EmailService) SendConsentReminder(to, name string, documents []string, daysLeft int) error {
|
||||
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
|
||||
|
||||
urgency := "Erinnerung"
|
||||
if daysLeft <= 7 {
|
||||
urgency = "Dringend"
|
||||
}
|
||||
if daysLeft <= 2 {
|
||||
urgency = "Letzte Warnung"
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("%s: Noch %d Tage um ausstehende Dokumente zu bestätigen", urgency, daysLeft)
|
||||
|
||||
docList := strings.Join(documents, "\n- ")
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.
|
||||
|
||||
Ausstehende Dokumente:
|
||||
- %s
|
||||
|
||||
Sie haben noch %d Tage Zeit. Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
|
||||
|
||||
Bitte bestätigen Sie hier:
|
||||
%s
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), docList, daysLeft, consentLink)
|
||||
|
||||
htmlBody := s.renderTemplate("reminder", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"Documents": documents,
|
||||
"DaysLeft": daysLeft,
|
||||
"Urgency": urgency,
|
||||
"ConsentLink": consentLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendAccountSuspendedNotification sends notification when account is suspended
|
||||
func (s *EmailService) SendAccountSuspendedNotification(to, name string, documents []string) error {
|
||||
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
|
||||
|
||||
subject := "Ihr Account wurde vorübergehend gesperrt - BreakPilot"
|
||||
|
||||
docList := strings.Join(documents, "\n- ")
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Ihr Account wurde vorübergehend gesperrt, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben:
|
||||
|
||||
- %s
|
||||
|
||||
Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:
|
||||
%s
|
||||
|
||||
Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), docList, consentLink)
|
||||
|
||||
htmlBody := s.renderTemplate("suspended", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"Documents": documents,
|
||||
"ConsentLink": consentLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// SendAccountReactivatedNotification sends notification when account is reactivated
|
||||
func (s *EmailService) SendAccountReactivatedNotification(to, name string) error {
|
||||
appLink := fmt.Sprintf("%s/app", s.config.BaseURL)
|
||||
|
||||
subject := "Ihr Account wurde wieder aktiviert - BreakPilot"
|
||||
|
||||
textBody := fmt.Sprintf(`Hallo %s,
|
||||
|
||||
Vielen Dank für die Bestätigung der rechtlichen Dokumente!
|
||||
|
||||
Ihr Account wurde wieder aktiviert und Sie können BreakPilot wie gewohnt nutzen:
|
||||
%s
|
||||
|
||||
Mit freundlichen Grüßen,
|
||||
Ihr BreakPilot Team`, getDisplayName(name), appLink)
|
||||
|
||||
htmlBody := s.renderTemplate("reactivated", map[string]interface{}{
|
||||
"Name": getDisplayName(name),
|
||||
"AppLink": appLink,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, textBody)
|
||||
}
|
||||
|
||||
// renderTemplate renders an email HTML template
|
||||
func (s *EmailService) renderTemplate(templateName string, data map[string]interface{}) string {
|
||||
templates := map[string]string{
|
||||
"verification": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Willkommen bei BreakPilot!</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Vielen Dank für Ihre Registrierung! Bitte bestätigen Sie Ihre E-Mail-Adresse, um Ihr Konto zu aktivieren.</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.VerifyLink}}" class="button">E-Mail bestätigen</a>
|
||||
</p>
|
||||
<p>Dieser Link ist 24 Stunden gültig.</p>
|
||||
<p>Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"password_reset": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Passwort zurücksetzen</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ResetLink}}" class="button">Passwort zurücksetzen</a>
|
||||
</p>
|
||||
<div class="warning">
|
||||
<strong>Hinweis:</strong> Dieser Link ist nur 1 Stunde gültig.
|
||||
</div>
|
||||
<p>Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"new_version": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.info-box { background: #e0e7ff; border-left: 4px solid #6366f1; padding: 12px; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Neue Version: {{.DocumentName}}</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Wir haben unsere <strong>{{.DocumentName}}</strong> aktualisiert.</p>
|
||||
<div class="info-box">
|
||||
<strong>Wichtig:</strong> Bitte bestätigen Sie die neuen Bedingungen innerhalb der nächsten <strong>{{.DeadlineDays}} Tage</strong>.
|
||||
</div>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ConsentLink}}" class="button">Dokument ansehen & bestätigen</a>
|
||||
</p>
|
||||
<p>Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"reminder": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #f59e0b; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
|
||||
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>{{.Urgency}}: Ausstehende Bestätigungen</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<p>Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.</p>
|
||||
<div class="doc-list">
|
||||
<strong>Ausstehende Dokumente:</strong>
|
||||
<ul>
|
||||
{{range .Documents}}<li>{{.}}</li>{{end}}
|
||||
</ul>
|
||||
</div>
|
||||
<div class="warning">
|
||||
<strong>Sie haben noch {{.DaysLeft}} Tage Zeit.</strong> Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
|
||||
</div>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ConsentLink}}" class="button">Jetzt bestätigen</a>
|
||||
</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"suspended": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.alert { background: #fee2e2; border-left: 4px solid #ef4444; padding: 12px; margin: 20px 0; }
|
||||
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Account vorübergehend gesperrt</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<div class="alert">
|
||||
<strong>Ihr Account wurde vorübergehend gesperrt</strong>, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben.
|
||||
</div>
|
||||
<div class="doc-list">
|
||||
<strong>Nicht bestätigte Dokumente:</strong>
|
||||
<ul>
|
||||
{{range .Documents}}<li>{{.}}</li>{{end}}
|
||||
</ul>
|
||||
</div>
|
||||
<p>Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.ConsentLink}}" class="button">Dokumente bestätigen & Account entsperren</a>
|
||||
</p>
|
||||
<p>Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"reactivated": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #22c55e, #16a34a); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #22c55e; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.success { background: #dcfce7; border-left: 4px solid #22c55e; padding: 12px; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Account wieder aktiviert!</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hallo {{.Name}},</p>
|
||||
<div class="success">
|
||||
<strong>Vielen Dank!</strong> Ihr Account wurde erfolgreich wieder aktiviert.
|
||||
</div>
|
||||
<p>Sie können BreakPilot ab sofort wieder wie gewohnt nutzen.</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.AppLink}}" class="button">Zu BreakPilot</a>
|
||||
</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
|
||||
"generic_notification": `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
|
||||
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
|
||||
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
|
||||
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
|
||||
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>{{.Title}}</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>{{.Body}}</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{{.BaseURL}}/app" class="button">Zu BreakPilot</a>
|
||||
</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
}
|
||||
|
||||
tmplStr, ok := templates[templateName]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tmpl, err := template.New(templateName).Parse(tmplStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// SendConsentReminderEmail sends a simplified consent reminder email
|
||||
func (s *EmailService) SendConsentReminderEmail(to, title, body string) error {
|
||||
subject := title
|
||||
|
||||
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
|
||||
"Title": title,
|
||||
"Body": body,
|
||||
"BaseURL": s.config.BaseURL,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, body)
|
||||
}
|
||||
|
||||
// SendGenericNotificationEmail sends a generic notification email
|
||||
func (s *EmailService) SendGenericNotificationEmail(to, title, body string) error {
|
||||
subject := title
|
||||
|
||||
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
|
||||
"Title": title,
|
||||
"Body": body,
|
||||
"BaseURL": s.config.BaseURL,
|
||||
})
|
||||
|
||||
return s.SendEmail(to, subject, htmlBody, body)
|
||||
}
|
||||
|
||||
// getDisplayName returns display name or fallback
|
||||
func getDisplayName(name string) string {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return "Nutzer"
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user