diff --git a/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts b/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts index 37fec08..60d4bb1 100644 --- a/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts +++ b/admin-lehrer/app/(admin)/ai/ocr-pipeline/types.ts @@ -213,12 +213,25 @@ export interface RowGroundTruth { notes?: string } +export interface StructureGraphic { + x: number + y: number + w: number + h: number + area: number + shape: string // arrow, circle, line, exclamation, dot, icon, illustration + color_name: string + color_hex: string + confidence: number +} + export interface StructureResult { image_width: number image_height: number content_bounds: { x: number; y: number; w: number; h: number } boxes: StructureBox[] zones: StructureZone[] + graphics: StructureGraphic[] color_pixel_counts: Record has_words: boolean word_count: number diff --git a/admin-lehrer/components/ocr-pipeline/StepStructureDetection.tsx b/admin-lehrer/components/ocr-pipeline/StepStructureDetection.tsx index 2698150..e902d55 100644 --- a/admin-lehrer/components/ocr-pipeline/StepStructureDetection.tsx +++ b/admin-lehrer/components/ocr-pipeline/StepStructureDetection.tsx @@ -155,6 +155,11 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec {result.boxes.length} Box(en) + {result.graphics && result.graphics.length > 0 && ( + + {result.graphics.length} Grafik(en) + + )} {result.has_words && ( {result.word_count} Woerter @@ -223,6 +228,60 @@ export function StepStructureDetection({ sessionId, onNext }: StepStructureDetec + {/* Graphics / visual elements */} + {result.graphics && result.graphics.length > 0 && ( +
+

+ Graphische Elemente ({result.graphics.length}) +

+ {/* Summary by shape */} + {(() => { + const shapeCounts: Record = {} + for (const g of result.graphics) { + shapeCounts[g.shape] = (shapeCounts[g.shape] || 0) + 1 + } + return ( +
+ {Object.entries(shapeCounts) + .sort(([, a], [, b]) => b - a) + .map(([shape, count]) => ( + + {shape === 'arrow' ? '→' : shape === 'circle' ? '●' : shape === 'line' ? '─' : shape === 'exclamation' ? '❗' : shape === 'dot' ? '•' : shape === 'illustration' ? '🖼' : '◆'} + {' '}{shape} ×{count} + + ))} +
+ ) + })()} + {/* Individual graphics list */} +
+ {result.graphics.map((g, i) => ( +
+ + + {g.shape} + + + {g.w}x{g.h}px @ ({g.x}, {g.y}) + + + {g.color_name} + + + {Math.round(g.confidence * 100)}% + +
+ ))} +
+
+ )} + {/* Color regions */} {Object.keys(result.color_pixel_counts).length > 0 && (
diff --git a/klausur-service/backend/cv_graphic_detect.py b/klausur-service/backend/cv_graphic_detect.py new file mode 100644 index 0000000..5fadf9d --- /dev/null +++ b/klausur-service/backend/cv_graphic_detect.py @@ -0,0 +1,309 @@ +""" +Graphical element detection for OCR pages. + +Finds non-text visual elements (arrows, balloons, icons, illustrations) +by subtracting known OCR word regions from the page ink and analysing +remaining connected components via contour shape metrics. + +Works on both color and grayscale scans. + +Lizenz: Apache 2.0 (kommerziell nutzbar) +DATENSCHUTZ: Alle Verarbeitung erfolgt lokal. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np + +logger = logging.getLogger(__name__) + +__all__ = ["detect_graphic_elements", "GraphicElement"] + + +@dataclass +class GraphicElement: + """A detected non-text graphical element.""" + x: int + y: int + width: int + height: int + area: int + shape: str # arrow, circle, line, icon, illustration + color_name: str # dominant color or 'black' + color_hex: str + confidence: float + contour: Any = field(default=None, repr=False) # numpy contour, excluded from repr + + +# --------------------------------------------------------------------------- +# Color helpers +# --------------------------------------------------------------------------- + +_COLOR_HEX = { + "black": "#000000", + "gray": "#6b7280", + "red": "#dc2626", + "orange": "#ea580c", + "yellow": "#ca8a04", + "green": "#16a34a", + "blue": "#2563eb", + "purple": "#9333ea", +} + + +def _dominant_color(hsv_roi: np.ndarray, sat_threshold: int = 50) -> tuple: + """Return (color_name, color_hex) for an HSV region.""" + if hsv_roi.size == 0: + return "black", _COLOR_HEX["black"] + + pixels = hsv_roi.reshape(-1, 3) + sat = pixels[:, 1] + sat_mask = sat > sat_threshold + sat_ratio = np.sum(sat_mask) / len(pixels) if len(pixels) > 0 else 0 + + if sat_ratio < 0.15: + return "black", _COLOR_HEX["black"] + + sat_pixels = pixels[sat_mask] + if len(sat_pixels) < 3: + return "black", _COLOR_HEX["black"] + + med_hue = float(np.median(sat_pixels[:, 0])) + + if med_hue < 10 or med_hue > 170: + name = "red" + elif med_hue < 25: + name = "orange" + elif med_hue < 35: + name = "yellow" + elif med_hue < 85: + name = "green" + elif med_hue < 130: + name = "blue" + else: + name = "purple" + + return name, _COLOR_HEX.get(name, _COLOR_HEX["black"]) + + +# --------------------------------------------------------------------------- +# Shape classification via contour analysis +# --------------------------------------------------------------------------- + +def _classify_shape( + contour: np.ndarray, + bw: int, + bh: int, + area: float, +) -> tuple: + """Classify contour shape → (shape_name, confidence). + + Uses circularity, aspect ratio, solidity, and vertex count. + """ + aspect = bw / bh if bh > 0 else 1.0 + perimeter = cv2.arcLength(contour, True) + circularity = (4 * np.pi * area) / (perimeter * perimeter) if perimeter > 0 else 0 + + hull = cv2.convexHull(contour) + hull_area = cv2.contourArea(hull) + solidity = area / hull_area if hull_area > 0 else 0 + + # Approximate polygon + epsilon = 0.03 * perimeter + approx = cv2.approxPolyDP(contour, epsilon, True) + vertices = len(approx) + + # --- Arrow detection --- + # Arrows typically have: vertices 5-8, moderate solidity (0.4-0.8), + # moderate aspect ratio, low circularity + if 4 <= vertices <= 9 and 0.3 < solidity < 0.85 and circularity < 0.5: + # Check for a pointed tip via convexity defects + hull_idx = cv2.convexHull(contour, returnPoints=False) + if len(hull_idx) >= 4: + try: + defects = cv2.convexityDefects(contour, hull_idx) + if defects is not None and len(defects) >= 1: + # Significant defect = pointed shape + max_depth = max(d[0][3] for d in defects) / 256.0 + if max_depth > min(bw, bh) * 0.15: + return "arrow", min(0.75, 0.5 + max_depth / max(bw, bh)) + except cv2.error: + pass + + # --- Circle / balloon --- + if circularity > 0.65 and 0.5 < aspect < 2.0: + conf = min(0.95, circularity) + return "circle", conf + + # --- Line --- + if aspect > 4.0 or aspect < 0.25: + return "line", 0.7 + + # --- Exclamation mark (tall narrow + high solidity) --- + if aspect < 0.45 and bh > 12 and solidity > 0.5: + return "exclamation", 0.7 + + # --- Dot / bullet (small, roughly square, high solidity) --- + if max(bw, bh) < 20 and 0.5 < aspect < 2.0 and solidity > 0.6: + return "dot", 0.6 + + # --- Larger illustration --- + if area > 2000: + return "illustration", 0.5 + + # --- Generic icon --- + return "icon", 0.4 + + +# --------------------------------------------------------------------------- +# Main detection +# --------------------------------------------------------------------------- + +def detect_graphic_elements( + img_bgr: np.ndarray, + word_boxes: List[Dict], + detected_boxes: Optional[List[Dict]] = None, + min_area: int = 30, + max_area_ratio: float = 0.05, + word_pad: int = 3, + max_elements: int = 80, +) -> List[GraphicElement]: + """Find non-text graphical elements on the page. + + 1. Build ink mask (dark + colored pixels). + 2. Subtract OCR word regions and detected boxes. + 3. Find connected components and classify shapes. + + Args: + img_bgr: BGR color image. + word_boxes: List of OCR word dicts with left/top/width/height. + detected_boxes: Optional list of detected box dicts (x/y/w/h). + min_area: Minimum contour area to keep. + max_area_ratio: Maximum area as fraction of image area. + word_pad: Padding around word boxes for exclusion. + max_elements: Maximum number of elements to return. + + Returns: + List of GraphicElement, sorted by area descending. + """ + if img_bgr is None: + return [] + + h, w = img_bgr.shape[:2] + max_area = int(h * w * max_area_ratio) + + gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) + hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) + + # --- 1. Build ink mask: dark pixels + saturated colored pixels --- + # Adaptive threshold for dark ink + _, dark_mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + # Saturated colored pixels (catches colored arrows, markers) + sat_mask = (hsv[:, :, 1] > 40).astype(np.uint8) * 255 + # Only include saturated pixels that are also reasonably dark (not background) + val_mask = (hsv[:, :, 2] < 230).astype(np.uint8) * 255 + color_ink = cv2.bitwise_and(sat_mask, val_mask) + + ink_mask = cv2.bitwise_or(dark_mask, color_ink) + + # --- 2. Build exclusion mask from OCR words --- + exclusion = np.zeros((h, w), dtype=np.uint8) + + for wb in word_boxes: + x1 = max(0, int(wb.get("left", 0)) - word_pad) + y1 = max(0, int(wb.get("top", 0)) - word_pad) + x2 = min(w, int(wb.get("left", 0) + wb.get("width", 0)) + word_pad) + y2 = min(h, int(wb.get("top", 0) + wb.get("height", 0)) + word_pad) + exclusion[y1:y2, x1:x2] = 255 + + # Also exclude detected box interiors (they contain text, not graphics) + # But keep a border strip so arrows/icons at box edges are found + if detected_boxes: + box_inset = 8 + for box in detected_boxes: + bx = int(box.get("x", 0)) + by = int(box.get("y", 0)) + bbw = int(box.get("w", box.get("width", 0))) + bbh = int(box.get("h", box.get("height", 0))) + x1 = max(0, bx + box_inset) + y1 = max(0, by + box_inset) + x2 = min(w, bx + bbw - box_inset) + y2 = min(h, by + bbh - box_inset) + if x2 > x1 and y2 > y1: + exclusion[y1:y2, x1:x2] = 255 + + # Subtract exclusion from ink + graphic_mask = cv2.bitwise_and(ink_mask, cv2.bitwise_not(exclusion)) + + # --- 3. Morphological cleanup --- + # Close small gaps (connects arrow stroke + head) + kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) + graphic_mask = cv2.morphologyEx(graphic_mask, cv2.MORPH_CLOSE, kernel_close) + # Remove tiny noise + kernel_open = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) + graphic_mask = cv2.morphologyEx(graphic_mask, cv2.MORPH_OPEN, kernel_open) + + # --- 4. Find contours --- + contours, _ = cv2.findContours( + graphic_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE, + ) + + # --- 5. Analyse and classify --- + candidates: List[GraphicElement] = [] + for cnt in contours: + area = cv2.contourArea(cnt) + if area < min_area or area > max_area: + continue + + bx, by, bw, bh = cv2.boundingRect(cnt) + if bw < 4 or bh < 4: + continue + + # Skip elements that are mostly inside the exclusion zone + # (partial overlap with a word) + roi_excl = exclusion[by:by + bh, bx:bx + bw] + excl_ratio = np.sum(roi_excl > 0) / (bw * bh) if bw * bh > 0 else 0 + if excl_ratio > 0.6: + continue + + # Classify shape + shape, conf = _classify_shape(cnt, bw, bh, area) + + # Determine dominant color + roi_hsv = hsv[by:by + bh, bx:bx + bw] + # Only sample pixels that are actually in the contour + cnt_mask = np.zeros((bh, bw), dtype=np.uint8) + shifted_cnt = cnt - np.array([bx, by]) + cv2.drawContours(cnt_mask, [shifted_cnt], -1, 255, -1) + masked_hsv = roi_hsv[cnt_mask > 0] + color_name, color_hex = _dominant_color(masked_hsv) + + candidates.append(GraphicElement( + x=bx, y=by, width=bw, height=bh, + area=int(area), + shape=shape, + color_name=color_name, + color_hex=color_hex, + confidence=conf, + contour=cnt, + )) + + # Sort by area descending, limit count + candidates.sort(key=lambda g: g.area, reverse=True) + result = candidates[:max_elements] + + if result: + shape_counts = {} + for g in result: + shape_counts[g.shape] = shape_counts.get(g.shape, 0) + 1 + logger.info( + "GraphicDetect: %d elements found (%s)", + len(result), + ", ".join(f"{s}: {c}" for s, c in sorted(shape_counts.items())), + ) + + return result diff --git a/klausur-service/backend/ocr_pipeline_api.py b/klausur-service/backend/ocr_pipeline_api.py index e20b84e..f03f44d 100644 --- a/klausur-service/backend/ocr_pipeline_api.py +++ b/klausur-service/backend/ocr_pipeline_api.py @@ -73,6 +73,7 @@ from cv_vocab_pipeline import ( ) from cv_box_detect import detect_boxes, split_page_into_zones from cv_color_detect import detect_word_colors, recover_colored_text, _COLOR_RANGES, _COLOR_HEX +from cv_graphic_detect import detect_graphic_elements from cv_words_first import build_grid_from_words from ocr_pipeline_session_store import ( create_session_db, @@ -1304,6 +1305,16 @@ async def detect_structure(session_id: str): if pixel_count > 50: # minimum threshold color_summary[color_name] = pixel_count + # --- Graphic element detection --- + box_dicts = [ + {"x": b.x, "y": b.y, "w": b.width, "h": b.height} + for b in boxes + ] + graphics = detect_graphic_elements( + img_bgr, words, + detected_boxes=box_dicts, + ) + duration = time.time() - t0 result_dict = { @@ -1332,6 +1343,17 @@ async def detect_structure(session_id: str): } for z in zones ], + "graphics": [ + { + "x": g.x, "y": g.y, "w": g.width, "h": g.height, + "area": g.area, + "shape": g.shape, + "color_name": g.color_name, + "color_hex": g.color_hex, + "confidence": round(g.confidence, 2), + } + for g in graphics + ], "color_pixel_counts": color_summary, "has_words": len(words) > 0, "word_count": len(words), @@ -1342,8 +1364,8 @@ async def detect_structure(session_id: str): await update_session_db(session_id, structure_result=result_dict) cached["structure_result"] = result_dict - logger.info("detect-structure session %s: %d boxes, %d zones, %.2fs", - session_id, len(boxes), len(zones), duration) + logger.info("detect-structure session %s: %d boxes, %d zones, %d graphics, %.2fs", + session_id, len(boxes), len(zones), len(graphics), duration) return {"session_id": session_id, **result_dict} @@ -1777,6 +1799,48 @@ async def _get_structure_overlay(session_id: str) -> Response: continue cv2.drawContours(img, [cnt], -1, draw_color, 2) + # --- Draw graphic elements --- + graphics_data = structure.get("graphics", []) + shape_icons = { + "arrow": "ARROW", + "circle": "CIRCLE", + "line": "LINE", + "exclamation": "!", + "dot": "DOT", + "icon": "ICON", + "illustration": "ILLUST", + } + for gfx in graphics_data: + gx, gy = gfx["x"], gfx["y"] + gw, gh = gfx["w"], gfx["h"] + shape = gfx.get("shape", "icon") + color_hex = gfx.get("color_hex", "#6b7280") + conf = gfx.get("confidence", 0) + + # Pick draw color based on element color (BGR) + gfx_bgr = bg_hex_to_bgr.get(color_hex, (128, 114, 107)) + + # Draw bounding box (dashed style via short segments) + dash = 6 + for seg_x in range(gx, gx + gw, dash * 2): + end_x = min(seg_x + dash, gx + gw) + cv2.line(img, (seg_x, gy), (end_x, gy), gfx_bgr, 2) + cv2.line(img, (seg_x, gy + gh), (end_x, gy + gh), gfx_bgr, 2) + for seg_y in range(gy, gy + gh, dash * 2): + end_y = min(seg_y + dash, gy + gh) + cv2.line(img, (gx, seg_y), (gx, end_y), gfx_bgr, 2) + cv2.line(img, (gx + gw, seg_y), (gx + gw, end_y), gfx_bgr, 2) + + # Label + icon = shape_icons.get(shape, shape.upper()[:5]) + label = f"{icon} {int(conf * 100)}%" + # White background for readability + (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) + lx = gx + 2 + ly = max(gy - 4, th + 4) + cv2.rectangle(img, (lx - 1, ly - th - 2), (lx + tw + 2, ly + 3), (255, 255, 255), -1) + cv2.putText(img, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.4, gfx_bgr, 1) + # Encode result _, png_buf = cv2.imencode(".png", img) return Response(content=png_buf.tobytes(), media_type="image/png")