import numpy as np
import pandas as pd
from .adjset_R import *
[docs]
def find_adjset(graph_pd, graph_type, target_name, exposure_names, r_path='R'):
'''
Run the dagitty R package to identify the adjustment set of X and Y
Author: kbiza@csd.uoc.gr
Args:
graph_pd(pandas Dataframe): the graph as adjacency matrix
graph_type(str): the type of the graph : {'dag', 'cpdag', 'mag', 'pag'}
target_name: list of one variable name
exposure_names: list of one or more variable names
Returns:
adj_set_can(list): the variable names of the canonical adj. set (if exists)
adj_set_min(list):: the variable names of the minimal adj. set (if exists)
'''
graph_np = graph_pd.to_numpy()
if graph_type in ['dag', 'cpdag']:
pcalg_graph = np.zeros(graph_np.shape, dtype=int)
pcalg_graph[graph_np == 1] = 1
pcalg_graph[graph_np == 2] = 1
pcalg_graph_t = np.transpose(pcalg_graph)
pcalg_graph_pd = pd.DataFrame(pcalg_graph_t, index=graph_pd.index, columns=graph_pd.columns)
else:
pcalg_graph_pd = graph_pd.copy()
exposure_names_ = [sub.replace(':', '.') for sub in exposure_names]
canonical_dg, minimal_dg = adjset_dagitty(pcalg_graph_pd, graph_type, exposure_names_, target_name, r_path)
if isinstance(canonical_dg, list):
adj_set_can_ = canonical_dg[0]
adj_set_can = [sub.replace('.', ':') for sub in adj_set_can_]
else:
adj_set_can = None
if isinstance(minimal_dg, list):
adj_set_min_ = minimal_dg[0]
adj_set_min = [sub.replace('.', ':') for sub in adj_set_min_]
else:
adj_set_min = None
return adj_set_can, adj_set_min