Source code for ETIA.CausalLearning.configurations.configurations

import json
import os
from typing import Optional

from ...data.Dataset import Dataset  # Corrected import
from .parameters import CDHPO_Parameters
from .default_parameters import default_conf
from .class_causal_configurator import CausalDiscoveryConfigurator
from ..utils import get_logger


[docs] class Configurations: """ Configurations class for setting up the causal discovery experiment. Parameters ---------- dataset : Dataset The dataset object. n_lags : int, optional Number of lags (for time series). time_lagged : bool, optional Indicates if the data is time-lagged. time_series : bool, optional Indicates if the dataset is time series data. conf_file : str, optional JSON configuration file containing parameters for the causal discovery experiment. n_jobs : int, optional Number of jobs to use for parallel processing. verbose : bool, optional Whether to print debug information. Attributes ---------- cdhpo_params : CDHPOParameters Parameters for the CDHPO algorithm. results_folder : str Folder path for storing results. """ def __init__( self, dataset: Dataset = None, n_lags: int = 0, time_lagged: bool = False, time_series: bool = False, conf_file: Optional[str] = None, n_jobs: Optional[int] = -1, verbose=False ): self.n_jobs = n_jobs self.dataset = dataset self.n_lags = n_lags self.time_lagged = time_lagged self.time_series = time_series self.verbose = verbose self.logger = get_logger(name=__name__, verbose=self.verbose) self.cdhpo_params = CDHPO_Parameters() self.results_folder = './' # If conf_file is provided, process it. Otherwise, set default configurations. if conf_file is not None and conf_file.endswith('.json'): self.conf_file = conf_file self.process_conf_file() # Process the JSON file else: self.set_default_configuration() # Set default configuration based on dataset self.logger.info('Configurations object has been initialized')
[docs] def set_default_configuration(self): """ Set default configurations based on the dataset when no JSON configuration file is provided. """ default_conf['Dataset'] = { 'dataset_name': self.dataset.dataset_name or 'Preloaded Dataset', 'time_lagged': self.time_lagged, 'n_lags': self.n_lags, 'time_series': self.time_series, } self.dataset_info = self.dataset.get_info() # Configure CDHPO parameters using default configurations oct_json_params = default_conf['OCT'] self.cdhpo_params.init_main_params( alpha=oct_json_params['alpha'], n_permutations=oct_json_params['n_permutations'], causal_sufficiency=default_conf['causal_sufficiency'], n_jobs=self.n_jobs, variables_type=oct_json_params['variables_type'], ) self.cdhpo_params.set_regressor( name=oct_json_params['Regressor_parameters']['name'], parameters=oct_json_params['Regressor_parameters']['parameters'], ) self.cdhpo_params.set_oos_protocol( name=oct_json_params['out_of_sample_protocol']['name'], parameters=oct_json_params['out_of_sample_protocol']['parameters'], ) # Configure causal discovery algorithms if 'CausalDiscoveryAlgorithms' not in oct_json_params: configurator = CausalDiscoveryConfigurator() oct_json_params['CausalDiscoveryAlgorithms'] = configurator.create_causal_configs( data_type=self.dataset_info['data_type'], causal_sufficiency=default_conf['causal_sufficiency'], assume_faithfulness=default_conf['assume_faithfulness'], time_lagged=self.time_lagged, time_series=self.time_series, ) else: if ('include_algs' in oct_json_params['CausalDiscoveryAlgorithms'] or 'exclude_algs' in oct_json_params[ 'CausalDiscoveryAlgorithms']): configurator = CausalDiscoveryConfigurator() if ('exclude_algs' in oct_json_params['CausalDiscoveryAlgorithms'] and 'include_algs' in oct_json_params['CausalDiscoveryAlgorithms']): exlcude_algs = oct_json_params['CausalDiscoveryAlgorithms']['exclude_algs'] include_algs = oct_json_params['CausalDiscoveryAlgorithms']['include_algs'] elif ('exclude_algs' in oct_json_params['CausalDiscoveryAlgorithms']): exlcude_algs = oct_json_params['CausalDiscoveryAlgorithms']['exclude_algs'] include_algs = [] else: exlcude_algs = [] include_algs = oct_json_params['CausalDiscoveryAlgorithms']['include_algs'] oct_json_params['CausalDiscoveryAlgorithms'] = configurator.create_causal_configs( data_type=self.dataset_info['data_type'], causal_sufficiency=default_conf['causal_sufficiency'], assume_faithfulness=default_conf['assume_faithfulness'], time_lagged=self.time_lagged, time_series=self.time_series, include_algs=include_algs, exclude_algs=exlcude_algs ) self.cdhpo_params.set_cd_algorithms(oct_json_params['CausalDiscoveryAlgorithms'], self.dataset_info)
[docs] def process_conf_file(self): """ Process the JSON file containing all vital information, such as algorithms, algorithm parameters, run mode, etc. """ with open(self.conf_file) as f: conf = json.load(f) dataset_conf = conf.get('Dataset', {}) self.time_lagged = dataset_conf.get('time_lagged', self.time_lagged) self.n_lags = dataset_conf.get('n_lags', self.n_lags) self.time_series = dataset_conf.get('time_series', self.time_series) # Initialize the dataset from the configuration file if 'dataset_name' in dataset_conf: self.dataset = Dataset( filename=dataset_conf['dataset_name'], data_time_info={'n_lags': self.n_lags, 'time_lagged': self.time_lagged}, time_series=self.time_series, ) else: raise ValueError("Configuration file must include 'dataset_name' in 'Dataset' section.") self.results_folder = conf.get('Results_folder_path', self.results_folder) # Process CDHPO parameters oct_json_params = conf['OCT'] if('n_jobs' in oct_json_params): self.n_jobs = oct_json_params['n_jobs'] else: self.n_jobs = os.cpu_count() self.cdhpo_params.init_main_params( alpha=oct_json_params['alpha'], n_permutations=oct_json_params['n_permutations'], causal_sufficiency=conf['causal_sufficiency'], variables_type=oct_json_params['variables_type'], n_jobs=self.n_jobs, verbose=self.verbose, ) self.cdhpo_params.set_regressor( name=oct_json_params['Regressor_parameters']['name'], parameters=oct_json_params['Regressor_parameters']['parameters'], ) self.cdhpo_params.set_oos_protocol( name=oct_json_params['out_of_sample_protocol']['name'], parameters=oct_json_params['out_of_sample_protocol']['parameters'], ) self.dataset_info = self.dataset.get_info() # Configure causal discovery algorithms if 'CausalDiscoveryAlgorithms' not in oct_json_params: configurator = CausalDiscoveryConfigurator() oct_json_params['CausalDiscoveryAlgorithms'] = configurator.create_causal_configs( data_type=self.dataset_info['data_type'], causal_sufficiency=conf['causal_sufficiency'], assume_faithfulness=conf['assume_faithfulness'], time_lagged=self.time_lagged, time_series=self.time_series, ) else: if('include_algs' in oct_json_params['CausalDiscoveryAlgorithms'] or 'exclude_algs' in oct_json_params['CausalDiscoveryAlgorithms']): configurator = CausalDiscoveryConfigurator() if('exclude_algs' in oct_json_params['CausalDiscoveryAlgorithms'] and 'include_algs' in oct_json_params['CausalDiscoveryAlgorithms']): exlcude_algs = oct_json_params['CausalDiscoveryAlgorithms']['exclude_algs'] include_algs = oct_json_params['CausalDiscoveryAlgorithms']['include_algs'] elif('exclude_algs' in oct_json_params['CausalDiscoveryAlgorithms']): exlcude_algs = oct_json_params['CausalDiscoveryAlgorithms']['exclude_algs'] include_algs = [] else: exlcude_algs = [] include_algs = oct_json_params['CausalDiscoveryAlgorithms']['include_algs'] oct_json_params['CausalDiscoveryAlgorithms'] = configurator.create_causal_configs( data_type=self.dataset_info['data_type'], causal_sufficiency=conf['causal_sufficiency'], assume_faithfulness=conf['assume_faithfulness'], time_lagged=self.time_lagged, time_series=self.time_series, include_algs=include_algs, exclude_algs=exlcude_algs ) self.cdhpo_params.set_cd_algorithms(oct_json_params['CausalDiscoveryAlgorithms'], self.dataset_info)
[docs] def add_configurations_from_file(self, filename: str) -> None: """ Add additional configurations to the experiment from a JSON file. Parameters ---------- filename : str The filename of the JSON file containing configurations. """ with open(filename) as f: conf = json.load(f) oct_json_params = conf['OCT'] causal_algorithms = oct_json_params['CausalDiscoveryAlgorithms'] for algo, params in causal_algorithms.items(): if algo not in self.cdhpo_params.configs: self.cdhpo_params.add_cd_algorithm(algo, params, self.dataset_info) else: self.cdhpo_params.add_cd_algorithm_parameters(algo, params) self.logger.info(f'Additional configurations added from {filename}')