Skip to content

Commit 4f6396f

Browse files
taoranngyzhou2000
andauthored
about FatraGNN, including model,datasets and example (#218)
* about FatraGNN, including model,datasets and example * update * Revert "update" This reverts commit fa9101b. * update * Modifications as required * update github action --------- Co-authored-by: Guangyu Zhou <[email protected]> Co-authored-by: gyzhou2000 <[email protected]>
1 parent ce3b426 commit 4f6396f

File tree

10 files changed

+1014
-1
lines changed

10 files changed

+1014
-1
lines changed

.github/workflows/test_push.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ jobs:
3838
run: |
3939
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
4040
41+
- name: Install Tensorflow
42+
run: |
43+
pip install tensorflow==2.11.0
44+
4145
- name: Install llvmlite
4246
run: |
4347
pip install llvmlite

.github/workflows/test_pypi_package.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ jobs:
3030
run: |
3131
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
3232
33+
- name: Install Tensorflow
34+
run: |
35+
pip install tensorflow==2.11.0
36+
3337
- name: Install llvmlite
3438
run: |
3539
pip install llvmlite

examples/fatragnn/config.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
bail:
2+
epochs: 400
3+
g_epochs: 5
4+
a_epochs: 4
5+
cla_epochs: 10
6+
dic_epochs: 8
7+
dtb_epochs: 5
8+
d_lr: 0.001
9+
c_lr: 0.005
10+
e_lr: 0.005
11+
g_lr: 0.05
12+
drope_rate: 0.1
13+
credit:
14+
epochs: 600
15+
g_epochs: 5
16+
a_epochs: 2
17+
cla_epochs: 12
18+
dic_epochs: 5
19+
dtb_epochs: 5
20+
d_lr: 0.001
21+
c_lr: 0.01
22+
e_lr: 0.01
23+
g_lr: 0.05
24+
drope_rate: 0.1

examples/fatragnn/fatragnn_trainer.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
import os
2+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3+
os.environ['TL_BACKEND'] = 'torch'
4+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
5+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
6+
import tensorlayerx as tlx
7+
from gammagl.models import FatraGNNModel
8+
import argparse
9+
import numpy as np
10+
from tensorlayerx.model import TrainOneStep, WithLoss
11+
from sklearn.metrics import roc_auc_score
12+
import scipy.sparse as sp
13+
import yaml
14+
from gammagl.datasets import Bail
15+
from gammagl.datasets import Credit
16+
17+
18+
def fair_metric(pred, labels, sens):
19+
idx_s0 = sens == 0
20+
idx_s1 = sens == 1
21+
idx_s0_y1 = np.bitwise_and(idx_s0, labels == 1)
22+
idx_s1_y1 = np.bitwise_and(idx_s1, labels == 1)
23+
parity = abs(sum(pred[idx_s0]) / sum(idx_s0) -
24+
sum(pred[idx_s1]) / sum(idx_s1))
25+
equality = abs(sum(pred[idx_s0_y1]) / sum(idx_s0_y1) -
26+
sum(pred[idx_s1_y1]) / sum(idx_s1_y1))
27+
return parity.item(), equality.item()
28+
29+
30+
def evaluate_ged3(net, x, edge_index, y, test_mask, sens):
31+
net.set_eval()
32+
flag = 0
33+
output = net(x, edge_index, flag)
34+
pred_test = tlx.cast(tlx.squeeze(output[test_mask], axis=-1) > 0, y.dtype)
35+
36+
acc_nums_test = (pred_test == y[test_mask])
37+
accs = np.sum(tlx.convert_to_numpy(acc_nums_test))/np.sum(tlx.convert_to_numpy(test_mask))
38+
39+
auc_rocs = roc_auc_score(tlx.convert_to_numpy(y[test_mask]), tlx.convert_to_numpy(output[test_mask]))
40+
paritys, equalitys = fair_metric(tlx.convert_to_numpy(pred_test), tlx.convert_to_numpy(y[test_mask]), tlx.convert_to_numpy(sens[test_mask]))
41+
42+
return accs, auc_rocs, paritys, equalitys
43+
44+
45+
class DicLoss(WithLoss):
46+
def __init__(self, net, loss_fn):
47+
super(DicLoss, self).__init__(backbone=net, loss_fn=loss_fn)
48+
49+
def forward(self, data, label):
50+
output = self.backbone_network(data['x'], data['edge_index'], data['flag'])
51+
loss = tlx.losses.binary_cross_entropy(tlx.squeeze(output, axis=-1), tlx.cast(data['sens'], dtype=tlx.float32))
52+
return loss
53+
54+
55+
class EncClaLoss(WithLoss):
56+
def __init__(self, net, loss_fn):
57+
super(EncClaLoss, self).__init__(backbone=net, loss_fn=loss_fn)
58+
59+
def forward(self, data, label):
60+
output = self.backbone_network(data['x'], data['edge_index'], data['flag'])
61+
y_train = tlx.cast(tlx.expand_dims(label[data['train_mask']], axis=1), dtype=tlx.float32)
62+
loss = tlx.losses.binary_cross_entropy(output[data['train_mask']], y_train)
63+
return loss
64+
65+
66+
class EncLoss(WithLoss):
67+
def __init__(self, net, loss_fn):
68+
super(EncLoss, self).__init__(backbone=net, loss_fn=loss_fn)
69+
70+
def forward(self, data, label):
71+
output = self.backbone_network(data['x'], data['edge_index'], data['flag'])
72+
loss = tlx.losses.mean_squared_error(output, 0.5 * tlx.ones_like(output))
73+
return loss
74+
75+
76+
class EdtLoss(WithLoss):
77+
def __init__(self, net, loss_fn):
78+
super(EdtLoss, self).__init__(backbone=net, loss_fn=loss_fn)
79+
80+
def forward(self, data, label):
81+
output = self.backbone_network(data['x'], data['edge_index'], data['flag'])
82+
loss = -tlx.abs(tlx.reduce_sum(output[data['train_mask']][data['t_idx_s0_y1']])) / tlx.reduce_sum(tlx.cast(data['t_idx_s0_y1'], dtype=tlx.float32)) - tlx.reduce_sum(output[data['train_mask']][data['t_idx_s1_y1']]) / tlx.reduce_sum(tlx.cast(data['t_idx_s1_y1'], dtype=tlx.float32))
83+
84+
return loss
85+
86+
87+
class AliLoss(WithLoss):
88+
def __init__(self, net, loss_fn):
89+
super(AliLoss, self).__init__(backbone=net, loss_fn=loss_fn)
90+
91+
def forward(self, data, label):
92+
output = self.backbone_network(data['x'], data['edge_index'], data['flag'])
93+
h1 = output['h1']
94+
h2 = output['h2']
95+
idx_s0_y0 = data['idx_s0_y0']
96+
idx_s1_y0 = data['idx_s1_y0']
97+
idx_s0_y1 = data['idx_s0_y1']
98+
idx_s1_y1 = data['idx_s1_y1']
99+
node_num = data['x'].shape[0]
100+
loss_align = - node_num / (tlx.reduce_sum(tlx.cast(idx_s0_y0, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s0_y0], tlx.transpose(h2[idx_s0_y0]))) \
101+
- node_num / (tlx.reduce_sum(tlx.cast(idx_s0_y1, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s0_y1], tlx.transpose(h2[idx_s0_y1]))) \
102+
- node_num / (tlx.reduce_sum(tlx.cast(idx_s1_y0, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s1_y0], tlx.transpose(h2[idx_s1_y0]))) \
103+
- node_num / (tlx.reduce_sum(tlx.cast(idx_s1_y1, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s1_y1], tlx.transpose(h2[idx_s1_y1])))
104+
105+
loss = loss_align * 0.01
106+
return loss
107+
108+
109+
def main(args):
110+
111+
# load datasets
112+
if str.lower(args.dataset) not in ['bail', 'credit', 'pokec']:
113+
raise ValueError('Unknown dataset: {}'.format(args.dataset))
114+
115+
if args.dataset == 'bail':
116+
dataset = Bail(args.dataset_path, args.dataset)
117+
118+
elif args.dataset == 'credit':
119+
dataset = Credit(args.dataset_path, args.dataset)
120+
121+
graphs = dataset.data
122+
data = {
123+
'x':graphs[0].x,
124+
'y': graphs[0].y,
125+
'edge_index': {'edge_index': graphs[0].edge_index},
126+
'sens': graphs[0].sens,
127+
'train_mask': graphs[0].train_mask,
128+
}
129+
data_test = []
130+
for i in range(1, len(graphs)):
131+
data_tem = {
132+
'x':graphs[i].x,
133+
'y': graphs[i].y,
134+
'edge_index': graphs[i].edge_index,
135+
'sens': graphs[i].sens,
136+
'test_mask': graphs[i].train_mask | graphs[i].val_mask | graphs[i].test_mask,
137+
}
138+
data_test.append(data_tem)
139+
dataset = None
140+
graphs = None
141+
args.num_features, args.num_classes = data['x'].shape[1], len(np.unique(tlx.convert_to_numpy(data['y']))) - 1
142+
args.test_set_num = len(data_test)
143+
144+
t_idx_s0 = data['sens'][data['train_mask']] == 0
145+
t_idx_s1 = data['sens'][data['train_mask']] == 1
146+
t_idx_s0_y1 = tlx.logical_and(t_idx_s0, data['y'][data['train_mask']] == 1)
147+
t_idx_s1_y1 = tlx.logical_and(t_idx_s1, data['y'][data['train_mask']] == 1)
148+
149+
idx_s0 = data['sens'] == 0
150+
idx_s1 = data['sens'] == 1
151+
idx_s0_y1 = tlx.logical_and(idx_s0, data['y'] == 1)
152+
idx_s1_y1 = tlx.logical_and(idx_s1, data['y'] == 1)
153+
idx_s0_y0 = tlx.logical_and(idx_s0, data['y'] == 0)
154+
idx_s1_y0 = tlx.logical_and(idx_s1, data['y'] == 0)
155+
156+
data['idx_s0_y0'] = idx_s0_y0
157+
data['idx_s1_y0'] = idx_s1_y0
158+
data['idx_s0_y1'] = idx_s0_y1
159+
data['idx_s1_y1'] = idx_s1_y1
160+
data['t_idx_s0_y1'] = t_idx_s0_y1
161+
data['t_idx_s1_y1'] = t_idx_s1_y1
162+
163+
edge_index_np = tlx.convert_to_numpy(data['edge_index']['edge_index'])
164+
adj = sp.coo_matrix((np.ones(data['edge_index']['edge_index'].shape[1]), (edge_index_np[0, :], edge_index_np[1, :])),
165+
shape=(data['x'].shape[0], data['x'].shape[0]),
166+
dtype=np.float32)
167+
A2 = adj.dot(adj)
168+
A2 = A2.toarray()
169+
A2_edge = tlx.convert_to_tensor(np.vstack((A2.nonzero()[0], A2.nonzero()[1])))
170+
171+
net = FatraGNNModel(args)
172+
173+
dic_loss_func = DicLoss(net, tlx.losses.binary_cross_entropy)
174+
enc_cla_loss_func = EncClaLoss(net, tlx.losses.binary_cross_entropy)
175+
enc_loss_func = EncLoss(net, tlx.losses.binary_cross_entropy)
176+
edt_loss_func = EdtLoss(net, tlx.losses.binary_cross_entropy)
177+
ali_loss_func = AliLoss(net, tlx.losses.binary_cross_entropy)
178+
179+
dic_opt = tlx.optimizers.Adam(lr=args.d_lr, weight_decay=args.d_wd)
180+
dic_train_one_step = TrainOneStep(dic_loss_func, dic_opt, net.discriminator.trainable_weights)
181+
182+
enc_cla_opt = tlx.optimizers.Adam(lr=args.c_lr, weight_decay=args.c_wd)
183+
enc_cla_train_one_step = TrainOneStep(enc_cla_loss_func, enc_cla_opt, net.encoder.trainable_weights+net.classifier.trainable_weights)
184+
185+
enc_opt = tlx.optimizers.Adam(lr=args.e_lr, weight_decay=args.e_wd)
186+
enc_train_one_step = TrainOneStep(enc_loss_func, enc_opt, net.encoder.trainable_weights)
187+
188+
edt_opt = tlx.optimizers.Adam(lr=args.g_lr, weight_decay=args.g_wd)
189+
edt_train_one_step = TrainOneStep(edt_loss_func, edt_opt, net.graphEdit.trainable_weights)
190+
191+
ali_opt = tlx.optimizers.Adam(lr=args.e_lr, weight_decay=args.e_wd)
192+
ali_train_one_step = TrainOneStep(ali_loss_func, ali_opt, net.encoder.trainable_weights)
193+
194+
tlx.set_seed(args.seed)
195+
net.set_train()
196+
for epoch in range(0, args.epochs):
197+
print(f"======={epoch}=======")
198+
# train discriminator to recognize the sensitive group
199+
data['flag'] = 1
200+
for epoch_d in range(0, args.dic_epochs):
201+
dic_loss = dic_train_one_step(data=data, label=data['y'])
202+
203+
# train classifier and encoder
204+
data['flag'] = 2
205+
for epoch_c in range(0, args.cla_epochs):
206+
enc_cla_loss = enc_cla_train_one_step(data=data, label=data['y'])
207+
208+
# train encoder to fool discriminator
209+
data['flag'] = 3
210+
for epoch_g in range(0, args.g_epochs):
211+
enc_loss = enc_train_one_step(data=data, label=data['y'])
212+
213+
# train generator
214+
data['flag'] = 4
215+
if epoch > args.start:
216+
if epoch % 10 == 0:
217+
if epoch % 20 == 0:
218+
data['edge_index']['edge_index2'] = net.graphEdit.modify_structure1(data['edge_index']['edge_index'], A2_edge, data['sens'], data['x'].shape[0], args.drope_rate)
219+
else:
220+
data['edge_index']['edge_index2'] = net.graphEdit.modify_structure2(data['edge_index']['edge_index'], A2_edge, data['sens'], data['x'].shape[0], args.drope_rate)
221+
else:
222+
data['edge_index']['edge_index2'] = data['edge_index']['edge_index']
223+
224+
for epoch_g in range(0, args.dtb_epochs):
225+
edt_loss = edt_train_one_step(data=data, label=data['y'])
226+
227+
# shift align
228+
data['flag'] = 5
229+
if epoch > args.start:
230+
for epoch_a in range(0, args.a_epochs):
231+
aliloss = ali_train_one_step(data=data, label=data['y'])
232+
233+
acc = np.zeros([args.test_set_num])
234+
auc_roc = np.zeros([args.test_set_num])
235+
parity = np.zeros([args.test_set_num])
236+
equality = np.zeros([args.test_set_num])
237+
net.set_eval()
238+
for i in range(args.test_set_num):
239+
data_tem = data_test[i]
240+
acc[i],auc_roc[i], parity[i], equality[i] = evaluate_ged3(net, data_tem['x'], data_tem['edge_index'], data_tem['y'], data_tem['test_mask'], data_tem['sens'])
241+
return acc, auc_roc, parity, equality
242+
243+
if __name__ == '__main__':
244+
parser = argparse.ArgumentParser()
245+
parser.add_argument('--dataset', type=str, default='bail')
246+
parser.add_argument('--start', type=int, default=50)
247+
parser.add_argument('--epochs', type=int, default=400)
248+
parser.add_argument('--dic_epochs', type=int, default=5)
249+
parser.add_argument('--dtb_epochs', type=int, default=5)
250+
parser.add_argument('--cla_epochs', type=int, default=12)
251+
parser.add_argument('--a_epochs', type=int, default=2)
252+
parser.add_argument('--g_epochs', type=int, default=5)
253+
parser.add_argument('--g_lr', type=float, default=0.05)
254+
parser.add_argument('--g_wd', type=float, default=0.01)
255+
parser.add_argument('--d_lr', type=float, default=0.001)
256+
parser.add_argument('--d_wd', type=float, default=0)
257+
parser.add_argument('--c_lr', type=float, default=0.001)
258+
parser.add_argument('--c_wd', type=float, default=0.01)
259+
parser.add_argument('--e_lr', type=float, default=0.005)
260+
parser.add_argument('--e_wd', type=float, default=0)
261+
parser.add_argument('--hidden', type=int, default=128)
262+
parser.add_argument('--seed', type=int, default=3)
263+
parser.add_argument('--top_k', type=int, default=10)
264+
parser.add_argument('--gpu', type=int, default=1)
265+
parser.add_argument('--drope_rate', type=float, default=0.1)
266+
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
267+
268+
args = parser.parse_args()
269+
270+
if args.gpu >= 0:
271+
tlx.set_device("GPU", args.gpu)
272+
else:
273+
tlx.set_device("CPU")
274+
args.device = f'cuda:{args.gpu}'
275+
276+
277+
fileNamePath = os.path.split(os.path.realpath(__file__))[0]
278+
yamlPath = os.path.join(fileNamePath, 'config.yaml')
279+
with open(yamlPath, 'r', encoding='utf-8') as f:
280+
cont = f.read()
281+
config_dict = yaml.safe_load(cont)[args.dataset]
282+
for key, value in config_dict.items():
283+
args.__setattr__(key, value)
284+
285+
print(args)
286+
acc, auc_roc, parity, equality = main(args)
287+
288+
for i in range(args.test_set_num):
289+
print("===========test{}============".format(i+1))
290+
print('Acc: ', acc.T[i])
291+
print('auc_roc: ', auc_roc.T[i])
292+
print('parity: ', parity.T[i])
293+
print('equality: ', equality.T[i])

0 commit comments

Comments
 (0)