Source code for ETIA.CausalLearning.CausalModel.PAG

from .GraphWrapperBase import GraphWrapperBase
from pywhy_graphs import PAG

[docs] class PAGWrapper(GraphWrapperBase): def __init__(self, incoming_directed_edges=None, incoming_undirected_edges=None, incoming_bidirected_edges=None, incoming_circle_edges=None): """ Initialize a Partial Ancestral Graph (PAG) wrapper. Parameters ---------- incoming_directed_edges : list, optional List of tuples representing directed edges. incoming_undirected_edges : list, optional List of tuples representing undirected edges. incoming_bidirected_edges : list, optional List of tuples representing bidirected edges. incoming_circle_edges : list, optional List of tuples representing circle edges. """ self.pag = PAG(incoming_directed_edges=incoming_directed_edges, incoming_undirected_edges=incoming_undirected_edges, incoming_bidirected_edges=incoming_bidirected_edges, incoming_circle_edges=incoming_circle_edges)
[docs] def add_node(self, node): """ Add a node to the PAG. Parameters ---------- node : hashable object The node to be added. """ self.pag.add_node(node)
[docs] def remove_node(self, node): """ Remove a node from the PAG. Parameters ---------- node : hashable object The node to be removed. """ self.pag.remove_node(node)
[docs] def add_directed_edge(self, source, target): """ Add a directed edge to the PAG. Parameters ---------- source : hashable object The source node of the directed edge. target : hashable object The target node of the directed edge. """ self.pag.add_edge(source, target, edge_type=self.pag.directed_edge_name)
[docs] def add_bidirected_edge(self, source, target): """ Add a bidirected edge to the PAG. Parameters ---------- source : hashable object The source node of the bidirected edge. target : hashable object The target node of the bidirected edge. """ self.pag.add_edge(source, target, edge_type=self.pag.bidirected_edge_name)
[docs] def add_undirected_edge(self, source, target): """ Add an undirected edge to the PAG. Parameters ---------- source : hashable object The source node of the undirected edge. target : hashable object The target node of the undirected edge. """ self.pag.add_edge(source, target, edge_type=self.pag.undirected_edge_name)
[docs] def add_circle_edge(self, source, target): """ Add a circle edge (circle edge on one side) to the PAG. Parameters ---------- source : hashable object One side of the circle edge. target : hashable object The other side of the circle edge. """ self.pag.add_edge(source, target, edge_type=self.pag.circle_edge_name)
[docs] def remove_edge(self, source, target): """ Remove an edge from the PAG. Parameters ---------- source : hashable object The source node of the edge to be removed. target : hashable object The target node of the edge to be removed. """ self.pag.remove_edge(source, target)
[docs] def get_nodes(self): """ Return the nodes of the PAG. Returns ------- list List of nodes in the PAG. """ return self.pag.nodes
[docs] def get_edges(self): """ Return the edges of the PAG. Returns ------- dict Dictionary containing directed, bidirected, undirected, and circle edges of the PAG. """ return { "directed": self.pag.directed_edges, "bidirected": self.pag.bidirected_edges, "undirected": self.pag.undirected_edges, "circle": self.pag.circle_edges }
[docs] def possible_children(self, node): """ Return an iterator over possible children of a node. Parameters ---------- node : hashable object The node whose possible children are to be retrieved. Returns ------- iterator Iterator over possible children of the given node. """ return self.pag.possible_children(node)
[docs] def possible_parents(self, node): """ Return an iterator over possible parents of a node. Parameters ---------- node : hashable object The node whose possible parents are to be retrieved. Returns ------- iterator Iterator over possible parents of the given node. """ return self.pag.possible_parents(node)
[docs] def children(self, node): """ Return the definite children of a node. Parameters ---------- node : hashable object The node whose definite children are to be retrieved. Returns ------- list List of definite children of the given node. """ return self.pag.children(node)
[docs] def parents(self, node): """ Return the definite parents of a node. Parameters ---------- node : hashable object The node whose definite parents are to be retrieved. Returns ------- list List of definite parents of the given node. """ return self.pag.parents(node)