Skip to content

Commit 48fcd03

Browse files
author
amarquand
committed
Major revision
- fixed bug in computation of posterior (m was not computed properly) - uses sparse matrices where appropriate (for speed) - revised blr code for efficiency and numerical stability - added blr_multi for cases where there are multiple independent samples from the posterior
1 parent 08cd25c commit 48fcd03

File tree

3 files changed

+185
-30
lines changed

3 files changed

+185
-30
lines changed

blr.m

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,44 +35,58 @@
3535
end
3636

3737
if Nalpha == D
38-
Sigma = diag(1./alpha); % weight prior covariance
39-
iSigma = diag(alpha); % weight prior precision
38+
Sigma = spdiags(1./alpha,0,D,D);
39+
invSigma = spdiags(alpha,0,D,D);
4040
else
41-
Sigma = 1./alpha*eye(D); % weight prior covariance
42-
iSigma = alpha*eye(D); % weight prior precision
41+
Sigma = 1./alpha*speye(D); % weight prior covariance
42+
invSigma = alpha*speye(D); % weight prior precision
4343
end
4444

45+
% useful quantities
4546
XX = X'*X;
46-
A = beta*XX + iSigma; % posterior precision
47-
Q = A\X';
48-
m = beta*Q*t; % posterior mean
47+
A = beta*XX + invSigma; % posterior precision
48+
S = inv(A); % posterior covariance
49+
Q = S*X';
50+
m = beta*Q*t; % posterior mean
51+
52+
% compute like this for to avoid numerical overflow
53+
logdetA = 2*sum(log(diag(chol(A))));
54+
logdetSigma = sum(log(diag(A))); % assumes Sigma is diagonal
4955

5056
if nargin == 3
51-
nlZ = -0.5*( N*log(beta) - N*log(2*pi) - log(det(Sigma)) ...
52-
- beta*(t-X*m)'*(t-X*m) - m'*iSigma*m - log(det(A)) );
57+
nlZ = -0.5*( N*log(beta) - N*log(2*pi) - logdetSigma ...
58+
- beta*(t-X*m)'*(t-X*m) - m'*invSigma*m - ...
59+
logdetA );
5360

5461
if nargout > 1 % derivatives?
5562
dnlZ = zeros(size(hyp));
5663
b = (eye(D) - beta*Q*X)*Q*t;
5764

58-
% noise precision
59-
dnlZ(1) = -( N/(2*beta) - 0.5*(t'*t) + t'*X*m + beta*t'*X*b - 0.5*m'*XX*m ...
60-
- beta*b'*XX*m - b'*iSigma*m -0.5*trace(Q*X) )*beta;
65+
% repeatedly computed quantities for derivatives
66+
Xt = X'*t;
67+
XXm = XX*m;
68+
SXt = S*Xt;
6169

70+
% noise precision
71+
dnlZ(1) = -( N/(2*beta) - 0.5*(t'*t) + t'*X*m + beta*t'*X*b - 0.5*m'*XXm ...
72+
- beta*b'*XXm - b'*invSigma*m -0.5*trace(Q*X) )*beta;
73+
6274
% variance parameters
6375
for i = 1:Nalpha
6476
if Nalpha == D % use ARD?
65-
dSigma = zeros(D);
66-
dSigma(i,i) = -alpha(i)^-2; % if alpha is the precision
77+
dSigma = sparse(i,i,-alpha(i)^-2,D,D);
78+
dinvSigma = sparse(i,i,1,D,D);
6779
else
68-
dSigma = -alpha(i)^-2*eye(D);
80+
dSigma = -alpha(i)^-2*speye(D);
81+
dinvSigma = speye(D);
6982
end
7083

71-
F = -iSigma*dSigma*iSigma;
72-
c = -beta*F*X'*t;
84+
F = dinvSigma;
85+
c = -beta*S*F*SXt;
7386

74-
dnlZ(i+1) = -( -0.5*trace(iSigma*dSigma) + beta*t'*X*c - beta*c'*XX*m ...
75-
- c'*iSigma*m - 0.5*m'*F*m - 0.5*trace(A\F) )*alpha(i);
87+
dnlZ(i+1) = -( -0.5*sum(sum(invSigma.*dSigma')) + ...
88+
beta*Xt'*c - beta*c'*XXm - c'*invSigma*m - ...
89+
0.5*m'*F*m - 0.5*trace(A\F) )*alpha(i);
7690
end
7791
post.m = m;
7892
post.A = A;

blr_multi.m

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
function [out1, out2, post] = blr_multi(hyp, X, T, xs)
2+
3+
% Bayesian linear regression (multiple independent targets)
4+
%
5+
% Fits a bayesian linear regression model, where the inputs are:
6+
% hyp : vector of hyperparmaters. hyp = [log(beta); log(alpha)]
7+
% X : N x D data matrix
8+
% t : N x 1 vector of targets
9+
% xs : Nte x D matrix of test cases
10+
%
11+
% The hyperparameter beta is the noise precision and alpha is the precision
12+
% over lengthscale parameters. This can be either a scalar variable (a
13+
% common lengthscale for all input variables), or a vector of length D (a
14+
% different lengthscale for each input variable, derived using an automatic
15+
% relevance determination formulation).
16+
%
17+
% The main difference between this version and the vanilla version of blr
18+
% is that this version precomputes lots of quantities that are used
19+
% repeatedly for computing i.i.d samples with the same posterior covariance
20+
% (i.e. when T is a matrix). for such cases this is more efficient than
21+
% computing each separately.
22+
%
23+
% Two modes are supported:
24+
% [nlZ, dnlZ, post] = blr(hyp, x, y); % report evidence and derivatives
25+
% [mu, s2, post] = blr(hyp, x, y, xs); % predictive mean and variance
26+
%
27+
% Written by A. Marquand
28+
29+
if nargin<3 || nargin>4
30+
disp('Usage: [nlZ dnlZ] = blr(hyp, x, y);')
31+
disp(' or: [mu s2 ] = blr(hyp, x, y, xs);')
32+
return
33+
end
34+
35+
[N,D] = size(X);
36+
Nrep = size(T,2);
37+
beta = exp(hyp(1)); % noise precision
38+
alpha = exp(hyp(2:end)); % weight precisions
39+
Nalpha = length(alpha);
40+
if Nalpha ~= 1 && Nalpha ~= D
41+
error('hyperparameter vector has invalid length');
42+
end
43+
44+
if Nalpha == D
45+
Sigma = diag(1./alpha); % weight prior covariance
46+
invSigma = diag(alpha); % weight prior precision
47+
else
48+
Sigma = 1./alpha*eye(D); % weight prior covariance
49+
invSigma = alpha*eye(D); % weight prior precision
50+
end
51+
Sigma = sparse(Sigma);
52+
invSigma = sparse(invSigma);
53+
54+
% invariant quantities that do not need to be recomputed each time
55+
XX = X'*X;
56+
A = beta*XX + invSigma; % posterior precision
57+
S = inv(A); % posterior covariance. Store for speed
58+
Q = S*X';
59+
%Q = A\X';
60+
trQX = trace(Q*X);
61+
R = (eye(D) - beta*Q*X)*Q;
62+
63+
% compute like this to avoid numerical overflow
64+
logdetA = 2*sum(log(diag(chol(A))));
65+
logdetSigma = sum(log(diag(A))); % assumes Sigma is diagonal
66+
67+
% save posterior precision
68+
post.A = A;
69+
70+
for r = 1:Nrep
71+
%if mod(r,5) == 0, fprintf('%d ',r); end
72+
t = T(:,r); % targets
73+
m = beta*Q*t; % posterior mean
74+
% save posterior means
75+
if r == 1, post.M = zeros(length(m), Nrep); end
76+
post.M(:,r) = m;
77+
78+
% frequently needed quantities dependent on t and m
79+
Xt = X'*t;
80+
XXm = XX*m;
81+
SXt = S*Xt;
82+
83+
if nargin == 3
84+
if r == 1, NLZ = zeros(Nrep,1); end
85+
86+
NLZ(r) = -0.5*( N*log(beta) - N*log(2*pi) - logdetSigma ...
87+
- beta*(t-X*m)'*(t-X*m) - m'*invSigma*m - logdetA );
88+
89+
if nargout > 1 % derivatives?
90+
if r == 1
91+
DNLZ = zeros(length(hyp), Nrep);
92+
end
93+
b = R*t;
94+
95+
% noise precision
96+
DNLZ(1,r) = -( N/(2*beta) - 0.5*(t'*t) + t'*X*m ...
97+
+ beta*t'*X*b - 0.5*m'*XX*m - beta*b'*XX*m ...
98+
- b'*invSigma*m -0.5*trQX )*beta;
99+
100+
% variance parameters
101+
for i = 1:Nalpha
102+
if Nalpha == D % use ARD?
103+
dSigma = sparse(i,i,-alpha(i)^-2,D,D);
104+
dinvSigma = sparse(i,i,1,D,D);
105+
else
106+
dSigma = -alpha(i)^-2*eye(D);
107+
dinvSigma = eye(D);
108+
end
109+
110+
%F = -invSigma*dSigma*invSigma;
111+
%c = -beta*F*Xt;
112+
F = dinvSigma;
113+
c = -beta*S*F*SXt;
114+
115+
DNLZ(i+1,r) = -(-0.5*sum(sum(invSigma.*dSigma')) + ...
116+
beta*Xt'*c - beta*c'*XXm - c'*invSigma*m ...
117+
- 0.5*m'*F*m - 0.5*sum(sum(S*F')) ...
118+
)*alpha(i);
119+
end
120+
end
121+
else % prediction mode
122+
if r == 1
123+
Ys = zeros(size(xs,1),Nrep);
124+
S2 = zeros(size(xs,1),Nrep);
125+
s2 = 1/beta + sum((xs*S).*xs,2); % assumes that xs is constant
126+
end
127+
Ys(:,r) = xs*m;
128+
S2(:,r) = s2;
129+
%S2(:,r) = 1/beta + diag(xs*(A\xs')); % sloooow
130+
end
131+
end
132+
%fprintf('\n');
133+
134+
% use this syntax instead of varargout to be able to compile this function
135+
if nargin == 3
136+
out1 = sum(NLZ);
137+
if nargout > 1
138+
out2 = sum(DNLZ,2);
139+
else
140+
out2 = [];
141+
end
142+
else
143+
out1 = Ys;
144+
out2 = S2;
145+
end
146+
end

sp_blr_cluster_job.m

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,30 @@
3333

3434
if opt.type2ml
3535
try
36-
%[hyp,nlml] = minimize(hyp, @gp, opt.maxEval, opt.inf, opt.mean, opt.cov, opt.lik, X, y);
3736
[hyp,nlml] = minimize(zeros(D+1,1), @blr, opt.maxEval, X, y);
3837

39-
% check gradients
40-
fun = @(lh)blr(lh,X,y);
41-
[~,g] = blr(zeros(D+1,1),X,y);
42-
gnum = computeNumericalGradient(fun,zeros(D+1,1));
38+
% % check gradients
39+
% fun = @(lh)blr(lh,X,y);
40+
% [~,g] = blr(zeros(D+1,1),X,y);
41+
% gnum = computeNumericalGradient(fun,zeros(D+1,1));
4342
catch
4443
warning('Optimisation failed. Using default values');
4544
end
4645
end
4746
if nargin > 4
48-
%[yhat, s2] = gp(hyp,opt.inf,opt.mean,opt.cov,opt.lik, X, y, Xs, zeros(Ns,1));
4947
[yhat, s2] = blr(hyp, X, y, Xs);
5048

5149
Yhat(:,t) = yhat;
5250
S2(:,t) = s2;
5351
if nargout > 5
54-
%[yhattr, s2tr] = gp(hyp,opt.inf,opt.mean,cov,opt.lik, X, y, X, zeros(N,1));
5552
[yhattr, s2tr] = blr(hyp, X, y, X);
5653
Yhattr(:,t) = yhattr;
5754
S2tr(:,t) = s2tr;
5855
end
5956
else % just report marginal likelihood and derivatives
60-
%[nlml, dnlml] = gp(hyp,opt.inf,opt.mean,opt.cov,opt.lik, X, y);
61-
%DNLML(:,t) = unwrap(dnlml);
62-
[nlml,DNLML(:,t)] = blr(hyp, X, y);
57+
[nlml,DNLML(:,t)] = blr(hyp, X, y);
6358
end
6459

6560
NLML(t) = min(nlml);
66-
Hyp(t,:) = hyp';%unwrap(hyp)';
61+
Hyp(t,:) = hyp';
6762
end

0 commit comments

Comments
 (0)