""" Knowledge Graph for Breakpilot Agents Provides entity and relationship management: - Entity storage with properties - Relationship definitions - Graph traversal - Optional Qdrant integration for semantic search """ from typing import Dict, Any, List, Optional, Set, Tuple from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum import json import logging logger = logging.getLogger(__name__) class EntityType(Enum): """Types of entities in the knowledge graph""" STUDENT = "student" TEACHER = "teacher" CLASS = "class" SUBJECT = "subject" ASSIGNMENT = "assignment" EXAM = "exam" TOPIC = "topic" CONCEPT = "concept" RESOURCE = "resource" CUSTOM = "custom" class RelationshipType(Enum): """Types of relationships between entities""" BELONGS_TO = "belongs_to" # Student belongs to class TEACHES = "teaches" # Teacher teaches subject ASSIGNED_TO = "assigned_to" # Assignment assigned to student COVERS = "covers" # Exam covers topic REQUIRES = "requires" # Topic requires concept RELATED_TO = "related_to" # General relationship PARENT_OF = "parent_of" # Hierarchical relationship CREATED_BY = "created_by" # Creator relationship GRADED_BY = "graded_by" # Grading relationship @dataclass class Entity: """Represents an entity in the knowledge graph""" id: str entity_type: EntityType name: str properties: Dict[str, Any] = field(default_factory=dict) created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "entity_type": self.entity_type.value, "name": self.name, "properties": self.properties, "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat() } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Entity": return cls( id=data["id"], entity_type=EntityType(data["entity_type"]), name=data["name"], properties=data.get("properties", {}), created_at=datetime.fromisoformat(data["created_at"]), updated_at=datetime.fromisoformat(data["updated_at"]) ) @dataclass class Relationship: """Represents a relationship between two entities""" id: str source_id: str target_id: str relationship_type: RelationshipType properties: Dict[str, Any] = field(default_factory=dict) weight: float = 1.0 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "source_id": self.source_id, "target_id": self.target_id, "relationship_type": self.relationship_type.value, "properties": self.properties, "weight": self.weight, "created_at": self.created_at.isoformat() } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Relationship": return cls( id=data["id"], source_id=data["source_id"], target_id=data["target_id"], relationship_type=RelationshipType(data["relationship_type"]), properties=data.get("properties", {}), weight=data.get("weight", 1.0), created_at=datetime.fromisoformat(data["created_at"]) ) class KnowledgeGraph: """ Knowledge graph for managing entity relationships. Provides: - Entity CRUD operations - Relationship management - Graph traversal (neighbors, paths) - Optional vector search via Qdrant """ def __init__( self, db_pool=None, qdrant_client=None, namespace: str = "breakpilot" ): """ Initialize the knowledge graph. Args: db_pool: Async PostgreSQL connection pool qdrant_client: Optional Qdrant client for vector search namespace: Namespace for isolation """ self.db_pool = db_pool self.qdrant = qdrant_client self.namespace = namespace self._entities: Dict[str, Entity] = {} self._relationships: Dict[str, Relationship] = {} self._adjacency: Dict[str, Set[str]] = {} # entity_id -> set of relationship_ids # Entity Operations def add_entity( self, entity_id: str, entity_type: EntityType, name: str, properties: Optional[Dict[str, Any]] = None ) -> Entity: """ Adds an entity to the graph. Args: entity_id: Unique entity identifier entity_type: Type of entity name: Human-readable name properties: Entity properties Returns: The created Entity """ entity = Entity( id=entity_id, entity_type=entity_type, name=name, properties=properties or {} ) self._entities[entity_id] = entity self._adjacency[entity_id] = set() logger.debug(f"Added entity: {entity_type.value}/{entity_id}") return entity def get_entity(self, entity_id: str) -> Optional[Entity]: """Gets an entity by ID""" return self._entities.get(entity_id) def update_entity( self, entity_id: str, name: Optional[str] = None, properties: Optional[Dict[str, Any]] = None ) -> Optional[Entity]: """ Updates an entity. Args: entity_id: Entity to update name: New name (optional) properties: Properties to update (merged) Returns: Updated entity or None if not found """ entity = self._entities.get(entity_id) if not entity: return None if name: entity.name = name if properties: entity.properties.update(properties) entity.updated_at = datetime.now(timezone.utc) return entity def delete_entity(self, entity_id: str) -> bool: """ Deletes an entity and its relationships. Args: entity_id: Entity to delete Returns: True if deleted """ if entity_id not in self._entities: return False # Delete all relationships involving this entity rel_ids = list(self._adjacency.get(entity_id, set())) for rel_id in rel_ids: self._delete_relationship_internal(rel_id) del self._entities[entity_id] del self._adjacency[entity_id] return True def get_entities_by_type( self, entity_type: EntityType ) -> List[Entity]: """Gets all entities of a specific type""" return [ e for e in self._entities.values() if e.entity_type == entity_type ] def search_entities( self, query: str, entity_type: Optional[EntityType] = None, limit: int = 10 ) -> List[Entity]: """ Searches entities by name. Args: query: Search query (case-insensitive substring) entity_type: Optional type filter limit: Maximum results Returns: Matching entities """ query_lower = query.lower() results = [] for entity in self._entities.values(): if entity_type and entity.entity_type != entity_type: continue if query_lower in entity.name.lower(): results.append(entity) if len(results) >= limit: break return results # Relationship Operations def add_relationship( self, relationship_id: str, source_id: str, target_id: str, relationship_type: RelationshipType, properties: Optional[Dict[str, Any]] = None, weight: float = 1.0 ) -> Optional[Relationship]: """ Adds a relationship between two entities. Args: relationship_id: Unique relationship identifier source_id: Source entity ID target_id: Target entity ID relationship_type: Type of relationship properties: Relationship properties weight: Relationship weight/strength Returns: The created Relationship or None if entities don't exist """ if source_id not in self._entities or target_id not in self._entities: logger.warning( f"Cannot create relationship: entity not found " f"(source={source_id}, target={target_id})" ) return None relationship = Relationship( id=relationship_id, source_id=source_id, target_id=target_id, relationship_type=relationship_type, properties=properties or {}, weight=weight ) self._relationships[relationship_id] = relationship self._adjacency[source_id].add(relationship_id) self._adjacency[target_id].add(relationship_id) logger.debug( f"Added relationship: {source_id} -[{relationship_type.value}]-> {target_id}" ) return relationship def get_relationship(self, relationship_id: str) -> Optional[Relationship]: """Gets a relationship by ID""" return self._relationships.get(relationship_id) def delete_relationship(self, relationship_id: str) -> bool: """Deletes a relationship""" return self._delete_relationship_internal(relationship_id) def _delete_relationship_internal(self, relationship_id: str) -> bool: """Internal relationship deletion""" relationship = self._relationships.get(relationship_id) if not relationship: return False # Remove from adjacency lists if relationship.source_id in self._adjacency: self._adjacency[relationship.source_id].discard(relationship_id) if relationship.target_id in self._adjacency: self._adjacency[relationship.target_id].discard(relationship_id) del self._relationships[relationship_id] return True # Graph Traversal def get_neighbors( self, entity_id: str, relationship_type: Optional[RelationshipType] = None, direction: str = "both" # "outgoing", "incoming", "both" ) -> List[Tuple[Entity, Relationship]]: """ Gets neighboring entities. Args: entity_id: Starting entity relationship_type: Optional filter by relationship type direction: Direction to traverse Returns: List of (entity, relationship) tuples """ if entity_id not in self._entities: return [] results = [] rel_ids = self._adjacency.get(entity_id, set()) for rel_id in rel_ids: rel = self._relationships.get(rel_id) if not rel: continue # Filter by relationship type if relationship_type and rel.relationship_type != relationship_type: continue # Determine neighbor based on direction neighbor_id = None if direction == "outgoing" and rel.source_id == entity_id: neighbor_id = rel.target_id elif direction == "incoming" and rel.target_id == entity_id: neighbor_id = rel.source_id elif direction == "both": neighbor_id = rel.target_id if rel.source_id == entity_id else rel.source_id if neighbor_id: neighbor = self._entities.get(neighbor_id) if neighbor: results.append((neighbor, rel)) return results def get_path( self, source_id: str, target_id: str, max_depth: int = 5 ) -> Optional[List[Tuple[Entity, Optional[Relationship]]]]: """ Finds a path between two entities using BFS. Args: source_id: Starting entity target_id: Target entity max_depth: Maximum path length Returns: Path as list of (entity, relationship) tuples, or None if no path """ if source_id not in self._entities or target_id not in self._entities: return None if source_id == target_id: return [(self._entities[source_id], None)] # BFS from collections import deque visited = {source_id} # Queue items: (entity_id, path so far) queue = deque([(source_id, [(self._entities[source_id], None)])]) while queue: current_id, path = queue.popleft() if len(path) > max_depth: continue for neighbor, rel in self.get_neighbors(current_id): if neighbor.id == target_id: return path + [(neighbor, rel)] if neighbor.id not in visited: visited.add(neighbor.id) queue.append((neighbor.id, path + [(neighbor, rel)])) return None def get_subgraph( self, entity_id: str, depth: int = 2 ) -> Tuple[List[Entity], List[Relationship]]: """ Gets a subgraph around an entity. Args: entity_id: Center entity depth: How many hops to include Returns: Tuple of (entities, relationships) """ if entity_id not in self._entities: return [], [] entities_set: Set[str] = {entity_id} relationships_set: Set[str] = set() frontier: Set[str] = {entity_id} for _ in range(depth): next_frontier: Set[str] = set() for e_id in frontier: for neighbor, rel in self.get_neighbors(e_id): if neighbor.id not in entities_set: entities_set.add(neighbor.id) next_frontier.add(neighbor.id) relationships_set.add(rel.id) frontier = next_frontier entities = [self._entities[e_id] for e_id in entities_set] relationships = [self._relationships[r_id] for r_id in relationships_set] return entities, relationships # Serialization def to_dict(self) -> Dict[str, Any]: """Serializes the graph to a dictionary""" return { "entities": [e.to_dict() for e in self._entities.values()], "relationships": [r.to_dict() for r in self._relationships.values()] } @classmethod def from_dict(cls, data: Dict[str, Any], **kwargs) -> "KnowledgeGraph": """Deserializes a graph from a dictionary""" graph = cls(**kwargs) # Load entities first for e_data in data.get("entities", []): entity = Entity.from_dict(e_data) graph._entities[entity.id] = entity graph._adjacency[entity.id] = set() # Load relationships for r_data in data.get("relationships", []): rel = Relationship.from_dict(r_data) graph._relationships[rel.id] = rel if rel.source_id in graph._adjacency: graph._adjacency[rel.source_id].add(rel.id) if rel.target_id in graph._adjacency: graph._adjacency[rel.target_id].add(rel.id) return graph def export_json(self) -> str: """Exports graph to JSON string""" return json.dumps(self.to_dict(), indent=2) @classmethod def import_json(cls, json_str: str, **kwargs) -> "KnowledgeGraph": """Imports graph from JSON string""" return cls.from_dict(json.loads(json_str), **kwargs) # Statistics def get_statistics(self) -> Dict[str, Any]: """Gets graph statistics""" entity_types: Dict[str, int] = {} for entity in self._entities.values(): entity_types[entity.entity_type.value] = \ entity_types.get(entity.entity_type.value, 0) + 1 rel_types: Dict[str, int] = {} for rel in self._relationships.values(): rel_types[rel.relationship_type.value] = \ rel_types.get(rel.relationship_type.value, 0) + 1 return { "total_entities": len(self._entities), "total_relationships": len(self._relationships), "entity_types": entity_types, "relationship_types": rel_types, "avg_connections": ( sum(len(adj) for adj in self._adjacency.values()) / max(len(self._adjacency), 1) ) } @property def entity_count(self) -> int: """Returns number of entities""" return len(self._entities) @property def relationship_count(self) -> int: """Returns number of relationships""" return len(self._relationships)