-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbaseline_mdi.py
59 lines (44 loc) · 1.75 KB
/
baseline_mdi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import time
import argparse
import sys
import os
import os.path as osp
import numpy as np
import torch
import pandas as pd
from uci.uci_subparser import add_uci_subparser
from training.baseline import baseline_mdi
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--method', type=str, default='mean')
parser.add_argument('--level', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', type=str, default='0')
parser.add_argument('--masking_distribution', type=str, default='uniform')
parser.add_argument('--corrupt', type=str, default="mcar")
parser.add_argument('--mar_rate_obs', type=float, default=0.1)
parser.add_argument('--mar_rate_missing', type=float, default=0.15)
parser.add_argument('--mnar_known_mask', type=float, default=0.1)
subparsers = parser.add_subparsers()
add_uci_subparser(subparsers)
args = parser.parse_args()
seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)
if args.domain == 'uci':
from uci.uci_data import load_data
data = load_data(args)
log_path = './{}/test/{}/{}_{}/'.format(args.domain,args.data,args.method,args.log_dir)
best_levels = {'mean':0, 'knn':3, 'svd':2, 'mice':2, 'spectral':1} # The i-th HPO setting of the corresponding method has the best performance
args.level = best_levels[args.method] if args.method in best_levels else None
# print(args.level)
res = []
for i in range(5):
seed = i
np.random.seed(seed)
torch.manual_seed(seed)
mae = baseline_mdi(data, args, log_path)
res.append(mae)
print(f'Mean MAE (x 10):{np.mean(res)}, std: {np.std(res)}')
if __name__ == '__main__':
main()