Skip to content

Commit 31b131d

Browse files
committed
make iterative factorization test and add relative error convergence criterion
1 parent 4b32629 commit 31b131d

File tree

7 files changed

+272
-30
lines changed

7 files changed

+272
-30
lines changed

docs/src/MatrixTensorFactor.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ combined_norm
7171
dist_to_Ncone
7272
rel_error
7373
mean_rel_error
74-
residual
74+
relative_error
7575
slicewise_dot
7676
```
7777

src/MatrixTensorFactor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using KernelDensity
1212
using Base: *
1313

1414
export Abstract3Tensor # Types
15-
export combined_norm, dist_to_Ncone, nnmtf, rel_error, mean_rel_error, residual, slicewise_dot # Functions
15+
export combined_norm, dist_to_Ncone, nnmtf, rel_error, mean_rel_error, relative_error, slicewise_dot # Functions
1616
export d_dx, d2_dx2, curvature, standard_curvature # Approximations
1717
export nnmtf_proxgrad_online
1818

src/matrixtensorfactorize.jl

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ const IMPLIMENTED_PROJECTIONS = Set{Symbol}((:nnscale, :simplex, :nonnegative))
3131
3232
- `:ncone`: vector-set distance between the -gradient of the objective and the normal cone
3333
- `:iterates`: A,B before and after one iteration are close in L2 norm
34-
- `:objective`: objective before and after one iteration is close
34+
- `:objective`: objective is small
35+
- `:relativeerror`: relative error is small (when `normalize=:nothing`) or
36+
mean relative error averaging fibres or slices when the normalization is `:fibres` or
37+
`:slices` respectfuly.
3538
"""
36-
const IMPLIMENTED_CRITERIA = Set{Symbol}((:ncone, :iterates, :objective))
39+
const IMPLIMENTED_CRITERIA = Set{Symbol}((:ncone, :iterates, :objective, :relativeerror))
3740

3841
"""
3942
IMPLIMENTED_STEPSIZES::Set{Symbol}
@@ -66,17 +69,21 @@ const IMPLIMENTED_OPTIONS = Dict(
6669
)
6770

6871
@doc raw"""
69-
nnmtf(Y::Abstract3Tensor, R::Integer; kwargs...)
72+
nnmtf(Y::AbstractArray, R::Integer; kwargs...)
7073
71-
Non-negatively matrix-tensor factorizes an order 3 tensor Y with a given "rank" R.
74+
Non-negatively matrix-tensor factorizes an order N tensor Y with a given "rank" R.
7275
73-
Factorizes ``Y \approx A B`` where ``\displaystyle Y[i,j,k] \approx \sum_{r=1}^R A[i,r]*B[r,j,k]``
76+
For an order ``N=3`` tensor, this factorizes ``Y \approx A B`` where
77+
``\displaystyle Y[i,j,k] \approx \sum_{r=1}^R A[i,r]*B[r,j,k]``
7478
and the factors ``A, B \geq 0`` are nonnegative.
7579
80+
For higher orders, this becomes
81+
``\displaystyle Y[i1,i2,...,iN] \approx \sum_{r=1}^R A[i1,r]*B[r,i2,...,iN].``
82+
7683
Note there may NOT be a unique optimal solution
7784
7885
# Arguments
79-
- `Y::Abstract3Tensor`: tensor to factorize
86+
- `Y::AbstractArray{T,N}`: tensor to factorize
8087
- `R::Integer`: rank to factorize Y (size(A)[2] and size(B)[1])
8188
8289
# Keywords
@@ -89,14 +96,16 @@ Note there may NOT be a unique optimal solution
8996
- `criterion::Symbol=:ncone`: how to determine if the algorithm has converged (must be in IMPLIMENTED_CRITERIA)
9097
- `stepsize::Symbol=:lipshitz`: used for the gradient decent step (must be in IMPLIMENTED_STEPSIZES)
9198
- `momentum::Bool=false`: use momentum updates
92-
- `delta::Real=0.9999`: safeguard for maximum amount of momentum (see eq 3.5 Xu & Yin 2013)
99+
- `delta::Real=0.9999`: safeguard for maximum amount of momentum (see eq (3.5) Xu & Yin 2013)
93100
- `R_max::Integer=size(Y)[1]`: maximum rank to try if R is not given
94101
- `projectionA::Symbol=projection`: projection to use on factor A (must be in IMPLIMENTED_PROJECTIONS)
95102
- `projectionB::Symbol=projection`: projection to use on factor B (must be in IMPLIMENTED_PROJECTIONS)
103+
- `A_init::AbstractMatrix=nothing`: initial A for the iterative algorithm. Should be kept as nothing if `R` is not given.
104+
- `B_init::AbstractArray=nothing`: initial B for the iterative algorithm. Should be kept as nothing if `R` is not given.
96105
97106
# Returns
98107
- `A::Matrix{Float64}`: the matrix A in the factorization Y ≈ A * B
99-
- `B::Array{Float64, 3}`: the tensor B in the factorization Y ≈ A * B
108+
- `B::Array{Float64, N}`: the tensor B in the factorization Y ≈ A * B
100109
- `rel_errors::Vector{Float64}`: relative errors at each iteration
101110
- `norm_grad::Vector{Float64}`: norm of the full gradient at each iteration
102111
- `dist_Ncone::Vector{Float64}`: distance of the -gradient to the normal cone at each iteration
@@ -224,6 +233,8 @@ function _nnmtf_proxgrad(
224233
rescale_Y::Bool = (projection == :nnscale ? true : false),
225234
projectionA::Symbol = projection,
226235
projectionB::Symbol = projection,
236+
A_init::Union{Nothing, AbstractMatrix}=nothing,
237+
B_init::Union{Nothing, AbstractArray}=nothing,
227238
)
228239
# Override scaling if no normalization is requested
229240
normalize == :nothing ? (rescale_AB = rescale_Y = false) : nothing
@@ -232,11 +243,25 @@ function _nnmtf_proxgrad(
232243
M, Ns... = size(Y)
233244

234245
# Initialize A, B
235-
init(x...) = abs.(randn(x...))
236-
A = init(M, R)
237-
B = init(R, Ns...)
246+
if A_init === nothing
247+
A = _init(M, R)
248+
else
249+
size(A_init) == (M, R) || throw(ArgumentError("A_init should have size $((M, R)), got $(size(A_init))"))
250+
A = A_init
251+
end
238252

239-
rescaleAB!(A, B; normalize)
253+
if A_init === nothing
254+
B = _init(R, Ns...)
255+
else
256+
size(B_init) == (R, Ns...) || throw(ArgumentError("A_init should have size $((R, Ns...)), got $(size(B_init))"))
257+
B = B_init
258+
end
259+
260+
# Only want to rescale the initialization if both A and B were not given
261+
# Otherwise, we should use the provided initialization
262+
if rescale_AB && A_init === nothing && B_init === nothing
263+
rescaleAB!(A, B; normalize)
264+
end
240265

241266
problem_size = R*(M + prod(Ns))
242267

@@ -254,7 +279,7 @@ function _nnmtf_proxgrad(
254279

255280
# Calculate initial relative error and gradient
256281
Yhat = A*B
257-
rel_errors[i] = residual(Yhat, Y; normalize)
282+
rel_errors[i] = relative_error(Yhat, Y; normalize)
258283
grad_A, grad_B = calc_gradient(A, B, Y)
259284
norm_grad[i] = combined_norm(grad_A, grad_B)
260285
dist_Ncone[i] = dist_to_Ncone(grad_A, grad_B, A, B)
@@ -318,15 +343,15 @@ function _nnmtf_proxgrad(
318343
# Calculate relative error and norm of gradient
319344
i += 1
320345
Yhat .= A*B
321-
rel_errors[i] = residual(Yhat, Y; normalize)
346+
rel_errors[i] = relative_error(Yhat, Y; normalize)
322347
# grad_A, grad_B = calc_gradient(A, B, Y)
323348
grad_A .= calc_gradientA(A, B, Y)
324349
grad_B .= calc_gradientB(A, B, Y)
325350
norm_grad[i] = combined_norm(grad_A, grad_B)
326351
# norm_grad[i] = combined_norm(grad_A, grad_B)
327352
dist_Ncone[i] = dist_to_Ncone(grad_A, grad_B, A, B)
328353

329-
if converged(; dist_Ncone, i, A, B, A_last, B_last, tol, problem_size, criterion, Y)
354+
if converged(; dist_Ncone, i, A, B, A_last, B_last, tol, problem_size, criterion, Y, Yhat, normalize)
330355
break
331356
end
332357

@@ -367,6 +392,11 @@ function _nnmtf_proxgrad(
367392
return A, B, rel_errors, norm_grad, dist_Ncone
368393
end
369394

395+
"""
396+
Default initialization
397+
"""
398+
_init(x...) = abs.(randn(x...))
399+
370400
"""
371401
Convergence criteria function.
372402
@@ -376,16 +406,26 @@ independent of the dimentions of Y and rank R.
376406
Note the use of `;` in the function definition so that order of arguments does not matter,
377407
and keyword assignment can be ignored if the input variables are named exactly as below.
378408
"""
379-
function converged(; dist_Ncone, i, A, B, A_last, B_last, tol, problem_size, criterion, Y)
409+
function converged(; dist_Ncone, i, A, B, A_last, B_last, tol, problem_size, criterion, Y, Yhat, normalize)
410+
criterion_value = 0.0
411+
380412
if !(criterion in IMPLIMENTED_CRITERIA)
381413
return UnimplimentedError("criterion is not an impliment criterion")
414+
382415
elseif criterion == :ncone
383-
return dist_Ncone[i]/sqrt(problem_size) < tol #TODO remove root problem size dependence
416+
criterion_value = dist_Ncone[i]/sqrt(problem_size) #TODO remove root problem size dependence
417+
384418
elseif criterion == :iterates
385-
return combined_norm(A - A_last, B - B_last) < tol
419+
criterion_value = combined_norm(A - A_last, B - B_last)
420+
386421
elseif criterion == :objective
387-
return 0.5 * norm(A*B - Y)^2 < tol
422+
criterion_value = 0.5 * norm(Yhat - Y)^2
423+
424+
elseif criterion == :relativeerror
425+
criterion_value = relative_error(Yhat, Y; normalize)
388426
end
427+
428+
return criterion_value < tol
389429
end
390430

391431
"""
@@ -405,7 +445,7 @@ function to_dims(normalize::Symbol)
405445
end
406446

407447
"""
408-
residual(Yhat, Y; normalize=:nothing)
448+
relative_error(Yhat, Y; normalize=:nothing)
409449
410450
Wrapper to use the relative error calculation according to the normalization used.
411451
@@ -415,7 +455,7 @@ Wrapper to use the relative error calculation according to the normalization use
415455
416456
See also [`rel_error`](@ref), [`mean_rel_error`](@ref).
417457
"""
418-
function residual(Yhat, Y; normalize=:nothing)
458+
function relative_error(Yhat, Y; normalize=:nothing)
419459
if normalize in (:fibres, :slices)
420460
return mean_rel_error(Yhat, Y; dims=to_dims(normalize))
421461
elseif normalize == :nothing
@@ -655,7 +695,8 @@ function nnmtf_proxgrad_online(
655695
dist_Ncone = zeros(maxiter)
656696

657697
# Calculate initial relative error and gradient
658-
rel_errors[i] = residual(A*B, Y; normalize)
698+
Yhat = A*B
699+
rel_errors[i] = relative_error(Yhat, Y; normalize)
659700
grad_A, grad_B = calc_gradient(A, B, Y)
660701
norm_grad[i] = combined_norm(grad_A, grad_B)
661702
dist_Ncone[i] = dist_to_Ncone(grad_A, grad_B, A, B)
@@ -719,12 +760,13 @@ function nnmtf_proxgrad_online(
719760

720761
# Calculate relative error and norm of gradient
721762
i += 1
722-
rel_errors[i] = residual(A*B, Y; normalize)
763+
Yhat .= A*B
764+
rel_errors[i] = relative_error(Yhat, Y; normalize)
723765
grad_A, grad_B = calc_gradient(A, B, Y)
724766
norm_grad[i] = combined_norm(grad_A, grad_B)
725767
dist_Ncone[i] = dist_to_Ncone(grad_A, grad_B, A, B)
726768

727-
if converged(; dist_Ncone, i, A, B, A_last, B_last, tol, problem_size, criterion, Y)
769+
if converged(; dist_Ncone, i, A, B, A_last, B_last, tol, problem_size, criterion, Y, Yhat, normalize)
728770
break
729771
end
730772

0 commit comments

Comments
 (0)