Source code for ETIA.CausalLearning.CausalModel.utils

import numpy as np
import pandas as pd

from .DAG import DAGWrapper
from .CPDAG import CPDAGWrapper
from .PAG import PAGWrapper
from .MAG import MAGWrapper

[docs] def matrix_to_pywhy_graph(matrix, graph_type='DAG'): """ Convert a matrix representation to a pywhy-graphs graph. Parameters ---------- matrix : numpy.ndarray or pandas.DataFrame The matrix representation of the graph. graph_type : str, optional The type of graph to be created. Default is 'DAG'. Supported types are 'DAG', 'CPDAG', 'PAG', and 'MAG'. Returns ------- GraphWrapperBase The pywhy-graphs graph object. """ if graph_type == 'DAG': graph = DAGWrapper() elif graph_type == 'CPDAG': graph = CPDAGWrapper() elif graph_type == 'PAG': graph = PAGWrapper() elif graph_type == 'MAG': graph = MAGWrapper() else: raise ValueError("Unsupported graph type") if isinstance(matrix, pd.DataFrame): matrix = matrix.to_numpy() n = matrix.shape[0] for i in range(n): graph.add_node(str(i)) for j in range(n): if matrix[i, j] == 2 and matrix[j, i] == 3: graph.add_directed_edge(str(i), str(j)) elif matrix[i, j] == 1 and matrix[j, i] == 1: if graph_type == 'PAG': graph.add_circle_edge(str(i), str(j)) graph.add_circle_edge(str(j), str(i)) else: raise TypeError( f"Unsupported edge type detected in the provided matrix for the graph type '{graph_type}'. The matrix contains edges that are not compatible with the specified graph type.") elif matrix[i, j] == 2 and matrix[j, i] == 1: if graph_type == 'PAG': graph.add_directed_edge(str(i), str(j)) graph.add_circle_edge(str(j), str(i)) else: raise TypeError( f"Unsupported edge type detected in the provided matrix for the graph type '{graph_type}'. The matrix contains edges that are not compatible with the specified graph type.") elif matrix[i, j] == 2 and matrix[j, i] == 2: if graph_type == 'MAG' or graph_type == 'PAG': graph.add_bidirected_edge(str(i), str(j)) else: raise TypeError( f"Unsupported edge type detected in the provided matrix for the graph type '{graph_type}'. The matrix contains edges that are not compatible with the specified graph type.") elif matrix[i, j] == 3 and matrix[j, i] == 3 and graph_type == 'CPDAG': if graph_type == 'CPDAG': graph.add_undirected_edge(str(i), str(j)) else: raise TypeError( f"Unsupported edge type detected in the provided matrix for the graph type '{graph_type}'. The matrix contains edges that are not compatible with the specified graph type.") elif(matrix[i, j] == 0 and matrix[j, i] != 0) or (matrix[j, i] == 0 and matrix[i, j] != 0): raise TypeError( f"Unsupported edge type detected in the provided matrix for the graph type '{graph_type}'. The matrix contains edges that are not compatible with the specified graph type.") return graph
[docs] def pywhy_graph_to_matrix(graph): """ Convert a pywhy-graphs graph to a matrix representation. Parameters ---------- graph : GraphWrapperBase The pywhy-graphs graph object. Returns ------- numpy.ndarray The matrix representation of the graph. """ nodes = list(graph.get_nodes()) node_indices = {node: i for i, node in enumerate(nodes)} n = len(nodes) matrix = np.zeros((n, n), dtype=int) edges = graph.get_edges() if isinstance(graph, DAGWrapper): for edge in edges: source, target = edge matrix[node_indices[source], node_indices[target]] = 2 matrix[node_indices[target], node_indices[source]] = 3 else: for edge_type in edges: for edge in edges[edge_type]: source, target = edge # Assign values based on edge types if edge_type == 'directed': matrix[node_indices[source], node_indices[target]] = 2 matrix[node_indices[target], node_indices[source]] = 3 elif edge_type == 'circle': matrix[node_indices[source], node_indices[target]] = 1 elif edge_type == 'bidirected': matrix[node_indices[source], node_indices[target]] = 2 matrix[node_indices[target], node_indices[source]] = 2 elif edge_type == 'undirected': matrix[node_indices[source], node_indices[target]] = 3 matrix[node_indices[target], node_indices[source]] = 3 else: raise TypeError( f"Unsupported edge type.") return matrix
''' import numpy as np # Define the matrix matrix = np.array([[0, 2, 1], [3, 0, 2], [2, 3, 0]]) # Convert the matrix to a PAG pag = matrix_to_pywhy_graph(matrix, graph_type='PAG') print(pag.get_edges()) # Convert the PAG back to a matrix converted_matrix = pywhy_graph_to_matrix(pag) # Print the converted matrix print(converted_matrix) '''