2
2
3
3
using Base: FastMath
4
4
5
-
6
5
# # helpers
7
6
8
7
within (lower, upper) = (val) -> lower <= val <= upper
@@ -103,17 +102,98 @@ end
103
102
104
103
@device_override Base. log (x:: Float64 ) = ccall (" extern __nv_log" , llvmcall, Cdouble, (Cdouble,), x)
105
104
@device_override Base. log (x:: Float32 ) = ccall (" extern __nv_logf" , llvmcall, Cfloat, (Cfloat,), x)
105
+ @device_override function Base. log (x:: Float16 )
106
+ log_x = @asmcall (""" {.reg.b32 f, C;
107
+ .reg.b16 r,h;
108
+ mov.b16 h,\$ 1;
109
+ cvt.f32.f16 f,h;
110
+ lg2.approx.ftz.f32 f,f;
111
+ mov.b32 C, 0x3f317218U;
112
+ mul.f32 f,f,C;
113
+ cvt.rn.f16.f32 r,f;
114
+ .reg.b16 spc, ulp, p;
115
+ mov.b16 spc, 0X160DU;
116
+ mov.b16 ulp, 0x9C00U;
117
+ set.eq.f16.f16 p, h, spc;
118
+ fma.rn.f16 r,p,ulp,r;
119
+ mov.b16 spc, 0X3BFEU;
120
+ mov.b16 ulp, 0x8010U;
121
+ set.eq.f16.f16 p, h, spc;
122
+ fma.rn.f16 r,p,ulp,r;
123
+ mov.b16 spc, 0X3C0BU;
124
+ mov.b16 ulp, 0x8080U;
125
+ set.eq.f16.f16 p, h, spc;
126
+ fma.rn.f16 r,p,ulp,r;
127
+ mov.b16 spc, 0X6051U;
128
+ mov.b16 ulp, 0x1C00U;
129
+ set.eq.f16.f16 p, h, spc;
130
+ fma.rn.f16 r,p,ulp,r;
131
+ mov.b16 \$ 0,r;
132
+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
133
+ return log_x
134
+ end
135
+
106
136
@device_override FastMath. log_fast (x:: Float32 ) = ccall (" extern __nv_fast_logf" , llvmcall, Cfloat, (Cfloat,), x)
107
137
108
138
@device_override Base. log10 (x:: Float64 ) = ccall (" extern __nv_log10" , llvmcall, Cdouble, (Cdouble,), x)
109
139
@device_override Base. log10 (x:: Float32 ) = ccall (" extern __nv_log10f" , llvmcall, Cfloat, (Cfloat,), x)
140
+ @device_override function Base. log10 (x:: Float16 )
141
+ log_x = @asmcall (""" {.reg.b16 h, r;
142
+ .reg.b32 f, C;
143
+ mov.b16 h, \$ 1;
144
+ cvt.f32.f16 f, h;
145
+ lg2.approx.ftz.f32 f, f;
146
+ mov.b32 C, 0x3E9A209BU;
147
+ mul.f32 f,f,C;
148
+ cvt.rn.f16.f32 r, f;
149
+ .reg.b16 spc, ulp, p;
150
+ mov.b16 spc, 0x338FU;
151
+ mov.b16 ulp, 0x1000U;
152
+ set.eq.f16.f16 p, h, spc;
153
+ fma.rn.f16 r,p,ulp,r;
154
+ mov.b16 spc, 0x33F8U;
155
+ mov.b16 ulp, 0x9000U;
156
+ set.eq.f16.f16 p, h, spc;
157
+ fma.rn.f16 r,p,ulp,r;
158
+ mov.b16 spc, 0x57E1U;
159
+ mov.b16 ulp, 0x9800U;
160
+ set.eq.f16.f16 p, h, spc;
161
+ fma.rn.f16 r,p,ulp,r;
162
+ mov.b16 spc, 0x719DU;
163
+ mov.b16 ulp, 0x9C00U;
164
+ set.eq.f16.f16 p, h, spc;
165
+ fma.rn.f16 r,p,ulp,r;
166
+ mov.b16 \$ 0, r;
167
+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
168
+ return log_x
169
+ end
110
170
@device_override FastMath. log10_fast (x:: Float32 ) = ccall (" extern __nv_fast_log10f" , llvmcall, Cfloat, (Cfloat,), x)
111
171
112
172
@device_override Base. log1p (x:: Float64 ) = ccall (" extern __nv_log1p" , llvmcall, Cdouble, (Cdouble,), x)
113
173
@device_override Base. log1p (x:: Float32 ) = ccall (" extern __nv_log1pf" , llvmcall, Cfloat, (Cfloat,), x)
114
174
115
175
@device_override Base. log2 (x:: Float64 ) = ccall (" extern __nv_log2" , llvmcall, Cdouble, (Cdouble,), x)
116
176
@device_override Base. log2 (x:: Float32 ) = ccall (" extern __nv_log2f" , llvmcall, Cfloat, (Cfloat,), x)
177
+ @device_override function Base. log2 (x:: Float16 )
178
+ log_x = @asmcall (""" {.reg.b16 h, r;
179
+ .reg.b32 f;
180
+ mov.b16 h, \$ 1;
181
+ cvt.f32.f16 f, h;
182
+ lg2.approx.ftz.f32 f, f;
183
+ cvt.rn.f16.f32 r, f;
184
+ .reg.b16 spc, ulp, p;
185
+ mov.b16 spc, 0xA2E2U;
186
+ mov.b16 ulp, 0x8080U;
187
+ set.eq.f16.f16 p, r, spc;
188
+ fma.rn.f16 r,p,ulp,r;
189
+ mov.b16 spc, 0xBF46U;
190
+ mov.b16 ulp, 0x9400U;
191
+ set.eq.f16.f16 p, r, spc;
192
+ fma.rn.f16 r,p,ulp,r;
193
+ mov.b16 \$ 0, r;
194
+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
195
+ return log_x
196
+ end
117
197
@device_override FastMath. log2_fast (x:: Float32 ) = ccall (" extern __nv_fast_log2f" , llvmcall, Cfloat, (Cfloat,), x)
118
198
119
199
@device_function logb (x:: Float64 ) = ccall (" extern __nv_logb" , llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +207,95 @@ end
127
207
128
208
@device_override Base. exp (x:: Float64 ) = ccall (" extern __nv_exp" , llvmcall, Cdouble, (Cdouble,), x)
129
209
@device_override Base. exp (x:: Float32 ) = ccall (" extern __nv_expf" , llvmcall, Cfloat, (Cfloat,), x)
210
+ @device_override function Base. exp (x:: Float16 )
211
+ exp_x = @asmcall (""" {
212
+ .reg.b32 f, C, nZ;
213
+ .reg.b16 h,r;
214
+ mov.b16 h,\$ 1;
215
+ cvt.f32.f16 f,h;
216
+ mov.b32 C, 0x3fb8aa3bU;
217
+ mov.b32 nZ, 0x80000000U;
218
+ fma.rn.f32 f,f,C,nZ;
219
+ ex2.approx.ftz.f32 f,f;
220
+ cvt.rn.f16.f32 r,f;
221
+ .reg.b16 spc, ulp, p;
222
+ mov.b16 spc,0X1F79U;
223
+ mov.b16 ulp,0x9400U;
224
+ set.eq.f16.f16 p, h, spc;
225
+ fma.rn.f16 r,p,ulp,r;
226
+ mov.b16 spc,0X25CFU;
227
+ mov.b16 ulp,0x9400U;
228
+ set.eq.f16.f16 p, h, spc;
229
+ fma.rn.f16 r,p,ulp,r;
230
+ mov.b16 spc,0XC13BU;
231
+ mov.b16 ulp,0x0400U;
232
+ set.eq.f16.f16 p, h, spc;
233
+ fma.rn.f16 r,p,ulp,r;
234
+ mov.b16 spc,0XC1EFU;
235
+ mov.b16 ulp,0x0200U;
236
+ set.eq.f16.f16 p, h, spc;
237
+ fma.rn.f16 r,p,ulp,r;
238
+ mov.b16 \$ 0,r;
239
+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
240
+ return exp_x
241
+ end
130
242
@device_override FastMath. exp_fast (x:: Float32 ) = ccall (" extern __nv_fast_expf" , llvmcall, Cfloat, (Cfloat,), x)
131
243
132
244
@device_override Base. exp2 (x:: Float64 ) = ccall (" extern __nv_exp2" , llvmcall, Cdouble, (Cdouble,), x)
133
245
@device_override Base. exp2 (x:: Float32 ) = ccall (" extern __nv_exp2f" , llvmcall, Cfloat, (Cfloat,), x)
246
+ @device_override function Base. exp2 (x:: Float16 )
247
+ exp_x = @asmcall (""" {.reg.b32 f, ULP;
248
+ .reg.b16 r;
249
+ mov.b16 r,\$ 1;
250
+ cvt.f32.f16 f,r;
251
+ ex2.approx.ftz.f32 f,f;
252
+ mov.b32 ULP, 0x33800000U;
253
+ fma.rn.f32 f,f,ULP,f;
254
+ cvt.rn.f16.f32 r,f;
255
+ mov.b16 \$ 0,r;
256
+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
257
+ return exp_x
258
+ end
134
259
@device_override FastMath. exp2_fast (x:: Union{Float32, Float64} ) = exp2 (x)
135
- # TODO : enable once PTX > 7.0 is supported
136
- # @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137
260
138
261
@device_override Base. exp10 (x:: Float64 ) = ccall (" extern __nv_exp10" , llvmcall, Cdouble, (Cdouble,), x)
139
262
@device_override Base. exp10 (x:: Float32 ) = ccall (" extern __nv_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
263
+ @device_override function Base. exp10 (x:: Float16 )
264
+
265
+ exp_x = @asmcall (""" {.reg.b16 h,r;
266
+ .reg.b32 f, C, nZ;
267
+ mov.b16 h, \$ 1;
268
+ cvt.f32.f16 f, h;
269
+ mov.b32 C, 0x40549A78U;
270
+ mov.b32 nZ, 0x80000000U;
271
+ fma.rn.f32 f,f,C,nZ;
272
+ ex2.approx.ftz.f32 f, f;
273
+ cvt.rn.f16.f32 r, f;
274
+ .reg.b16 spc, ulp, p;
275
+ mov.b16 spc,0x34DEU;
276
+ mov.b16 ulp,0x9800U;
277
+ set.eq.f16.f16 p, h, spc;
278
+ fma.rn.f16 r,p,ulp,r;
279
+ mov.b16 spc,0x9766U;
280
+ mov.b16 ulp,0x9000U;
281
+ set.eq.f16.f16 p, h, spc;
282
+ fma.rn.f16 r,p,ulp,r;
283
+ mov.b16 spc,0x9972U;
284
+ mov.b16 ulp,0x1000U;
285
+ set.eq.f16.f16 p, h, spc;
286
+ fma.rn.f16 r,p,ulp,r;
287
+ mov.b16 spc,0xA5C4U;
288
+ mov.b16 ulp,0x1000U;
289
+ set.eq.f16.f16 p, h, spc;
290
+ fma.rn.f16 r,p,ulp,r;
291
+ mov.b16 spc,0xBF0AU;
292
+ mov.b16 ulp,0x8100U;
293
+ set.eq.f16.f16 p, h, spc;
294
+ fma.rn.f16 r,p,ulp,r;
295
+ mov.b16 \$ 0, r;
296
+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
297
+ return exp_x
298
+ end
140
299
@device_override FastMath. exp10_fast (x:: Float32 ) = ccall (" extern __nv_fast_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
141
300
142
301
@device_override Base. expm1 (x:: Float64 ) = ccall (" extern __nv_expm1" , llvmcall, Cdouble, (Cdouble,), x)
204
363
205
364
@device_override Base. isnan (x:: Float64 ) = (ccall (" extern __nv_isnand" , llvmcall, Int32, (Cdouble,), x)) != 0
206
365
@device_override Base. isnan (x:: Float32 ) = (ccall (" extern __nv_isnanf" , llvmcall, Int32, (Cfloat,), x)) != 0
366
+ @device_override function Base. isnan (x:: Float16 )
367
+ if compute_capability () >= sv " 8.0"
368
+ return (ccall (" extern __nv_hisnan" , llvmcall, Int32, (Float16,), x)) != 0
369
+ else
370
+ return isnan (Float32 (x))
371
+ end
372
+ end
207
373
208
374
@device_function nearbyint (x:: Float64 ) = ccall (" extern __nv_nearbyint" , llvmcall, Cdouble, (Cdouble,), x)
209
375
@device_function nearbyint (x:: Float32 ) = ccall (" extern __nv_nearbyintf" , llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +389,20 @@ end
223
389
@device_override Base. abs (x:: Int32 ) = ccall (" extern __nv_abs" , llvmcall, Int32, (Int32,), x)
224
390
@device_override Base. abs (f:: Float64 ) = ccall (" extern __nv_fabs" , llvmcall, Cdouble, (Cdouble,), f)
225
391
@device_override Base. abs (f:: Float32 ) = ccall (" extern __nv_fabsf" , llvmcall, Cfloat, (Cfloat,), f)
226
- # TODO : enable once PTX > 7.0 is supported
227
- # @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
392
+ @device_override Base. abs (f:: Float16 ) = Float16 (abs (Float32 (f)))
228
393
@device_override Base. abs (x:: Int64 ) = ccall (" extern __nv_llabs" , llvmcall, Int64, (Int64,), x)
229
394
230
395
# # roots and powers
231
396
232
397
@device_override Base. sqrt (x:: Float64 ) = ccall (" extern __nv_sqrt" , llvmcall, Cdouble, (Cdouble,), x)
233
398
@device_override Base. sqrt (x:: Float32 ) = ccall (" extern __nv_sqrtf" , llvmcall, Cfloat, (Cfloat,), x)
399
+ @device_override function Base. sqrt (x:: Float16 )
400
+ if compute_capability () >= sv " 8.0"
401
+ ccall (" extern __nv_hsqrt" , llvmcall, Float16, (Float16,), x)
402
+ else
403
+ return Float16 (sqrt (Float32 (x)))
404
+ end
405
+ end
234
406
@device_override FastMath. sqrt_fast (x:: Union{Float32, Float64} ) = sqrt (x)
235
407
236
408
@device_function rsqrt (x:: Float64 ) = ccall (" extern __nv_rsqrt" , llvmcall, Cdouble, (Cdouble,), x)
0 commit comments