1
+ function [out1 , out2 , post ] = blr_multi(hyp , X , T , xs )
2
+
3
+ % Bayesian linear regression (multiple independent targets)
4
+ %
5
+ % Fits a bayesian linear regression model, where the inputs are:
6
+ % hyp : vector of hyperparmaters. hyp = [log(beta); log(alpha)]
7
+ % X : N x D data matrix
8
+ % t : N x 1 vector of targets
9
+ % xs : Nte x D matrix of test cases
10
+ %
11
+ % The hyperparameter beta is the noise precision and alpha is the precision
12
+ % over lengthscale parameters. This can be either a scalar variable (a
13
+ % common lengthscale for all input variables), or a vector of length D (a
14
+ % different lengthscale for each input variable, derived using an automatic
15
+ % relevance determination formulation).
16
+ %
17
+ % The main difference between this version and the vanilla version of blr
18
+ % is that this version precomputes lots of quantities that are used
19
+ % repeatedly for computing i.i.d samples with the same posterior covariance
20
+ % (i.e. when T is a matrix). for such cases this is more efficient than
21
+ % computing each separately.
22
+ %
23
+ % Two modes are supported:
24
+ % [nlZ, dnlZ, post] = blr(hyp, x, y); % report evidence and derivatives
25
+ % [mu, s2, post] = blr(hyp, x, y, xs); % predictive mean and variance
26
+ %
27
+ % Written by A. Marquand
28
+
29
+ if nargin < 3 || nargin > 4
30
+ disp(' Usage: [nlZ dnlZ] = blr(hyp, x, y);' )
31
+ disp(' or: [mu s2 ] = blr(hyp, x, y, xs);' )
32
+ return
33
+ end
34
+
35
+ [N ,D ] = size(X );
36
+ Nrep = size(T ,2 );
37
+ beta = exp(hyp(1 )); % noise precision
38
+ alpha = exp(hyp(2 : end )); % weight precisions
39
+ Nalpha = length(alpha );
40
+ if Nalpha ~= 1 && Nalpha ~= D
41
+ error(' hyperparameter vector has invalid length' );
42
+ end
43
+
44
+ if Nalpha == D
45
+ Sigma = diag(1 ./ alpha ); % weight prior covariance
46
+ invSigma = diag(alpha ); % weight prior precision
47
+ else
48
+ Sigma = 1 ./ alpha * eye(D ); % weight prior covariance
49
+ invSigma = alpha * eye(D ); % weight prior precision
50
+ end
51
+ Sigma = sparse(Sigma );
52
+ invSigma = sparse(invSigma );
53
+
54
+ % invariant quantities that do not need to be recomputed each time
55
+ XX = X ' *X ;
56
+ A = beta * XX + invSigma ; % posterior precision
57
+ S = inv(A ); % posterior covariance. Store for speed
58
+ Q = S * X ' ;
59
+ % Q = A\X';
60
+ trQX = trace(Q * X );
61
+ R = (eye(D ) - beta * Q * X )*Q ;
62
+
63
+ % compute like this to avoid numerical overflow
64
+ logdetA = 2 * sum(log(diag(chol(A ))));
65
+ logdetSigma = sum(log(diag(A ))); % assumes Sigma is diagonal
66
+
67
+ % save posterior precision
68
+ post.A = A ;
69
+
70
+ for r = 1 : Nrep
71
+ % if mod(r,5) == 0, fprintf('%d ',r); end
72
+ t = T(: ,r ); % targets
73
+ m = beta * Q * t ; % posterior mean
74
+ % save posterior means
75
+ if r == 1 , post.M = zeros(length(m ), Nrep ); end
76
+ post .M(: ,r ) = m ;
77
+
78
+ % frequently needed quantities dependent on t and m
79
+ Xt = X ' *t ;
80
+ XXm = XX * m ;
81
+ SXt = S * Xt ;
82
+
83
+ if nargin == 3
84
+ if r == 1 , NLZ = zeros(Nrep ,1 ); end
85
+
86
+ NLZ(r ) = - 0.5 *( N * log(beta ) - N * log(2 * pi ) - logdetSigma ...
87
+ - beta *(t - X * m )' *(t - X * m ) - m ' *invSigma * m - logdetA );
88
+
89
+ if nargout > 1 % derivatives?
90
+ if r == 1
91
+ DNLZ = zeros(length(hyp ), Nrep );
92
+ end
93
+ b = R * t ;
94
+
95
+ % noise precision
96
+ DNLZ(1 ,r ) = -( N /(2 * beta ) - 0.5 *(t ' *t ) + t ' *X * m ...
97
+ + beta * t ' *X * b - 0.5 * m ' *XX * m - beta * b ' *XX * m ...
98
+ - b ' *invSigma * m - 0.5 * trQX )*beta ;
99
+
100
+ % variance parameters
101
+ for i = 1 : Nalpha
102
+ if Nalpha == D % use ARD?
103
+ dSigma = sparse(i ,i ,-alpha(i )^-2 ,D ,D );
104
+ dinvSigma = sparse(i ,i ,1 ,D ,D );
105
+ else
106
+ dSigma = - alpha(i )^-2 * eye(D );
107
+ dinvSigma = eye(D );
108
+ end
109
+
110
+ % F = -invSigma*dSigma*invSigma;
111
+ % c = -beta*F*Xt;
112
+ F = dinvSigma ;
113
+ c = - beta * S * F * SXt ;
114
+
115
+ DNLZ(i + 1 ,r ) = -(-0.5 * sum(sum(invSigma .* dSigma ' )) + ...
116
+ beta * Xt ' *c - beta * c ' *XXm - c ' *invSigma * m ...
117
+ - 0.5 * m ' *F * m - 0.5 * sum(sum(S * F ' )) ...
118
+ )*alpha(i );
119
+ end
120
+ end
121
+ else % prediction mode
122
+ if r == 1
123
+ Ys = zeros(size(xs ,1 ),Nrep );
124
+ S2 = zeros(size(xs ,1 ),Nrep );
125
+ s2 = 1 / beta + sum((xs * S ).*xs ,2 ); % assumes that xs is constant
126
+ end
127
+ Ys(: ,r ) = xs * m ;
128
+ S2(: ,r ) = s2 ;
129
+ % S2(:,r) = 1/beta + diag(xs*(A\xs')); % sloooow
130
+ end
131
+ end
132
+ % fprintf('\n');
133
+
134
+ % use this syntax instead of varargout to be able to compile this function
135
+ if nargin == 3
136
+ out1 = sum(NLZ );
137
+ if nargout > 1
138
+ out2 = sum(DNLZ ,2 );
139
+ else
140
+ out2 = [];
141
+ end
142
+ else
143
+ out1 = Ys ;
144
+ out2 = S2 ;
145
+ end
146
+ end
0 commit comments