Source code for ETIA.CausalLearning.algorithms.cdt_algorithms.SAMAlgorithm

import pandas as pd
from cdt.causality.graph import SAM
from ...utils.logger import get_logger
from ...CausalModel import DAGWrapper


[docs] class SAMAlgorithm: """ A class that implements the SAM (Structural Agnostic Model) algorithm for causal discovery. Methods ------- check_parameters() Validates the parameters for running the SAM algorithm. prepare_data(data) Prepares the data for the SAM algorithm by converting it to a pandas DataFrame. set_parameters(parameters) Sets the algorithm's parameters from a provided dictionary, using defaults where necessary. run(data, parameters, prepare_data=True) Runs the SAM algorithm on the provided data and parameters, and returns the learned causal structure. """ def __init__(self, algorithm='sam', verbose=False): """ Initializes the SAMAlgorithm class. Parameters ---------- algorithm : str, optional The name of the algorithm. Default is 'sam'. verbose : bool, optional If True, enables verbose logging. Default is False. """ self.algorithm = algorithm self.verbose = verbose self.logger = get_logger(name=__name__, verbose=self.verbose)
[docs] def check_parameters(self): """ Validates the parameters required for running the SAM algorithm. Returns ------- bool True if all parameters are valid, raises ValueError otherwise. Raises ------ ValueError If any parameter is out of the valid range or not of the expected type. """ if not (0 < self.lr < 1) or not (0 < self.dlr < 1): self.logger.error('Learning rates (lr and dlr) must be between 0 and 1') return False if not isinstance(self.mixed_data, bool): self.logger.error('mixed_data must be a boolean value') return False if self.lambda1 < 0 or self.lambda2 < 0: self.logger.error('lambda1 and lambda2 must be non-negative') return False if not isinstance(self.nh, int) or self.nh <= 0 or not isinstance(self.dnh, int) or self.dnh <= 0: self.logger.error('nh and dnh must be positive integers') return False if not isinstance(self.train_epochs, int) or self.train_epochs <= 0 or not isinstance(self.test_epochs, int) or self.test_epochs <= 0: self.logger.error('train_epochs and test_epochs must be positive integers') return False if not isinstance(self.batch_size, int) or self.batch_size < 0: self.logger.error('batch_size must be a non-negative integer') return False if self.losstype not in ['fgan', 'gan', 'mse']: self.logger.error("losstype must be one of 'fgan', 'gan', or 'mse'") return False return True
[docs] def prepare_data(self, data): """ Prepares the data for the SAM algorithm by converting it to a pandas DataFrame if necessary. Parameters ---------- data : any The dataset to be used. Can be a pandas DataFrame or an object that implements the `get_dataset` method. Returns ------- data : pandas.DataFrame The prepared dataset as a pandas DataFrame. """ if not isinstance(data, pd.DataFrame): data = data.get_dataset() return data
[docs] def set_parameters(self, parameters): """ Sets the parameters for the SAM algorithm, using default values where necessary. Parameters ---------- parameters : dict A dictionary containing the parameters to set, such as learning rates, lambda values, number of hidden units, etc. Raises ------ ValueError If any of the parameters are invalid. """ self.lr = parameters.get('lr', 0.01) self.dlr = parameters.get('dlr', 0.001) self.mixed_data = parameters.get('mixed_data', False) self.lambda1 = parameters.get('lambda1', 10) self.lambda2 = parameters.get('lambda2', 0.001) self.nh = parameters.get('nh', 20) self.dnh = parameters.get('dnh', 200) self.train_epochs = parameters.get('train_epochs', 3000) self.test_epochs = parameters.get('test_epochs', 1000) self.batch_size = parameters.get('batch_size', 100) self.losstype = parameters.get('losstype', 'fgan') if not self.check_parameters(): raise ValueError("Invalid parameters for SAM algorithm")
[docs] def run(self, data, parameters, prepare_data=True): """ Runs the SAM algorithm on the provided data and parameters. Parameters ---------- data : any The dataset to be used, either as a pandas DataFrame or an object implementing `get_dataset`. parameters : dict The parameters to configure the SAM algorithm. prepare_data : bool, optional If True, the data will be prepared before running the algorithm. Default is True. Returns ------- tuple A tuple containing: - mec_graph : DAGWrapper The learned causal structure represented as a graph. - library_results : dict A dictionary containing the resulting graph and additional results. Raises ------ ValueError If any of the parameters are invalid. """ if prepare_data: data = self.prepare_data(data) self.set_parameters(parameters) # Initialize and run the SAM algorithm with the class parameters sam = SAM( lr=self.lr, dlr=self.dlr, mixed_data=self.mixed_data, lambda1=self.lambda1, lambda2=self.lambda2, nh=self.nh, dnh=self.dnh, train_epochs=self.train_epochs, test_epochs=self.test_epochs, batch_size=self.batch_size, losstype=self.losstype ) output = sam.predict(data) # Wrap the learned structure in a DAGWrapper mec_graph = DAGWrapper() mec_graph.dag = output library_results = {'mec': mec_graph} return mec_graph, library_results