-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathexample_main_admm.m
128 lines (108 loc) · 3.59 KB
/
example_main_admm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
%
% create time: 2015/4/30
% last update: 2015/4/30
%
close all;
clear; clc;
addpath('.\VBfunctions');
addpath('.\boundedline');
load ('COIL20_PCA.mat')
% feature: fea 9298 x 256
% groundturth: gnd 9298 x 1
%gnd = gnd - 1;
fea = repmat(fea,10,1);
gnd = repmat(gnd,10,1);
load Network2.mat
num_avg = 10; % 30
repeat = 10; % 10
rand('state',sum(1000*clock));
mtd = 'admm';
AC_result = zeros(num_avg,repeat,9);
NMI_result = zeros(num_avg,repeat,9);
AVG_AC = zeros(9,1);
AVG_NMI = zeros(9,1);
fid = fopen(['_res_' mtd '_1.txt'],'wt');
for K=6:10
tic;
avgAC = 0;
avgNMI = 0;
for t=1:num_avg
restart = 1;
while(restart)
restart = 0;
clusts = randperm(20,K);
fprintf('Digit number:');
fprintf(fid,'Digit number:');
for ppp = 1:K
fprintf('%d ',clusts(ppp) - 1);
fprintf(fid,'%d ',clusts(ppp) - 1);
end
fprintf('\n\n');
fprintf(fid,'\n\n');
oldAC = 0;
oldNMI = 0;
re_idx = [];
re_gnd = [];
for i=1:K
idx = find(gnd==clusts(i));
gnd2 = i*ones(length(idx),1);
re_idx = [re_idx;idx];
re_gnd = [re_gnd;gnd2];
end
[re_idx,ord] = sort(re_idx);
re_gnd = re_gnd(ord);
re_fea = fea(re_idx,:);
nsample = length(re_gnd);
[ re_fea, re_gnd, NodeSample, GroundTruth ] = splitdata_func( Network, re_fea, re_gnd ,K);
seed_off = floor(10000*rand(repeat,1));
for tt = 1:repeat
rand('state',sum(1000*clock)+seed_off(tt));
rho = 16;
flag = 1;
cnt = 0;
while(flag == 1)
cnt = cnt + 1;
if cnt > 10
fprintf('cnt = %d\n',cnt);
restart = 1;
break;
end
[MixModel,flag] = dvbgmm_admm(Network, NodeSample,K,GroundTruth,rho);
end
if restart == 1
break;
end
label = MixModel.Label;
[ new_label] = label_map( label,re_gnd );
AC = length(find(new_label-re_gnd == 0))/nsample;
NMI = MutualInfo(re_gnd,new_label);
if oldAC < AC
oldAC = AC;
end
if oldNMI < NMI
oldNMI = NMI;
end
fprintf([mtd,' %d AC: %f, MI: %f\n'],tt, oldAC, oldNMI);
fprintf(fid,[mtd,' %d AC: %f, MI: %f\n'],tt, oldAC, oldNMI);
AC_result(t,tt,K-1) = oldAC;
NMI_result(t,tt,K-1) = oldNMI;
end
end
fprintf([mtd, ': K = %d t = %d AC = %f, NMI = %f\n\n'], K,t, oldAC,oldNMI);
fprintf(fid, [mtd, ': K = %d t = %d AC = %f, NMI = %f\n\n'], K,t, oldAC,oldNMI);
avgAC = avgAC + oldAC;
avgNMI = avgNMI + oldNMI;
end
AVG_AC(K-1) = avgAC/num_avg;
AVG_NMI(K-1) = avgNMI/num_avg;
fprintf([mtd, ' avgAC = %f, avgMI = %f\n\n'],AVG_AC(K-1),AVG_NMI(K-1));
fprintf(fid,[mtd, ' avgAC = %f, avgMI = %f\n\n'],AVG_AC(K-1),AVG_NMI(K-1));
toc;
end
for K=2:10
fprintf([mtd, 'K = %d avgAC = %f, avgMI = %f\n'],K, AVG_AC(K-1),AVG_NMI(K-1));
fprintf(fid,[mtd, 'K = %d avgAC = %f, avgMI = %f\n'],K, AVG_AC(K-1),AVG_NMI(K-1));
end
save(['_', mtd, '_result_1.mat'],'AC_result', 'NMI_result','AVG_AC','AVG_NMI')
fclose(fid);