from pywhy_graphs import ADMG # Assuming ADMG is imported from pywhy_graphs
from .GraphWrapperBase import GraphWrapperBase
[docs]
class MAGWrapper(GraphWrapperBase):
def __init__(self, incoming_directed_edges=None, incoming_bidirected_edges=None):
"""
Initialize a Maximal Ancestral Graph (MAG) wrapper using the ADMG class.
Parameters
----------
incoming_directed_edges : list, optional
List of tuples representing directed edges.
incoming_bidirected_edges : list, optional
List of tuples representing bidirected edges.
"""
self.mag = ADMG(incoming_directed_edges=incoming_directed_edges,
incoming_bidirected_edges=incoming_bidirected_edges)
[docs]
def add_node(self, node):
"""
Add a node to the MAG.
Parameters
----------
node : hashable object
The node to be added.
"""
self.mag.add_node(node)
[docs]
def remove_node(self, node):
"""
Remove a node from the MAG.
Parameters
----------
node : hashable object
The node to be removed.
"""
self.mag.remove_node(node)
[docs]
def add_directed_edge(self, source, target):
"""
Add a directed edge to the MAG.
Parameters
----------
source : hashable object
The source node of the directed edge.
target : hashable object
The target node of the directed edge.
"""
self.mag.add_edge(source, target, edge_type=self.mag.directed_edge_name)
[docs]
def add_bidirected_edge(self, source, target):
"""
Add a bidirected edge to the MAG.
Parameters
----------
source : hashable object
The source node of the bidirected edge.
target : hashable object
The target node of the bidirected edge.
"""
self.mag.add_edge(source, target, edge_type=self.mag.bidirected_edge_name)
[docs]
def remove_edge(self, source, target):
"""
Remove an edge from the MAG.
Parameters
----------
source : hashable object
The source node of the edge to be removed.
target : hashable object
The target node of the edge to be removed.
"""
self.mag.remove_edge(source, target)
[docs]
def get_nodes(self):
"""
Return the nodes of the MAG.
Returns
-------
list
List of nodes in the MAG.
"""
return self.mag.nodes
[docs]
def get_edges(self):
"""
Return the edges of the MAG.
Returns
-------
dict
Dictionary containing directed and bidirected edges of the MAG.
"""
return {
"directed": self.mag.directed_edges,
"bidirected": self.mag.bidirected_edges
}
# Additional MAG-specific methods can be added as needed