Skip to content

Commit 90aa1da

Browse files
committed
reactant ops resize_linear!
1 parent 4a53e68 commit 90aa1da

1 file changed

Lines changed: 70 additions & 0 deletions

File tree

src/Ops.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4522,4 +4522,74 @@ end
45224522
return diff .- log(reduced_exp_diff; location)
45234523
end
45244524

4525+
"""
4526+
resize_linear!(dst::AbstractArray{T,N}, src::AbstractArray{T,N}, output_shape::NTuple{N,Int}) -> AbstractArray
4527+
4528+
The interpolation follows JAX's `jax.image.resize` with `ResizeMethod.LINEAR`: for each
4529+
spatial dimension a weight matrix is computed via a triangle kernel and applied via
4530+
matrix multiplication. The weight matrix depends only on the input/output sizes (which
4531+
are static during JIT compilation), so it becomes an XLA constant with no runtime
4532+
overhead.
4533+
4534+
# Examples
4535+
```julia
4536+
x = rand(Float32, 4, 4)
4537+
y = resize(x, (8, 8)) # 2× upsample
4538+
z = resize(x, (2, 3)) # mixed downsample / upsample
4539+
```
4540+
"""
4541+
@noinline function resize_linear!(
4542+
dst::AbstractArray{T,N}, src::AbstractArray{T,N}
4543+
) where {T,N}
4544+
output_shape = size(dst)
4545+
Tf = float(T)
4546+
x = Tf.(src)
4547+
for d in 1:N
4548+
size(x, d) == output_shape[d] && continue
4549+
m = size(x, d)
4550+
n = output_shape[d]
4551+
# Weight matrix: (m, n). Computed at trace time from static sizes.
4552+
mat = _resize_weight_mat(m, n)
4553+
W = Tf.(constant(mat))
4554+
# Move dim d to position 1, flatten remaining dims, matmul, restore.
4555+
perm = [d; setdiff(1:ndims(x), d)]
4556+
xt = permutedims(x, perm) # (m, d2, d3, ...)
4557+
rest = size(xt)[2:end]
4558+
xt = reshape(xt, m, prod(rest)) # (m, rest_flat)
4559+
W′ = transpose(W, collect(Int64, ndims(W):-1:1))
4560+
xt = dot_general(W′, xt; contracting_dimensions=([2], [1])) # (n, rest_flat)
4561+
xt = reshape(xt, n, rest...)
4562+
x = permutedims(xt, invperm(perm))
4563+
end
4564+
return copyto!(dst, x)
4565+
end
4566+
4567+
function _resize_weight_mat(input_size::Int, output_size::Int)
4568+
inv_scale = input_size / output_size
4569+
# Center of each output pixel in 0-indexed input-pixel coordinates
4570+
sample_f = ((0:(output_size - 1)) .+ 0.5) .* inv_scale .- 0.5 # (output_size,)
4571+
input_pos = 0.0:(input_size - 1) # (input_size,)
4572+
# x[i, j] = |sample_f[j] - (i-1)|; broadcasting over both dims
4573+
x = Base.abs.(sample_f' .- input_pos) # (input_size, output_size)
4574+
# Triangle kernel: max(0, 1 - |x|)
4575+
weights = max.(0.0, 1.0 .- x)
4576+
# Normalize each column
4577+
col_sums = sum(weights; dims=1)
4578+
weights = weights ./ col_sums
4579+
# Zero out columns where the sample is outside the valid input range
4580+
valid = (sample_f .>= -0.5) .& (sample_f .<= input_size - 0.5)
4581+
weights = weights .* valid'
4582+
return weights # (input_size, output_size), Float64
4583+
end
4584+
4585+
"""
4586+
resize(x::AbstractArray{T,N}, output_shape::NTuple{N,Int}) -> AbstractArray
4587+
4588+
Resize array `x` to `output_shape` using linear interpolation. See [`resize_linear!`](@ref).
4589+
"""
4590+
function resize(x::AbstractArray{T,N}, output_shape::NTuple{N,Int}) where {T,N}
4591+
dst = similar(x, float(T), output_shape)
4592+
return resize_linear!(dst, float(T).(x))
4593+
end
4594+
45254595
end # module Ops

0 commit comments

Comments
 (0)