Skip to content

Commit

Permalink
Merge pull request #204 from Kevinjkf/main
Browse files Browse the repository at this point in the history
initial upload for CALM algorithm
  • Loading branch information
kunwuz authored Nov 28, 2024
2 parents f6aa500 + fedcbc3 commit 08655cd
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 0 deletions.
212 changes: 212 additions & 0 deletions causallearn/search/ScoreBased/CALM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import numpy as np
import torch
import torch.nn as nn
from causallearn.utils.MarkovNetwork.iamb import iamb_markov_network
from causallearn.utils.CALMUtils import *
from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode
from typing import Any, Dict
from scipy.special import expit as sigmoid

torch.set_default_dtype(torch.double)

def calm(
X: np.ndarray,
lambda1: float = 0.005,
alpha: float = 0.01,
tau: float = 0.5,
rho_init: float = 1e-5,
rho_mult: float = 3,
htol: float = 1e-8,
subproblem_iter: int = 40000,
standardize: bool = False,
device: str = 'cpu'
) -> Dict[str, Any]:
"""
Perform the CALM (Continuous and Acyclicity-constrained L0-penalized likelihood with estimated Moral graph) algorithm.
Parameters
----------
X : numpy.ndarray
Input dataset of shape (n, d), where n is the number of samples,
and d is the number of variables.
lambda1 : float, optional
Coefficient for the approximated L0 penalty, which encourages sparsity in the learned graph. Default is 0.005.
alpha : float, optional
Significance level for conditional independence tests. Default is 0.01.
tau : float, optional
Temperature parameter for the Gumbel-Sigmoid. Default is 0.5.
rho_init : float, optional
Initial value of the penalty parameter for the acyclicity constraint. Default is 1e-5.
rho_mult : float, optional
Multiplication factor for rho in each iteration. Default is 3.
htol : float, optional
Tolerance level for acyclicity constraint. Default is 1e-8.
subproblem_iter : int, optional
Number of iterations for subproblem optimization. Default is 40000.
standardize : bool, optional
Whether to standardize the input data (mean=0, variance=1). Default is False.
device : str, optional
The device to use for computation ('cpu' or 'cuda'). Default is 'cpu'.
Returns
-------
Record : dict
A dictionary containing:
- Record['G']: learned causal graph, a DAG, where: Record['G'].graph[j,i]=1 and Record['G'].graph[i,j]=-1 indicates i --> j.
- Record['B_weighted']: weighted adjacency matrix of the learned causal graph.
"""

d = X.shape[1]
if standardize:
mean_X = np.mean(X, axis=0, keepdims=True)
std_X = np.std(X, axis=0, keepdims=True)
X = (X - mean_X) / std_X
else:
X = X - np.mean(X, axis=0, keepdims=True)

# Compute the data covariance matrix
cov_emp = np.cov(X.T, bias=True)

# Learn the moral graph using the IAMB Markov network
moral_mask, _ = iamb_markov_network(X, alpha=alpha)

# Initialize and run the CalmModel
device = torch.device(device)
cov_emp = torch.from_numpy(cov_emp).to(device)
moral_mask = torch.from_numpy(moral_mask).float().to(device)

model = CalmModel(d, moral_mask, tau=tau, lambda1=lambda1).to(device)

# Optimization loop
rho = rho_init
for _ in range(100):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for _ in range(subproblem_iter):
optimizer.zero_grad()
loss = model.compute_loss(cov_emp, rho)
loss.backward(retain_graph=True)
optimizer.step()

with torch.no_grad():
B_logit_copy = model.B_logit.detach().clone()
B_logit_copy[model.moral_mask == 0] = float('-inf')
h_sigmoid = model.compute_h(torch.sigmoid(B_logit_copy / model.tau))

rho *= rho_mult
if h_sigmoid.item() <= htol or rho > 1e+16:
break

# Extract the final binary and weighted adjacency matrices
params_est = model.get_params()
B_bin, B_weighted = params_est['B_bin'], params_est['B']

node_names = [("X%d" % (i + 1)) for i in range(d)]
nodes = [GraphNode(name) for name in node_names]
G = GeneralGraph(nodes)

# Add edges to the GeneralGraph based on B_bin
for i in range(d):
for j in range(d):
if B_bin[i, j] == 1:
G.add_directed_edge(nodes[i], nodes[j])

Record = {
"G": G, # GeneralGraph object representing the learned causal graph, a DAG
"B_weighted": B_weighted # Weighted adjacency matrix of the learned graph
}

return Record

class CalmModel(nn.Module):
"""
The CALM model
Parameters
----------
d : int
Number of variables/nodes in the graph.
moral_mask : torch.Tensor
Binary mask representing the moral graph structure, used to restrict possible edges.
tau : float, optional
Temperature parameter for the Gumbel-Sigmoid sampling, controlling the sparsity approximation. Default is 0.5.
lambda1 : float, optional
Coefficient for the approximated L0 penalty (sparsity term). Default is 0.005.
"""
def __init__(self, d, moral_mask, tau=0.5, lambda1=0.005):
super(CalmModel, self).__init__()
self.d = d
self.moral_mask = moral_mask
self.tau = tau
self.lambda1 = lambda1
self._init_params()

def _init_params(self):
"""Initialize parameters"""
self.B_param = nn.Parameter(
torch.FloatTensor(self.d, self.d).uniform_(-0.001, 0.001).to(self.moral_mask.device)
)
self.B_logit = nn.Parameter(
torch.zeros(self.d, self.d).to(self.moral_mask.device)
)

def sample_mask(self):
"""
Samples a binary mask B_mask based on the Gumbel-Sigmoid approximation.
Applies the moral graph mask to restrict possible edges.
"""
B_mask = gumbel_sigmoid(self.B_logit, tau=self.tau)
B_mask = B_mask * self.moral_mask
return B_mask

@torch.no_grad()
def get_params(self):
"""
Returns the estimated adjacency matrix B_bin (binary) and B (weighted), thresholding at 0.5.
"""
threshold = 0.5
B_param = self.B_param.cpu().detach().numpy()
B_logit = self.B_logit.cpu().detach().numpy()
B_logit[self.moral_mask.cpu().numpy() == 0] = float('-inf')
B_bin = sigmoid(B_logit / self.tau)
B_bin[B_bin < threshold] = 0
B_bin[B_bin >= threshold] = 1
B = B_bin * B_param
params = {'B': B, 'B_bin': B_bin}
return params

def compute_likelihood(self, B, cov_emp):
"""
Computes the likelihood-based objective function for non-equal noise variance (NV) assumption.
"""
I = torch.eye(self.d, device=self.B_param.device)
residuals = torch.diagonal((I - B).T @ cov_emp @ (I - B))
likelihood = 0.5 * torch.sum(torch.log(residuals)) - torch.linalg.slogdet(I - B)[1]
return likelihood

def compute_sparsity(self, B_mask):
"""
Computes the sparsity penalty (approximated L0 penalty) by summing the binary entries in B_mask.
"""
return B_mask.sum()

def compute_h(self, B_mask):
"""
Computes the DAG constraint term, adapted from the DAG constraint formulation
in Yu et al. (2019).
"""
return torch.trace(matrix_poly(B_mask, self.d, self.B_param.device)) - self.d

def compute_loss(self, cov_emp, rho):
"""
Combines likelihood, approximated L0 penalty (sparsity), and DAG constraint terms into the final loss function.
"""
B_mask = self.sample_mask()
B = B_mask * self.B_param
likelihood = self.compute_likelihood(B, cov_emp)
sparsity = self.lambda1 * self.compute_sparsity(B_mask)
h = self.compute_h(B_mask)
loss = likelihood + sparsity + 0.5 * rho * h**2
return loss


16 changes: 16 additions & 0 deletions causallearn/utils/CALMUtils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

def sample_logistic(shape, out=None):
U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
return torch.log(U) - torch.log(1-U)


def gumbel_sigmoid(logits, tau=1):
dims = logits.dim()
logistic_noise = sample_logistic(logits.size(), out=logits.data.new())
y = logits + logistic_noise
return torch.sigmoid(y / tau)

def matrix_poly(matrix, d, device):
x = torch.eye(d, device=device, dtype=matrix.dtype)+ torch.div(matrix, d)
return torch.matrix_power(x, d)
Empty file.
60 changes: 60 additions & 0 deletions causallearn/utils/MarkovNetwork/iamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import causallearn.utils.cit as cit
import numpy as np

def iamb_markov_network(X, alpha=0.05):
n, d = X.shape
markov_network_raw = np.zeros((d, d))
total_num_ci = 0
cond_indep_test = cit.CIT(X, 'fisherz')
# Estimate the markov blanket for each variable
for i in range(d):
markov_blanket, num_ci = iamb(cond_indep_test, d, i, alpha)
total_num_ci += num_ci
if len(markov_blanket) > 0:
markov_network_raw[i, markov_blanket] = 1
markov_network_raw[markov_blanket, i] = 1

# AND rule: (i, j) is an edge in the Markov network
# if and only if i and j are in Markov blanket of each other
# TODO: Check if whether we should use AND rule or OR rule
markov_network = np.logical_and(markov_network_raw, markov_network_raw.T).astype(float)
return markov_network, total_num_ci


def iamb(cond_indep_test, d, target, alpha):
# Modified from: https://github.com/wt-hu/pyCausalFS/blob/master/pyCausalFS/CBD/MBs/IAMB.py
markov_blanket = []
num_ci = 0
# Forward circulate phase
circulate_flag = True
while circulate_flag:
# if not change, forward phase of IAMB is finished.
circulate_flag = False
min_pval = float('inf')
y = None
variables = [i for i in range(d) if i != target and i not in markov_blanket]
for x in variables:
num_ci += 1
pval = cond_indep_test(target, x, markov_blanket)
# Choose maxsize of f(X:T|markov_blanket)
if pval <= alpha:
if pval < min_pval:
min_pval = pval
y = x

# if not condition independence the node,appended to markov_blanket
if y is not None:
markov_blanket.append(y)
circulate_flag = True

# Backward circulate phase
markov_blanket_temp = markov_blanket.copy()
for x in markov_blanket_temp:
# Exclude variable which need test p-value
condition_Variables=[i for i in markov_blanket if i != x]
num_ci += 1
pval = cond_indep_test(target, x, condition_Variables)
if pval > alpha:
markov_blanket.remove(x)

return list(set(markov_blanket)), num_ci

0 comments on commit 08655cd

Please sign in to comment.