Skip to content

Commit 3f60ea8

Browse files
committed
added MTL scripts
1 parent cc0d694 commit 3f60ea8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+5557
-0
lines changed

mtl_scripts/Copy_of_compute_roc.m

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
wdir = '/cns_zfs/mlearn/public_datasets/openfmri/posner/gp_da/';
2+
3+
4+
fnames{1} = 'erfnoopt_ValidHitObj_InvalidHitObj';
5+
fnames{2} = 'erf_ValidHitObj_InvalidHitObj';
6+
7+
fnames{3} = 'erfnoopt_ValidCue_ValidObj';
8+
fnames{4} = 'erf_ValidCue_ValidObj';
9+
10+
fnames{5} = 'erfnoopt_HCHObj_MissObj';
11+
fnames{6} = 'erf_HCHObj_MissObj';
12+
13+
suffix = {'_ValidHitObj_InvalidHitObj','_ValidCue_ValidObj','_HCHObj_MissObj'};
14+
15+
fh = {};
16+
for c= 1:length(fnames)/2
17+
fh{c} = figure;
18+
19+
nsub = 18;
20+
AUCno = [];
21+
TP = {}; FP = {};
22+
for s= 1:nsub
23+
%load([wdir,'single_subject/erfnoopt_sub',num2str(s),'_ValidCue_ValidObj']);
24+
%load([wdir,'single_subject/erfnoopt_sub',num2str(s),'_ValidHitObj_InvalidHitObj']);
25+
%load([wdir,'single_subject/erfnoopt_sub',num2str(s),'_HCHObj_MissObj']);
26+
27+
load([wdir,'single_subject/erfnoopt_sub',num2str(s),suffix{c}]);
28+
29+
[A, tp, fp] = roc(y,yhat);
30+
TP{s} = tp;
31+
FP{s} = fp;
32+
%plot(fp,tp,'k--','Linewidth',1);
33+
AUCno = [AUCno A];
34+
end
35+
36+
fpi = 0:0.01:1; FPi = []; TPi = [];
37+
for s = 1:nsub
38+
fp = FP{s};
39+
tp = TP{s};
40+
%plot(fp,tp,'k--','Linewidth',1); hold on
41+
42+
[fp, id] = unique(fp);
43+
tp = tp(id);
44+
%tpi = spline(fp,tp,fpi);
45+
tpi = interp1(fp,tp,fpi);
46+
%plot(fpi,tpi,'r--','Linewidth',1);
47+
%clf
48+
49+
FPi = [FPi; fpi];
50+
TPi = [TPi; tpi];
51+
end
52+
P = prctile(TPi,[25 50 75],1);
53+
%fill([fpi'; flipdim(fpi',1)], [P(1,:)'; flipdim(P(3,:)',1)], [7 7 7]/8);%, 'EdgeColor', [7 7 7]/8);
54+
hold on;
55+
%plot(fpi,P(2,:),'k','Linewidth',2);
56+
plot(fpi,mean(TPi),'k','Linewidth',2);
57+
end
58+
59+
%for f = 1:length(fnames)
60+
f = 1;
61+
for c= 1:length(fnames)/2
62+
figure(fh{c});
63+
64+
fnames{f}
65+
load([wdir,fnames{f}])
66+
[A, tp, fp] = roc(y,yhat);
67+
plot(fp,tp,'b','Linewidth',2); hold on
68+
69+
fnames{f+1}
70+
load([wdir,fnames{f+1}])
71+
[A, tp, fp] = roc(y,yhat);
72+
plot(fp,tp,'r','Linewidth',2);
73+
74+
chance=(1:size(y))/length(y);
75+
plot(chance,chance,'k--');
76+
xlabel('False Positive Rate');
77+
ylabel('True Positive Rate');
78+
79+
f= f+2;
80+
end
81+
% end
82+

mtl_scripts/Copy_of_gp_da_run.m

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
function [Acc] = gp_da_run(c1,c2,cov,output_name,opt)
2+
3+
% %clsnum = [1,2];
4+
% c1 = 1; c2 =2;
5+
% cov = 'lin';
6+
% output_name = ['../gp_da/t2ml_',mat2str(c1),'_v_',num2str(mat2str(c2)),'_YY0'];
7+
% opt.maxEval = 500;
8+
% opt.optimiseTheta = false;
9+
10+
addpath('/home/kkvi0203/svmdata/PD_MSA_PSP/prt/prt_mcode/mtl_clean');
11+
12+
% defaults
13+
try opt.optimiseTheta; catch, opt.optimiseTheta = true; end
14+
try opt.computeWeights; catch, opt.computeWeights = false; end
15+
try opt.normalizeK; catch, opt.normalizeK = false; end
16+
try opt.maxEval; catch, opt.maxEval = 500; end
17+
18+
[Xa,Ya,ID,~,classes] = load_data;
19+
20+
% get rid of fMRI tasks we are not considering
21+
Ya = [sum(Ya(:,c1),2), sum(Ya(:,c2),2)]; %Ya = Ya(:,clsnum);
22+
id = sum(Ya,2) > 0;
23+
Ya = Ya(id,:);
24+
Xa = Xa(id,:);
25+
ID = ID(id,:);
26+
27+
[X,y,ID,Y,Ys] = process_tasks(Xa,Ya,ID);
28+
[N, T] = size(Y);
29+
Nclassifiers = max(ID(:,1));
30+
Nfolds = length(unique(ID(:,2)));
31+
32+
[~, y_kfid] = find(Ys);
33+
y_kxid = 1:N;
34+
35+
% starting hyperparamers
36+
Kf00 = eye(T);
37+
kf00 = Kf00(tril(ones(T)) ~= 0);
38+
39+
% options
40+
switch cov
41+
case 'se'
42+
error('Squared exponential not adjusted yet');
43+
opt.CovFunc = 'covfunc_mtr_se'; kx0 = [log(1); log(1000); log(0.1*ones(T,1))];
44+
kx0 = [log(1); log(1); log(0.1*ones(T,1))];
45+
lh00 = [kf00; kx0];
46+
otherwise % linear
47+
opt.CovFunc = 'covfunc_mtr_nonblock'; kx0 = log(0.1*ones(T,1));
48+
%opt.CovFunc = 'covfunc_mtr'; kx0 = log(ones(T,1));
49+
lh00 = [kf00; kx0];
50+
end
51+
52+
% ----------------Main cross-validation loop----------------------------
53+
trall = zeros(N,Nfolds);
54+
teall = zeros(N,Nfolds);
55+
yhattr = cell(Nfolds,1);
56+
yhatte = cell(Nfolds,1);
57+
s2te = cell(Nfolds,1);
58+
Alpha = zeros(Nfolds,T,N);
59+
Hyp = zeros(Nfolds,length(lh00));
60+
%matlabpool('open');
61+
%par
62+
for f = 1:Nfolds
63+
fprintf('Outer loop %d of %d ...\n',f,Nfolds)
64+
optf = opt;
65+
66+
sid = ID(:,2) == f;
67+
if sum(sid) == 0, error(['No tasks found for fmri run ', num2str(r)]); end
68+
te = find(sid);
69+
tr = find(~sid);
70+
teall(:,f) = sid;
71+
trall(:,f) = ~sid;
72+
73+
% training mean
74+
Ystr = Ys(tr,:);
75+
%Ytr = Y(tr,:);
76+
%mtr = mean(Y(tr,:)); Mtr = repmat(mtr,length(tr),1); Ytr = Ytr - Mtr;
77+
78+
%fprintf('Standardising features ...\n ')
79+
Xz = (X - repmat(mean(X(tr,:)),N,1)) ./ repmat(std(X(tr,:)),N,1);
80+
Xz = Xz(:,logical(sum(isfinite(Xz))));
81+
Phi = Xz*Xz';
82+
if strcmp(opt.CovFunc,'covfunc_mtr_se') || opt.normalizeK
83+
disp('Normalizing kernel ...');
84+
Phi = prt_normalise_kernel(Phi);
85+
end
86+
87+
if opt.optimiseTheta
88+
% set initial hyperparameter values
89+
%Kf0 = eye(T);
90+
Kf0 = pinv(Ystr)*(y(tr)*y(tr)' ./ Phi(tr,tr))*pinv(Ystr)'; optf.Psi_prior = Kf0;
91+
kf0 = Kf0(tril(ones(T)) ~= 0);
92+
lh0 = [kf0; kx0];
93+
94+
[hyp nlmls] = minimize(lh0, @gp_mtr, opt.maxEval, {Phi(tr,tr), Ystr}, {y(tr)}, opt);
95+
%[hyp nlmls] = minimize(lh0, @gp_mtr_map, opt.maxEval, {Phi(tr,tr), Ytr}, Ytr, optf);
96+
else
97+
Kf0 = eye(T);
98+
kf0 = Kf0(tril(ones(T)) ~= 0);
99+
hyp = [kf0; kx0];
100+
end
101+
102+
C = feval(opt.CovFunc,{Phi, Ys},hyp);
103+
K = C(tr,tr);
104+
Ks = C(te,tr);
105+
kss = C(te,te);
106+
107+
[yhatte{f}, s2f] = gp_pred_mtr(K,Ks,kss,{y(tr)});
108+
yhattr{f} = gp_pred_mtr(K,K,K,{y(tr)});
109+
110+
%yhatte{te} = yhattte{f}';% + mtr;
111+
s2te{f} = diag(s2f);
112+
Hyp(f,:) = hyp';
113+
114+
% weights
115+
if opt.computeWeights
116+
disp('Computing weights ...');
117+
mask = '/cns_zfs/mlearn/public_datasets/openfmri/posner/masks/SPM_mask_46x55x39.img';
118+
[~,~,alpha] = gp_mtr(hyp, {Phi(tr,tr), Ystr}, {y(tr)}, opt);
119+
nvox = size(X,2);
120+
121+
Wm = alpha'*Xz(tr,:);
122+
Wmn = Wm ./ norm(Wm); % for visualisation only
123+
prt_write_nii(Wmn,mask,[output_name,'_W_fold_',num2str(f,'%02.0f'),'_meantask.img']);
124+
125+
W = zeros(T,nvox);
126+
for t = 1:T
127+
xid = find(ID(:,1) == t & ID(:,2) ~= f);
128+
aid = find(ID(tr,1) == t);
129+
130+
W(t,:) = alpha(aid)'*Xz(xid,:);
131+
132+
wn = W(t,:) ./ norm(W(t,:));
133+
prt_write_nii(wn,mask,[output_name,'_W_fold_',num2str(f,'%02.0f'),'_task',num2str(t),'.img']);
134+
end
135+
end
136+
137+
fprintf('Outer loop %d of %d done.\n',f,Nfolds)
138+
end
139+
matlabpool('close')
140+
141+
% reconstruct predictions
142+
yhat = zeros(N,1);
143+
s2 = zeros(N,1);
144+
for f = 1:Nfolds
145+
te = find(teall(:,f));
146+
yhat(te) = yhatte{f};
147+
s2(te) = s2te{f};
148+
end
149+
150+
% Reconstruct chol(Kx)' and Kf
151+
lmaxi = (T*(T+1)/2);
152+
Noise = zeros(Nfolds,T);
153+
Kf = zeros(T,T,Nfolds);
154+
for f = 1:Nfolds
155+
Lf = zeros(T);
156+
lf = Hyp(f,1:lmaxi)';
157+
id = tril(true(T));
158+
Lf(id) = lf;
159+
Kf(:,:,f) = Lf*Lf';
160+
Noise(f,:) = exp(Hyp(f,end-T+1:end));
161+
end
162+
163+
% compute accuracy
164+
Acc = zeros(Nclassifiers,1);
165+
Acc05 = zeros(Nclassifiers,1);
166+
for c = 1:Nclassifiers
167+
%c
168+
clsid = find(ID(:,4) == c);
169+
trlab = y(clsid) ~= 0;
170+
171+
Yf = [y(clsid) y(clsid(end:-1:1))];
172+
173+
prlab = opt_score(Yf,Yf,yhat(clsid));
174+
prlab05 = yhat(clsid) > 0.5;
175+
176+
Acc(c) = sum(trlab == prlab) ./ length(trlab);
177+
Acc05(c) = sum(trlab == prlab05) ./ length(trlab);
178+
end
179+
fprintf('Mean accuracy (OS): %02.2f\n',mean(Acc))
180+
fprintf('Mean accuracy (LR): %02.2f\n',mean(Acc05))
181+
182+
save(output_name ,'y','yhat','Acc','Hyp','Noise','Kf','Alpha','ID')
183+
184+
% check gradients
185+
% % fun = @(lh)gp_mtr(lh,X,Y,opt);
186+
% % [~,g] = gp_mtr(lh0,X,Y,opt);
187+
% fun = @(lh)gp_mtr(lh,{Phi, Y},Y,opt);
188+
% [~,g] =gp_mtr(lh0,{Phi, Y},Y,opt);
189+
% gnum = computeNumericalGradient(fun,lh0);

0 commit comments

Comments
 (0)