|
395 | 395 | @device_override Base.rem(x::Float16, y::Float16, ::RoundingMode{:Nearest}) = Float16(rem(Float32(x), Float32(y), RoundNearest)) |
396 | 396 |
|
397 | 397 | @device_override FastMath.div_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_fdividef", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
| 398 | +@device_override FastMath.div_fast(x::Float64, y::Float64) = x * FastMath.inv_fast(y) |
398 | 399 |
|
399 | 400 | @device_override Base.inv(x::Float32) = ccall("extern __nv_frcp_rn", llvmcall, Cfloat, (Cfloat,), x) |
400 | | -@device_override FastMath.inv_fast(x::Union{Float32, Float64}) = @fastmath one(x) / x |
| 401 | +@device_override FastMath.inv_fast(x::Float32) = ccall("llvm.nvvm.rcp.approx.ftz.f", llvmcall, Float32, (Float32,), x) |
| 402 | +@device_override function FastMath.inv_fast(x::Float64) |
| 403 | + # Get the approximate reciprocal |
| 404 | + # https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rcp-approx-ftz-f64 |
| 405 | + # This instruction chops off last 32bits of mantissa and computes inverse |
| 406 | + # while treating all subnormal numbers as 0.0 |
| 407 | + # If reciprocal would be subnormal, underflows to 0.0 |
| 408 | + # 32 least significant bits of the result are filled with 0s |
| 409 | + inv_x = ccall("llvm.nvvm.rcp.approx.ftz.d", llvmcall, Float64, (Float64,), x) |
| 410 | + |
| 411 | + # Approximate the missing 32bits of mantissa with a single cubic iteration |
| 412 | + e = fma(inv_x, -x, 1.0) |
| 413 | + e = fma(e, e, e) |
| 414 | + inv_x = fma(e, inv_x, inv_x) |
| 415 | +end |
401 | 416 |
|
402 | 417 | ## distributions |
403 | 418 |
|
|
0 commit comments