Skip to content

Commit be31e4b

Browse files
div(::TracedRNumber{Int}, ::TracedRNumber{Int}, ::RoundingMode) (#2787)
* div(::TracedRNumber{Int}, ::TracedRNumber{Int}, ::RoundingMode) * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 785135a commit be31e4b

1 file changed

Lines changed: 55 additions & 0 deletions

File tree

src/TracedRNumber.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,49 @@ Base.flipsign(x::TracedRNumber, y::TracedRNumber) = ifelse(y < 0, -x, x)
296296

297297
function Base.div(
298298
x::TracedRNumber{<:Reactant.ReactantSInt}, y::TracedRNumber{<:Reactant.ReactantUInt}
299+
)
300+
return div(x, y, RoundDown)
301+
end
302+
function Base.div(
303+
x::TracedRNumber{<:Reactant.ReactantSInt},
304+
y::TracedRNumber{<:Reactant.ReactantUInt},
305+
::typeof(RoundToZero),
299306
)
300307
return flipsign(signed(div(unsigned(abs(x)), y)), x)
301308
end
309+
function Base.div(
310+
x::TracedRNumber{<:Reactant.ReactantSInt},
311+
y::TracedRNumber{<:Reactant.ReactantUInt},
312+
::typeof(RoundDown),
313+
)
314+
ax = unsigned(abs(x))
315+
q = signed(div(ax, y))
316+
has_rem = !iszero(rem(ax, y))
317+
result = flipsign(q, x)
318+
return ifelse(signbit(x) & has_rem, result - one(result), result)
319+
end
320+
function Base.div(
321+
x::TracedRNumber{<:Reactant.ReactantSInt},
322+
y::TracedRNumber{<:Reactant.ReactantUInt},
323+
::typeof(RoundUp),
324+
)
325+
ax = unsigned(abs(x))
326+
q = signed(div(ax, y))
327+
has_rem = !iszero(rem(ax, y))
328+
result = flipsign(q, x)
329+
return ifelse(!signbit(x) & has_rem, result + one(result), result)
330+
end
331+
function Base.div(
332+
x::TracedRNumber{<:Reactant.ReactantSInt},
333+
y::TracedRNumber{<:Reactant.ReactantUInt},
334+
::typeof(RoundFromZero),
335+
)
336+
ax = unsigned(abs(x))
337+
q = signed(div(ax, y))
338+
has_rem = !iszero(rem(ax, y))
339+
q_adj = q + ifelse(has_rem, one(q), zero(q))
340+
return flipsign(q_adj, x)
341+
end
302342
function Base.div(
303343
x::TracedRNumber{<:Reactant.ReactantUInt}, y::TracedRNumber{<:Reactant.ReactantSInt}
304344
)
@@ -344,13 +384,28 @@ function Base.div(
344384
)
345385
end
346386

387+
function Base.div(
388+
@nospecialize(lhs::TracedRNumber{T}),
389+
@nospecialize(rhs::TracedRNumber{T}),
390+
::typeof(RoundToZero),
391+
) where {T<:Integer}
392+
return @opcall divide(lhs, rhs)
393+
end
347394
function Base.div(
348395
@nospecialize(lhs::TracedRNumber{T}),
349396
@nospecialize(rhs::TracedRNumber{T}),
350397
::typeof(RoundDown),
351398
) where {T<:Integer}
352399
return @opcall divide(lhs, rhs)
353400
end
401+
function Base.div(
402+
@nospecialize(lhs::TracedRNumber{T}),
403+
@nospecialize(rhs::TracedRNumber{T}),
404+
::typeof(RoundUp),
405+
) where {T<:Integer}
406+
q = div(lhs, rhs) # truncation (RoundToZero)
407+
return q + (!iszero(rem(lhs, rhs)) & (signbit(lhs) == signbit(rhs)))
408+
end
354409
function Base.div(
355410
@nospecialize(lhs::TracedRNumber{T}),
356411
@nospecialize(rhs::TracedRNumber{T}),

0 commit comments

Comments
 (0)