1
+ function [Acc ] = gp_da_run(c1 ,c2 ,cov ,output_name ,opt )
2
+
3
+ % %clsnum = [1,2];
4
+ % c1 = 1; c2 =2;
5
+ % cov = 'lin';
6
+ % output_name = ['../gp_da/t2ml_',mat2str(c1),'_v_',num2str(mat2str(c2)),'_YY0'];
7
+ % opt.maxEval = 500;
8
+ % opt.optimiseTheta = false;
9
+
10
+ addpath(' /home/kkvi0203/svmdata/PD_MSA_PSP/prt/prt_mcode/mtl_clean' );
11
+
12
+ % defaults
13
+ try opt .optimiseTheta ; catch , opt.optimiseTheta = true ; end
14
+ try opt .computeWeights ; catch , opt.computeWeights = false ; end
15
+ try opt .normalizeK ; catch , opt.normalizeK = false ; end
16
+ try opt .maxEval ; catch , opt.maxEval = 500 ; end
17
+
18
+ [Xa ,Ya ,ID ,~ ,classes ] = load_data ;
19
+
20
+ % get rid of fMRI tasks we are not considering
21
+ Ya = [sum(Ya(: ,c1 ),2 ), sum(Ya(: ,c2 ),2 )]; % Ya = Ya(:,clsnum);
22
+ id = sum(Ya ,2 ) > 0 ;
23
+ Ya = Ya(id ,: );
24
+ Xa = Xa(id ,: );
25
+ ID = ID(id ,: );
26
+
27
+ [X ,y ,ID ,Y ,Ys ] = process_tasks(Xa ,Ya ,ID );
28
+ [N , T ] = size(Y );
29
+ Nclassifiers = max(ID(: ,1 ));
30
+ Nfolds = length(unique(ID(: ,2 )));
31
+
32
+ [~ , y_kfid ] = find(Ys );
33
+ y_kxid = 1 : N ;
34
+
35
+ % starting hyperparamers
36
+ Kf00 = eye(T );
37
+ kf00 = Kf00(tril(ones(T )) ~= 0 );
38
+
39
+ % options
40
+ switch cov
41
+ case ' se'
42
+ error(' Squared exponential not adjusted yet' );
43
+ opt.CovFunc = ' covfunc_mtr_se' ; kx0 = [log(1 ); log(1000 ); log(0.1 * ones(T ,1 ))];
44
+ kx0 = [log(1 ); log(1 ); log(0.1 * ones(T ,1 ))];
45
+ lh00 = [kf00 ; kx0 ];
46
+ otherwise % linear
47
+ opt.CovFunc = ' covfunc_mtr_nonblock' ; kx0 = log(0.1 * ones(T ,1 ));
48
+ % opt.CovFunc = 'covfunc_mtr'; kx0 = log(ones(T,1));
49
+ lh00 = [kf00 ; kx0 ];
50
+ end
51
+
52
+ % ----------------Main cross-validation loop----------------------------
53
+ trall = zeros(N ,Nfolds );
54
+ teall = zeros(N ,Nfolds );
55
+ yhattr = cell(Nfolds ,1 );
56
+ yhatte = cell(Nfolds ,1 );
57
+ s2te = cell(Nfolds ,1 );
58
+ Alpha = zeros(Nfolds ,T ,N );
59
+ Hyp = zeros(Nfolds ,length(lh00 ));
60
+ % matlabpool('open');
61
+ % par
62
+ for f = 1 : Nfolds
63
+ fprintf(' Outer loop %d of %d ...\n ' ,f ,Nfolds )
64
+ optf = opt ;
65
+
66
+ sid = ID(: ,2 ) == f ;
67
+ if sum(sid ) == 0 , error([' No tasks found for fmri run ' , num2str(r )]); end
68
+ te = find(sid );
69
+ tr = find(~sid );
70
+ teall(: ,f ) = sid ;
71
+ trall(: ,f ) = ~sid ;
72
+
73
+ % training mean
74
+ Ystr = Ys(tr ,: );
75
+ % Ytr = Y(tr,:);
76
+ % mtr = mean(Y(tr,:)); Mtr = repmat(mtr,length(tr),1); Ytr = Ytr - Mtr;
77
+
78
+ % fprintf('Standardising features ...\n ')
79
+ Xz = (X - repmat(mean(X(tr ,: )),N ,1 )) ./ repmat(std(X(tr ,: )),N ,1 );
80
+ Xz = Xz(: ,logical(sum(isfinite(Xz ))));
81
+ Phi = Xz * Xz ' ;
82
+ if strcmp(opt .CovFunc ,' covfunc_mtr_se' ) || opt .normalizeK
83
+ disp(' Normalizing kernel ...' );
84
+ Phi = prt_normalise_kernel(Phi );
85
+ end
86
+
87
+ if opt .optimiseTheta
88
+ % set initial hyperparameter values
89
+ % Kf0 = eye(T);
90
+ Kf0 = pinv(Ystr )*(y(tr )*y(tr )' ./ Phi(tr ,tr ))*pinv(Ystr )' ; optf.Psi_prior = Kf0 ;
91
+ kf0 = Kf0(tril(ones(T )) ~= 0 );
92
+ lh0 = [kf0 ; kx0 ];
93
+
94
+ [hyp nlmls ] = minimize(lh0 , @gp_mtr , opt .maxEval , {Phi(tr ,tr ), Ystr }, {y(tr )}, opt );
95
+ % [hyp nlmls] = minimize(lh0, @gp_mtr_map, opt.maxEval, {Phi(tr,tr), Ytr}, Ytr, optf);
96
+ else
97
+ Kf0 = eye(T );
98
+ kf0 = Kf0(tril(ones(T )) ~= 0 );
99
+ hyp = [kf0 ; kx0 ];
100
+ end
101
+
102
+ C = feval(opt .CovFunc ,{Phi , Ys },hyp );
103
+ K = C(tr ,tr );
104
+ Ks = C(te ,tr );
105
+ kss = C(te ,te );
106
+
107
+ [yhatte{f }, s2f ] = gp_pred_mtr(K ,Ks ,kss ,{y(tr )});
108
+ yhattr{f } = gp_pred_mtr(K ,K ,K ,{y(tr )});
109
+
110
+ % yhatte{te} = yhattte{f}';% + mtr;
111
+ s2te{f } = diag(s2f );
112
+ Hyp(f ,: ) = hyp ' ;
113
+
114
+ % weights
115
+ if opt .computeWeights
116
+ disp(' Computing weights ...' );
117
+ mask = ' /cns_zfs/mlearn/public_datasets/openfmri/posner/masks/SPM_mask_46x55x39.img' ;
118
+ [~ ,~ ,alpha ] = gp_mtr(hyp , {Phi(tr ,tr ), Ystr }, {y(tr )}, opt );
119
+ nvox = size(X ,2 );
120
+
121
+ Wm = alpha ' *Xz(tr ,: );
122
+ Wmn = Wm ./ norm(Wm ); % for visualisation only
123
+ prt_write_nii(Wmn ,mask ,[output_name ,' _W_fold_' ,num2str(f ,' %02.0f ' ),' _meantask.img' ]);
124
+
125
+ W = zeros(T ,nvox );
126
+ for t = 1 : T
127
+ xid = find(ID(: ,1 ) == t & ID(: ,2 ) ~= f );
128
+ aid = find(ID(tr ,1 ) == t );
129
+
130
+ W(t ,: ) = alpha(aid )' *Xz(xid ,: );
131
+
132
+ wn = W(t ,: ) ./ norm(W(t ,: ));
133
+ prt_write_nii(wn ,mask ,[output_name ,' _W_fold_' ,num2str(f ,' %02.0f ' ),' _task' ,num2str(t ),' .img' ]);
134
+ end
135
+ end
136
+
137
+ fprintf(' Outer loop %d of %d done.\n ' ,f ,Nfolds )
138
+ end
139
+ matlabpool(' close' )
140
+
141
+ % reconstruct predictions
142
+ yhat = zeros(N ,1 );
143
+ s2 = zeros(N ,1 );
144
+ for f = 1 : Nfolds
145
+ te = find(teall(: ,f ));
146
+ yhat(te ) = yhatte{f };
147
+ s2(te ) = s2te{f };
148
+ end
149
+
150
+ % Reconstruct chol(Kx)' and Kf
151
+ lmaxi = (T *(T + 1 )/2 );
152
+ Noise = zeros(Nfolds ,T );
153
+ Kf = zeros(T ,T ,Nfolds );
154
+ for f = 1 : Nfolds
155
+ Lf = zeros(T );
156
+ lf = Hyp(f ,1 : lmaxi )' ;
157
+ id = tril(true(T ));
158
+ Lf(id ) = lf ;
159
+ Kf(: ,: ,f ) = Lf * Lf ' ;
160
+ Noise(f ,: ) = exp(Hyp(f ,end - T + 1 : end ));
161
+ end
162
+
163
+ % compute accuracy
164
+ Acc = zeros(Nclassifiers ,1 );
165
+ Acc05 = zeros(Nclassifiers ,1 );
166
+ for c = 1 : Nclassifiers
167
+ % c
168
+ clsid = find(ID(: ,4 ) == c );
169
+ trlab = y(clsid ) ~= 0 ;
170
+
171
+ Yf = [y(clsid ) y(clsid(end : -1 : 1 ))];
172
+
173
+ prlab = opt_score(Yf ,Yf ,yhat(clsid ));
174
+ prlab05 = yhat(clsid ) > 0.5 ;
175
+
176
+ Acc(c ) = sum(trlab == prlab ) ./ length(trlab );
177
+ Acc05(c ) = sum(trlab == prlab05 ) ./ length(trlab );
178
+ end
179
+ fprintf(' Mean accuracy (OS): %02.2f\n ' ,mean(Acc ))
180
+ fprintf(' Mean accuracy (LR): %02.2f\n ' ,mean(Acc05 ))
181
+
182
+ save(output_name ,' y' ,' yhat' ,' Acc' ,' Hyp' ,' Noise' ,' Kf' ,' Alpha' ,' ID' )
183
+
184
+ % check gradients
185
+ % % fun = @(lh)gp_mtr(lh,X,Y,opt);
186
+ % % [~,g] = gp_mtr(lh0,X,Y,opt);
187
+ % fun = @(lh)gp_mtr(lh,{Phi, Y},Y,opt);
188
+ % [~,g] =gp_mtr(lh0,{Phi, Y},Y,opt);
189
+ % gnum = computeNumericalGradient(fun,lh0);
0 commit comments