9393
9494@device_override @inline Random. default_rng () = Philox2x32 ()
9595
96+ # default to Float32 on GPU (matches CUDA convention, avoids expensive FP64)
97+ @device_override @inline Random. rand (rng:: AbstractRNG ) = Random. rand (rng, Float32)
98+
9699"""
97100 Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0])
98101
@@ -123,14 +126,9 @@ else
123126 Random. seed! (Random. default_rng (), seed)
124127end
125128
126- """
127- Random.rand(rng::Philox2x32, UInt32)
128-
129- Generate a byte of random data using the on-device Tausworthe generator.
130- """
131- function Random. rand (rng:: Philox2x32{R} ,:: Type{UInt64} ) where {R}
132- ctr1, ctr2, key = rng. ctr1, rng. ctr2, rng. key
133-
129+ # R rounds of Philox2x32, unrolled at compile time
130+ @inline function philox2x_rounds (:: Val{R} , ctr1:: UInt32 , ctr2:: UInt32 ,
131+ key:: UInt32 ) where R
134132 if R > 0 ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
135133 if R > 1 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
136134 if R > 2 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
@@ -147,6 +145,16 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
147145 if R > 13 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
148146 if R > 14 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
149147 if R > 15 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
148+ ctr1, ctr2
149+ end
150+
151+ """
152+ Random.rand(rng::Philox2x32, UInt64)
153+
154+ Generate 64 bits of random data using the on-device Philox2x32 generator.
155+ """
156+ function Random. rand (rng:: Philox2x32{R} , :: Type{UInt64} ) where {R}
157+ ctr1, ctr2 = philox2x_rounds (Val (R), rng. ctr1, rng. ctr2, rng. key)
150158
151159 # update the warp counter
152160 # NOTE: this performs the same update on every thread in the warp, but each warp writes
@@ -201,7 +209,7 @@ function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
201209 end
202210end
203211
204- for var in [:ki , :wi , :fi , : ke , :we , :fe ]
212+ for var in [:ke , :we , :fe ]
205213 val = getfield (Random, var)
206214 gpu_var = Symbol (" gpu_$var " )
207215 arr_typ = :(CuDeviceArray{$ (eltype (val)),$ (ndims (val)),AS. Constant})
@@ -211,39 +219,119 @@ for var in [:ki, :wi, :fi, :ke, :we, :fe]
211219 end
212220end
213221
214- # # randn
222+ # # Box-Muller helpers
223+ #
224+ # Vendored from GPUArrays.jl, which uses them in its host-side Philox4x32-10
225+ # batched randn kernel. Keep constants in sync when upstream tunes them.
226+
227+ using Base: FastMath
228+
229+ # unsigned int → uniform float in (0, 1), strictly positive
230+
231+ @inline u01 (:: Type{Float32} , u:: UInt32 ) =
232+ fma (Float32 (u), Float32 (2 )^ (- 32 ), Float32 (2 )^ (- 33 ))
233+
234+ # Bit-pattern construction avoids Float64(::UInt64) + FMA on consumer GPUs
235+ # (FP64 throughput as low as 1:64). Low mantissa bit set so result ∈ (0, 1) —
236+ # Box-Muller needs log(u) ≠ -Inf.
237+ @inline u01 (:: Type{Float64} , u:: UInt64 ) =
238+ reinterpret (Float64, ((u >> 12 ) | 0x1 ) | 0x3ff0000000000000 ) - 1.0
239+
240+ # Polynomial sincospi(Float32): branchless, stays in Float32 (Base.sincospi
241+ # widens internally). Bottom 3 bits of u pick an octant (swap/negate); top
242+ # 29 bits give the reduced argument (+0.5-biased so y ≠ 0).
243+
244+ const SP_F32 = (3.1415927f0 , - 5.167708f0 , 2.5497673f0 , - 0.58907866f0 )
245+ const CP_F32 = (1.0f0 , - 4.934788f0 , 4.057578f0 , - 1.3061346f0 )
246+
247+ @inline function fast_sincospi (:: Type{Float32} , u:: UInt32 )
248+ oct = (u % Int32) & Int32 (7 )
249+ y = fma (Float32 (u & ~ UInt32 (7 )), Float32 (2 )^ (- 34 ), Float32 (2 )^ (- 32 ))
250+ sp = y * evalpoly (y * y, SP_F32)
251+ cp = evalpoly (y * y, CP_F32)
252+ swap = ! iszero (oct & Int32 (1 ))
253+ sin_neg = ! iszero (oct & Int32 (2 ))
254+ cos_neg = ! iszero (oct & Int32 (4 ))
255+ s_raw = ifelse (swap, cp, sp)
256+ c_raw = ifelse (swap, sp, cp)
257+ (ifelse (sin_neg, - s_raw, s_raw), ifelse (cos_neg, - c_raw, c_raw))
258+ end
215259
216- @device_override function Random. randn (rng:: AbstractRNG )
217- while true
218- r = Random. rand (rng, Random. UInt52Raw ()) % UInt64
219- @inbounds begin
220- r &= 0x000fffffffffffff
221- rabs = Int64 (r>> 1 ) # One bit for the sign
222- idx = rabs & 0xFF
223- x = ifelse (r % Bool, - rabs, rabs)* gpu_wi ()[idx+ 1 ]
224- rabs < gpu_ki ()[idx+ 1 ] && return x # 99.3% of the time we return here 1st try
225- result = randn_unlikely (rng, idx, rabs, x)
226- result != = nothing && return result
227- end
228- end
260+ # Polynomial log(Float32), fdlibm-based. Consumes the raw UInt32 output; u01
261+ # is folded into the first FMA so there's no intermediate float.
262+
263+ const SQRT_HALF_I32 = reinterpret (Int32, Float32 (sqrt (0.5 )))
264+ const LOG_ODD_F32 = (reinterpret (Float32, Int32 (0x3f2aaaaa )),
265+ reinterpret (Float32, Int32 (0x3e91e9ee )))
266+ const LOG_EVEN_F32 = (reinterpret (Float32, Int32 (0x3eccce13 )),
267+ reinterpret (Float32, Int32 (0x3e789e26 )))
268+
269+ @inline function fast_log (:: Type{Float32} , u:: UInt32 )
270+ x = fma (Float32 (u), Float32 (2 )^ (- 32 ), Float32 (2 )^ (- 33 ))
271+ ix = reinterpret (Int32, x) - SQRT_HALF_I32
272+ k = ix >> Int32 (23 )
273+ f_std = reinterpret (Float32, (ix & Int32 (0x007fffff )) + SQRT_HALF_I32) - 1.0f0
274+ f_comp = - fma (Float32 (~ u), Float32 (2 )^ (- 32 ), Float32 (2 )^ (- 33 ))
275+ f = ifelse (k == Int32 (0 ), f_comp, f_std)
276+ s = f / (2.0f0 + f)
277+ z = s * s; w = z * z
278+ R = z * evalpoly (w, LOG_ODD_F32) + w * evalpoly (w, LOG_EVEN_F32)
279+ hfsq = 0.5f0 * f * f
280+ Float32 (k) * reinterpret (Float32, Int32 (0x3f317180 )) -
281+ ((hfsq - (s * (hfsq + R) +
282+ Float32 (k) * reinterpret (Float32, Int32 (0x3717f7d1 )))) - f)
229283end
230284
231- # this unlikely branch is put in a separate function for better efficiency
232- @noinline function randn_unlikely (rng, idx, rabs, x)
233- @inbounds if idx == 0
234- while true
235- xx = - Random. ziggurat_nor_inv_r* log (Random. rand (rng))
236- yy = - log (Random. rand (rng))
237- yy+ yy > xx* xx &&
238- return (rabs >> 8 ) % Bool ? - Random. ziggurat_nor_r- xx : Random. ziggurat_nor_r+ xx
239- end
240- elseif (gpu_fi ()[idx] - gpu_fi ()[idx+ 1 ])* Random. rand (rng) + gpu_fi ()[idx+ 1 ] < exp (- 0.5 * x* x)
241- return x # return from the triangular area
242- else
243- return # retry
244- end
285+ # Box-Muller: pair of uniforms → pair of standard normals
286+
287+ @inline function boxmuller (:: Type{T} , u1:: UInt32 , u2:: UInt32 ) where T <: Union{Float16,Float32}
288+ r = sqrt (- 2f0 * fast_log (Float32, u2))
289+ s, c = fast_sincospi (Float32, u1)
290+ (T (r * s), T (r * c))
291+ end
292+
293+ @inline function boxmuller (:: Type{Float64} , u1:: Float64 , u2:: Float64 )
294+ r = sqrt (- 2.0 * FastMath. log_fast (u1))
295+ s, c = sincospi (2 * u2)
296+ (r * s, r * c)
297+ end
298+
299+
300+ # # randn — Box-Muller transform
301+ #
302+ # Uses Box-Muller instead of Ziggurat: rejection sampling would warp-diverge,
303+ # and the Ziggurat tables aren't device-accessible.
304+
305+ # Specialization for Philox2x32: one Philox call produces exactly the pair of
306+ # UInt32s Box-Muller needs, halving the Philox work vs the generic path.
307+ @device_override @inline function Random. randn (rng:: Philox2x32{R} ,
308+ :: Type{T} ) where {R, T <: Union{Float16,Float32} }
309+ ctr1, ctr2 = philox2x_rounds (Val (R), rng. ctr1, rng. ctr2, rng. key)
310+ rng. ctr1 += 1 i32
311+ n, _ = boxmuller (T, ctr1, ctr2)
312+ n
245313end
246314
315+ # Float64 fundamentally needs 64 bits of entropy per uniform, so 2 Philox
316+ # calls. The u01 bit-trick avoids the expensive Float64(::UInt64) conversion.
317+ @device_override @inline function Random. randn (rng:: Philox2x32{R} ,
318+ :: Type{Float64} ) where R
319+ u1 = u01 (Float64, Random. rand (rng, UInt64))
320+ u2 = u01 (Float64, Random. rand (rng, UInt64))
321+ n, _ = boxmuller (Float64, u1, u2)
322+ n
323+ end
324+
325+ # Generic fallback for user-defined AbstractFloat types.
326+ @device_override @inline function Random. randn (rng:: AbstractRNG , :: Type{T} ) where T <: AbstractFloat
327+ U1 = max (Random. rand (rng, T), floatmin (T)) # avoid log(0)
328+ U2 = Random. rand (rng, T)
329+ sqrt (T (- 2 ) * FastMath. log_fast (U1)) * first (sincospi (T (2 ) * U2))
330+ end
331+
332+ # untyped randn() defaults to Float32 on GPU
333+ @device_override @inline Random. randn (rng:: AbstractRNG ) = Random. randn (rng, Float32)
334+
247335# # randexp
248336
249337@device_override function Random. randexp (rng:: AbstractRNG )
0 commit comments