-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmatmul_test.m
157 lines (139 loc) · 4.88 KB
/
matmul_test.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
% matmul_test.m Simulation of custom variants of Model 1 in [1].
%
% Requirements:
% CPFloat (https://github.com/north-numerical-computing/cpfloat/).
%
% References:
% [1] T. Mary and M. Mikaitis.
% Error Analysis of Matrix Multiplication with Narrow Range
% Floating-Point Arithmetic. hal-04671474. Aug. 2024.
% Set up the input format.
options_input.format = input_format;
[~, options] = cpfloat([], options_input);
% Grab various parameters of the format.
t = options.params(1);
emin = options.params(2);
emax = options.params(3);
% Set up some useful quantities used in the paper.
u = 2^-t;
fmin = 2^emin;
% Special treatment for the maximum value of fp8-e4m3, as per the
% OFP8 format specification.
if (strcmp(input_format, 'fp8-e4m3'))
fmax = 2^emax*(2-4*u);
else
fmax = 2^emax*(2-2*u);
end
% Set up the accumulation format.
options_accum.format = accum_format;
[~, options] = cpfloat([], options_accum);
% Grab various parameters of the format.
T = options.params(1);
Emin = options.params(2);
Emax = options.params(3);
% Set up some useful quantities used in the paper.
U = 2^(-T);
Fmin = 2^Emin;
Fmax = 2^Emax*(2-2*U);
% Matrix dimensions: A is m x n, B is n x q
m = 10;
q = 10;
nlist = floor(logspace(1,6,40));
% Matrix elements: uniformly distributed logarithms in [-l, l].
l = 10;
for subnormals_on = 0:1
options_input.subnormal = subnormals_on;
options_accum.subnormal = subnormals_on;
i = 0;
for n = nlist
i=i+1;
% Matrix elements: uniformly distributed logarithms in [-l, l].
A = (10.^(rand(m,n)*2*l-l));
B = (10.^(rand(n,q)*2*l-l));
% Random sign + or - with equal probability
sA = randi(2,m,n)*2-3;
sB = randi(2,n,q)*2-3;
A = A.*sA;
B = B.*sB;
% Compute a reference result in binary64.
Ctrue = A*B;
% Compute diagonal scaling matrices L and M such that
% the elements of L*A and B*M are at most theta.
theta = min(fmax, sqrt(Fmax/n));
L = previous_pow2(theta./max(abs(A),[],2));
M = previous_pow2(theta./max(abs(B)));
Linv = 1./L;
Minv = 1./M;
% Round L*A and B*M to the input format.
temp1 = L.*A;
temp2 = B.*M;
LA{1} = cpfloat(temp1, options_input);
BM{1} = cpfloat(temp2, options_input);
for j = 2:p
temp1 = temp1 - LA{j-1}*u^(j-2);
temp2 = temp2 - BM{j-1}*u^(j-2);
LA{j} = cpfloat(temp1/u^(j-1), options_input);
BM{j} = cpfloat(temp2/u^(j-1), options_input);
end
% Compute LA*BM in the accumulation format.
LABM = matmul(LA, BM, options_accum, p, u);
% Scale LABM back to obtain C.
C = Linv.*LABM.*Minv;
% Compute the error
err(i) = norm(C-Ctrue,'inf')/norm(A,'inf')/norm(B,'inf');
% Compute the product as if we had no range limitations.
options_input.explim = 0;
options_accum.explim = 0;
temp1 = L.*A;
temp2 = B.*M;
LA_nrl{1} = cpfloat(temp1, options_input);
BM_nrl{1} = cpfloat(temp2, options_input);
for j = 2:p
temp1 = temp1 - LA_nrl{j-1}*u^(j-2);
temp2 = temp2 - BM_nrl{j-1}*u^(j-2);
LA_nrl{j} = cpfloat(temp1/u^(j-1), options_input);
BM_nrl{j} = cpfloat(temp2/u^(j-1), options_input);
end
LABM_nrl = matmul(LA_nrl, BM_nrl, options_accum, p, u);
C_nrl = Linv.*LABM_nrl.*Minv;
err_nrl(i) =...
norm(C_nrl-Ctrue,'inf')/norm(A,'inf')/norm(B,'inf');
options_input.explim = 1;
options_accum.explim = 1;
% Bound (3.26)
bound(i) = 2*u + n*U + 4*n^2*fmin/theta + 4*n^2*Fmin/theta^2;
bound_nrl(i) = 2*u + n*U;
% Same bound but without dependency on n, which is quite
% pessimistic.
%bound(i) = 2*u + U + 4*fmin/theta + 4*Fmin/theta^2;
%bound_nrl(i) = 2*u + U;
end
% Output various results to .dat files.
filename = strcat('./data/matmul_test_', input_format,...
'_', accum_format, '_subnormals',...
num2str(subnormals_on), '_words_', num2str(p), '.dat');
fileID = fopen(filename, 'w');
fprintf(fileID, ...
['n error bound error-nrl bound-nrl \n']);
for j=1:length(nlist)
fprintf(fileID,'%d %e %e %e %e \n', ...
nlist(j), err(j), bound(j), err_nrl(j), bound_nrl(j));
end
end
function y = previous_pow2(x)
% Replace elements of x by the immediately inferior power of two.
y = 2.^floor(log2(x));
end
function C = matmul(A, B, options_accum, p, u)
C = zeros(size(A{1},1), size(B{1},2));
for j=1:p
for k=1:p
if (j+k-2 < p)
for i=1:size(A{j},2)
C = cpfloat(C + u^(j+k-2)*cpfloat(A{j}(:,i)*B{k}(i,:),...
options_accum), options_accum);
end
end
end
end
end