Skip to content

Commit 460eb54

Browse files
committed
div(::TracedRNumber{Int}, ::TracedRNumber{Int}, ::RoundingMode)
1 parent aece0b1 commit 460eb54

1 file changed

Lines changed: 51 additions & 0 deletions

File tree

src/TracedRNumber.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,45 @@ Base.flipsign(x::TracedRNumber, y::TracedRNumber) = ifelse(y < 0, -x, x)
296296

297297
function Base.div(
298298
x::TracedRNumber{<:Reactant.ReactantSInt}, y::TracedRNumber{<:Reactant.ReactantUInt}
299+
)
300+
return div(x, y, RoundDown)
301+
end
302+
function Base.div(
303+
x::TracedRNumber{<:Reactant.ReactantSInt}, y::TracedRNumber{<:Reactant.ReactantUInt},
304+
::typeof(RoundToZero),
299305
)
300306
return flipsign(signed(div(unsigned(abs(x)), y)), x)
301307
end
308+
function Base.div(
309+
x::TracedRNumber{<:Reactant.ReactantSInt}, y::TracedRNumber{<:Reactant.ReactantUInt},
310+
::typeof(RoundDown),
311+
)
312+
ax = unsigned(abs(x))
313+
q = signed(div(ax, y))
314+
has_rem = !iszero(rem(ax, y))
315+
result = flipsign(q, x)
316+
return ifelse(signbit(x) & has_rem, result - one(result), result)
317+
end
318+
function Base.div(
319+
x::TracedRNumber{<:Reactant.ReactantSInt}, y::TracedRNumber{<:Reactant.ReactantUInt},
320+
::typeof(RoundUp),
321+
)
322+
ax = unsigned(abs(x))
323+
q = signed(div(ax, y))
324+
has_rem = !iszero(rem(ax, y))
325+
result = flipsign(q, x)
326+
return ifelse(!signbit(x) & has_rem, result + one(result), result)
327+
end
328+
function Base.div(
329+
x::TracedRNumber{<:Reactant.ReactantSInt}, y::TracedRNumber{<:Reactant.ReactantUInt},
330+
::typeof(RoundFromZero),
331+
)
332+
ax = unsigned(abs(x))
333+
q = signed(div(ax, y))
334+
has_rem = !iszero(rem(ax, y))
335+
q_adj = q + ifelse(has_rem, one(q), zero(q))
336+
return flipsign(q_adj, x)
337+
end
302338
function Base.div(
303339
x::TracedRNumber{<:Reactant.ReactantUInt}, y::TracedRNumber{<:Reactant.ReactantSInt}
304340
)
@@ -344,13 +380,28 @@ function Base.div(
344380
)
345381
end
346382

383+
function Base.div(
384+
@nospecialize(lhs::TracedRNumber{T}),
385+
@nospecialize(rhs::TracedRNumber{T}),
386+
::typeof(RoundToZero),
387+
) where {T<:Integer}
388+
return @opcall divide(lhs, rhs)
389+
end
347390
function Base.div(
348391
@nospecialize(lhs::TracedRNumber{T}),
349392
@nospecialize(rhs::TracedRNumber{T}),
350393
::typeof(RoundDown),
351394
) where {T<:Integer}
352395
return @opcall divide(lhs, rhs)
353396
end
397+
function Base.div(
398+
@nospecialize(lhs::TracedRNumber{T}),
399+
@nospecialize(rhs::TracedRNumber{T}),
400+
::typeof(RoundUp),
401+
) where {T<:Integer}
402+
q = div(lhs, rhs) # truncation (RoundToZero)
403+
return q + (!iszero(rem(lhs, rhs)) & (signbit(lhs) == signbit(rhs)))
404+
end
354405
function Base.div(
355406
@nospecialize(lhs::TracedRNumber{T}),
356407
@nospecialize(rhs::TracedRNumber{T}),

0 commit comments

Comments
 (0)