forked from lttam/SobolevTransport
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_OT_GroundGraphMetric.m
76 lines (58 loc) · 1.6 KB
/
compute_OT_GroundGraphMetric.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
%
% compute optimal transport distance matrix (with ground graph metric)
%
% Choose:
% (1) typeGG = 'RandLLE' (G_Log) or typeGG = 'RandSLE' (G_Sqrt)
%
clear all
clc
typeGG = 'RandLLE'; % log-linear #edges (G_Log)
% typeGG = 'RandSLE'; % sqrt-linear #edges (G_Sqrt)
dsName = 'twitter';
maxKC = 100;
load([dsName '_' num2str(maxKC) '_' typeGG '_Graph.mat']);
GM = zeros(nGG, nGG);
disp('compute the ground graph metric');
tic
for ii = 1:(nGG-1)
[~, TRD_II, ~] = shortestpathtree(GG, ii, [(ii+1):nGG], 'OutputForm', 'cell');
GM(ii, (ii+1):nGG) = TRD_II;
GM((ii+1):nGG, ii) = TRD_II';
end
runTime_GroundGM = toc;
% histogram
XX_ID_vec = zeros(N, nGG);
tic
for ii = 1:N
% WW{ii}
tmpWW = WW{ii}/sum(WW{ii}); % normalization
tmpXX_ID = XX_ID{ii};
XX_ID_vec(ii, tmpXX_ID) = tmpWW';
end
runTime_Hist = toc;
tic
% compute the OT
DD_OT = zeros(N, N);
for ii = 1:(N-1)
if mod(ii, 20) == 0
disp(['...' num2str(ii)]);
end
for jj = (ii+1):N
% preprocessing
tmpALL = XX_ID_vec(ii, :) + XX_ID_vec(jj, :);
idNZ = find(tmpALL > 0);
tmpII = XX_ID_vec(ii, idNZ);
tmpJJ = XX_ID_vec(jj, idNZ);
GMIJ = GM(idNZ, idNZ);
DD_OT(ii, jj) = mexEMD(tmpII', tmpJJ', GMIJ);
DD_OT(jj, ii) = DD_OT(ii, jj);
end
end
runTime_Dist = toc;
runTime_Dist_ALL = runTime_Dist + runTime_GroundGM + runTime_Hist;
outName = [dsName '_OT_' num2str(maxKC) '_' typeGG '.mat'];
save(outName, 'DD_OT', ...
'runTime_Dist', 'runTime_GroundGM', 'runTime_Hist', ...
'runTime_Dist_ALL', ...
'YY');
disp('FINISH !!!');