|
2 | 2 |
|
3 | 3 | @public fma, rsqrt, saturate, byte_perm, assume |
4 | 4 |
|
5 | | -using Base: FastMath |
| 5 | +using Base: FastMath, @assume_effects |
6 | 6 |
|
7 | 7 |
|
8 | 8 | ## helpers |
@@ -248,17 +248,42 @@ end |
248 | 248 | @device_override Base.:(^)(x::Float64, y::Float64) = ccall("extern __nv_pow", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) |
249 | 249 | @device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern __nv_powf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
250 | 250 | @device_override FastMath.pow_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_powf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
| 251 | +# pow_fast: Base methods call llvm.powi which NVPTX cannot lower (#3065) |
| 252 | +@device_override @assume_effects :foldable @inline function FastMath.pow_fast(x::Float64, y::Integer) |
| 253 | + y == -1 && return inv(x) |
| 254 | + y == 0 && return one(x) |
| 255 | + y == 1 && return x |
| 256 | + y == 2 && return x*x |
| 257 | + y == 3 && return x*x*x |
| 258 | + x ^ y # no fast variant for Float64; uses __nv_powi |
| 259 | +end |
| 260 | +@device_override @assume_effects :foldable @inline function FastMath.pow_fast(x::Float32, y::Integer) |
| 261 | + y == -1 && return inv(x) |
| 262 | + y == 0 && return one(x) |
| 263 | + y == 1 && return x |
| 264 | + y == 2 && return x*x |
| 265 | + y == 3 && return x*x*x |
| 266 | + FastMath.pow_fast(x, Float32(y)) # uses __nv_fast_powf |
| 267 | +end |
| 268 | +@device_override @assume_effects :foldable @inline function FastMath.pow_fast(x::Float16, y::Integer) |
| 269 | + y == -1 && return inv(x) |
| 270 | + y == 0 && return one(x) |
| 271 | + y == 1 && return x |
| 272 | + y == 2 && return x*x |
| 273 | + y == 3 && return x*x*x |
| 274 | + Float16(FastMath.pow_fast(Float32(x), Float32(y))) |
| 275 | +end |
251 | 276 | @device_override Base.:(^)(x::Float64, y::Int32) = ccall("extern __nv_powi", llvmcall, Cdouble, (Cdouble, Int32), x, y) |
252 | 277 | @device_override Base.:(^)(x::Float32, y::Int32) = ccall("extern __nv_powif", llvmcall, Cfloat, (Cfloat, Int32), x, y) |
253 | | -@device_override @inline function Base.:(^)(x::Float32, y::Int64) |
| 278 | +@device_override @assume_effects :foldable @inline function Base.:(^)(x::Float32, y::Int64) |
254 | 279 | y == -1 && return inv(x) |
255 | 280 | y == 0 && return one(x) |
256 | 281 | y == 1 && return x |
257 | 282 | y == 2 && return x*x |
258 | 283 | y == 3 && return x*x*x |
259 | 284 | x ^ Float32(y) |
260 | 285 | end |
261 | | -@device_override @inline function Base.:(^)(x::Float64, y::Int64) |
| 286 | +@device_override @assume_effects :foldable @inline function Base.:(^)(x::Float64, y::Int64) |
262 | 287 | y == -1 && return inv(x) |
263 | 288 | y == 0 && return one(x) |
264 | 289 | y == 1 && return x |
@@ -435,10 +460,14 @@ end |
435 | 460 | @device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) |
436 | 461 | @device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
437 | 462 |
|
438 | | -@device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) |
439 | | -@device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) |
| 463 | +@device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("llvm.fma.f64", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) |
| 464 | +@device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("llvm.fma.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) |
440 | 465 | @device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) |
441 | 466 |
|
| 467 | +@device_override Base.muladd(x::Float64, y::Float64, z::Float64) = ccall("llvm.fmuladd.f64", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) |
| 468 | +@device_override Base.muladd(x::Float32, y::Float32, z::Float32) = ccall("llvm.fmuladd.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) |
| 469 | +@device_override Base.muladd(x::Float16, y::Float16, z::Float16) = ccall("llvm.fmuladd.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) |
| 470 | + |
442 | 471 | @device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) |
443 | 472 | @device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) |
444 | 473 |
|
|
0 commit comments