from typing import Optional
from ..model_validation_protocols.available_mvp_protocols import available_mvp_protocols
from ..regressors import regressors
import copy
from ..algorithms.causaldiscoveryalgorithms import *
from ..algorithms.tetrad_algorithm import TetradAlgorithm
from ..algorithms.tigramite_algorithm import TigramiteAlgorithm
from ..algorithms.causalnex_algorithm import NoTears
from ..algorithms.cdt_algorithms import SAMAlgorithm
from ..utils.logger import get_logger
[docs]
class MVP_Parameters:
"""
Class to manage the out-of-sample (OOS) protocol for model validation.
Parameters
----------
protocol_name : str, optional
The name of the OOS protocol. Default is 'KFoldCV'.
parameters : dict, optional
Parameters for the chosen OOS protocol. Default is {'folds': 10, 'folds_to_run': 1}.
verbose : bool, optional
If True, enables detailed logging. Default is False.
Attributes
----------
protocol_name : str
Name of the chosen OOS protocol.
protocol : object
Instance of the chosen OOS protocol.
parameters : dict
Parameters for the OOS protocol.
Methods
-------
__init__(protocol_name='KFoldCV', parameters={'folds': 10, 'folds_to_run': 1}, verbose=False)
Initializes the OOS protocol with the given or default parameters.
"""
[docs]
def __init__(self, protocol_name='KFoldCV', parameters={'folds': 10, 'folds_to_run': 1}, verbose=False):
self.protocol_name = protocol_name
self.parameters = parameters
self.verbose = verbose
self.logger = get_logger(name=__name__, verbose=self.verbose)
# Check if the specified protocol name exists in available protocols
if protocol_name in available_mvp_protocols:
self.protocol = available_mvp_protocols[protocol_name]
available_mvp_protocols[protocol_name].set_params(parameters, self.verbose)
else:
self.logger.error('The protocol you chose is not available')
raise Exception('The protocol you chose is not available')
[docs]
class Regressor_parameters:
"""
Class to manage the configuration of regressors for the CDHPO process.
Parameters
----------
regressor_name : str, optional
The name of the regressor. Default is 'RandomForestRegressor'.
parameters : dict, optional
Parameters for the regressor. Default is {'n_trees': 100, 'min_samples_leaf': 0.1, 'max_depth': 10}.
verbose : bool, optional
If True, enables detailed logging. Default is False.
Attributes
----------
regressor_name : str
Name of the chosen regressor.
regressor : object
Instance of the chosen regressor.
parameters : dict
Parameters for the regressor.
Methods
-------
__init__(regressor_name='RandomForestRegressor', parameters={'n_trees': 100, 'min_samples_leaf': 0.1, 'max_depth': 10}, verbose=False)
Initializes the regressor with the given or default parameters.
"""
[docs]
def __init__(self, regressor_name='RandomForestRegressor', parameters={'n_trees': 100, 'min_samples_leaf': 0.1, 'max_depth': 10}, verbose=False):
self.regressor_name = regressor_name
self.parameters = parameters
self.verbose = verbose
self.logger = get_logger(name=__name__, verbose=self.verbose)
# Check if the specified regressor name exists in available regressors
if regressor_name in regressors.available_regressors:
self.regressor = regressors.available_regressors[regressor_name].set_regressor_params(parameters)
else:
self.logger.error('The regressor you chose is not available')
raise Exception('The regressor you chose is not available')
[docs]
class CDHPO_Parameters:
"""
Class to manage the configuration of the CDHPO (Causal Discovery with Hyperparameter Optimization) process.
Methods
-------
init_main_params(alpha=0.01, n_permutations=200, causal_sufficiency=True, variables_type='mixed', n_jobs: Optional[int] = 1, verbose=False)
Initializes the main parameters for CDHPO.
set_regressor(name, parameters)
Sets the regressor used in the CDHPO process.
set_oos_protocol(name, parameters)
Sets the out-of-sample (OOS) protocol for validation.
set_cd_algorithms(algorithms, data_info)
Sets the causal discovery algorithms for CDHPO.
check_configs(data_info)
Verifies the configurations for the causal discovery algorithms.
add_cd_algorithm(algorithm, parameters, data_info)
Adds a new causal discovery algorithm with parameters.
add_cd_algorithm_parameters(algorithm, parameters)
Adds parameters to an existing causal discovery algorithm.
"""
[docs]
def init_main_params(self, alpha=0.01, n_permutations=200, causal_sufficiency=True, variables_type='mixed', n_jobs: Optional[int] = 1, verbose=False):
"""
Initializes the main parameters for the CDHPO process.
Parameters
----------
alpha : float, optional
The significance level. Default is 0.01.
n_permutations : int, optional
The number of permutations for statistical tests. Default is 200.
causal_sufficiency : bool, optional
Whether to assume causal sufficiency. Default is True.
variables_type : str, optional
Type of variables in the data (e.g., 'mixed', 'discrete', 'continuous'). Default is 'mixed'.
n_jobs : int, optional
Number of parallel jobs. Default is 1.
verbose : bool, optional
If True, enables detailed logging. Default is False.
Returns
-------
None
"""
self.n_jobs = n_jobs
self.alpha = alpha
self.n_permutations = n_permutations
self.causal_sufficiency = causal_sufficiency
self.variables_type = variables_type
self.oos_protocol = MVP_Parameters()
self.regressor = Regressor_parameters()
self.verbose = verbose
self.logger = get_logger(name=__name__, verbose=self.verbose)
self.logger.info('CDHPO Parameters have been initialized')
self.configs = {}
[docs]
def set_regressor(self, name, parameters):
"""
Sets the regressor for the CDHPO process.
Parameters
----------
name : str
The name of the regressor.
parameters : dict
The parameters for the regressor.
Returns
-------
None
"""
self.regressor = Regressor_parameters(name, parameters, self.verbose)
self.logger.info('Regressor_parameters has been set')
[docs]
def set_oos_protocol(self, name, parameters):
"""
Sets the out-of-sample (OOS) protocol for validation.
Parameters
----------
name : str
The name of the OOS protocol.
parameters : dict
The parameters for the OOS protocol.
Returns
-------
None
"""
self.oos_protocol = MVP_Parameters(name, parameters, self.verbose)
self.logger.info('OOS protocol has been set')
[docs]
def set_cd_algorithms(self, algorithms, data_info):
"""
Sets the causal discovery algorithms for CDHPO.
Parameters
----------
algorithms : dict
Dictionary of algorithms and their configurations.
data_info : dict
Information about the dataset.
Returns
-------
None
Raises
------
RuntimeError
If an algorithm is not found in the database.
"""
for algo in algorithms:
if algo in cd_algorithms['tetrad']['algorithms']:
algorithms[algo]['model'] = [TetradAlgorithm(algo, self.verbose)]
algorithms[algo]['model'][0].init_algo(data_info)
algorithms[algo]['library'] = ['tetrad']
self.configs[algo] = algorithms[algo]
elif algo in cd_algorithms['tigramite']['algorithms']:
algorithms[algo]['model'] = [TigramiteAlgorithm(algo, self.verbose)]
algorithms[algo]['model'][0].init_algo(data_info)
algorithms[algo]['library'] = ['tigramite']
self.configs[algo] = algorithms[algo]
elif algo in cd_algorithms['causalnex']['algorithms']:
algorithms[algo]['model'] = [NoTears.NoTearsAlgorithm(algo, self.verbose)]
algorithms[algo]['library'] = ['causalnex']
self.configs[algo] = algorithms[algo]
elif algo in cd_algorithms['cdt']['algorithms']:
algorithms[algo]['model'] = [SAMAlgorithm.SAMAlgorithm(algo, self.verbose)]
algorithms[algo]['library'] = ['cdt']
self.configs[algo] = algorithms[algo]
else:
raise RuntimeError("This algorithm (" + algo + ") does not exist in the database")
self.configs = algorithms
[docs]
def check_configs(self, data_info):
"""
Verifies the configurations for the causal discovery algorithms.
Parameters
----------
data_info : dict
Information about the dataset.
Returns
-------
None
Raises
------
RuntimeError
If any algorithm has incorrect parameters.
"""
for algo in self.configs:
if not self.configs[algo]['model'][0].check_parameters(self.configs[algo], data_info):
raise RuntimeError(f'Algorithm {algo} has wrong parameters or parameter values')
[docs]
def add_cd_algorithm(self, algorithm, parameters, data_info):
"""
Adds a new causal discovery algorithm to the CDHPO configuration.
Parameters
----------
algorithm : str
Name of the algorithm to add.
parameters : dict
Configuration parameters for the algorithm.
data_info : dict
Information about the dataset.
Returns
-------
None
Raises
------
RuntimeError
If the algorithm is not found in the database.
"""
if algorithm not in self.configs:
if algorithm in cd_algorithms['tetrad']['algorithms']:
parameters['model'] = [TetradAlgorithm(algorithm)]
parameters['model'][0].init_algo(data_info)
parameters['library'] = ['tetrad']
elif algorithm in cd_algorithms['tigramite']['algorithms']:
parameters['model'] = [TigramiteAlgorithm(algorithm)]
parameters['model'][0].init_algo(data_info)
parameters['library'] = ['tigramite']
elif algorithm in cd_algorithms['causalnex']['algorithms']:
parameters['model'] = [NoTears.NoTearsAlgorithm(algorithm)]
parameters['library'] = ['tigramite']
elif algorithm in cd_algorithms['cdt']['algorithms']:
parameters['model'] = [SAMAlgorithm.SAMAlgorithm(algorithm)]
parameters['library'] = ['tigramite']
else:
raise RuntimeError("This algorithm does not exist in the database")
self.configs[algorithm] = parameters
else:
raise Warning('Causal discovery algorithm already added')
[docs]
def add_cd_algorithm_parameters(self, algorithm, parameters):
"""
Adds parameters to an existing causal discovery algorithm.
Parameters
----------
algorithm : str
Name of the algorithm.
parameters : dict
Additional parameters for the algorithm.
Returns
-------
None
Raises
------
RuntimeError
If the algorithm does not exist or the parameters are invalid.
"""
if algorithm not in self.configs:
self.logger.error('Tried to add parameters for a non-existent algorithm')
raise RuntimeError('Tried to add parameters for a non-existent algorithm')
for param in parameters:
if param not in self.configs[algorithm]:
self.logger.error('Tried to add an invalid parameter')
raise RuntimeError('Tried to add an invalid parameter')
for val in parameters[param]:
if val not in self.configs[algorithm][param]:
self.configs[algorithm][param].append(val)
self.logger.info(f'Parameters added to algorithm {algorithm}')