Source code for ETIA.CausalLearning.CausalModel.DAG

import networkx as nx
from .GraphWrapperBase import GraphWrapperBase

[docs] class DAGWrapper(GraphWrapperBase): def __init__(self): """ Initialize a Directed Acyclic Graph (DAG) wrapper. """ self.dag = nx.DiGraph()
[docs] def add_node(self, node): """ Add a node to the DAG. Parameters ---------- node : hashable object The node to be added. """ self.dag.add_node(node)
[docs] def remove_node(self, node): """ Remove a node from the DAG. Parameters ---------- node : hashable object The node to be removed. """ self.dag.remove_node(node)
[docs] def add_directed_edge(self, source, target): """ Add a directed edge to the DAG. Checks for cycles and raises an error if an edge creates a cycle. Parameters ---------- source : hashable object The source node of the directed edge. target : hashable object The target node of the directed edge. """ self.dag.add_edge(source, target, edge_type='directed') if not nx.is_directed_acyclic_graph(self.dag): self.dag.remove_edge(source, target) raise ValueError("Adding the edge ({}, {}) would create a cycle.".format(source, target))
[docs] def remove_edge(self, source, target): """ Remove an edge from the DAG. Parameters ---------- source : hashable object The source node of the edge to be removed. target : hashable object The target node of the edge to be removed. """ self.dag.remove_edge(source, target)
[docs] def get_nodes(self): """ Return the nodes of the DAG. Returns ------- list List of nodes in the DAG. """ return self.dag.nodes
[docs] def get_edges(self): """ Return the edges of the DAG. Returns ------- list List of edges in the DAG. """ return self.dag.edges
# Additional DAG-specific methods can be added as needed