@@ -4522,4 +4522,74 @@ end
45224522 return diff .- log (reduced_exp_diff; location)
45234523end
45244524
4525+ """
4526+ resize_linear!(dst::AbstractArray{T,N}, src::AbstractArray{T,N}, output_shape::NTuple{N,Int}) -> AbstractArray
4527+
4528+ The interpolation follows JAX's `jax.image.resize` with `ResizeMethod.LINEAR`: for each
4529+ spatial dimension a weight matrix is computed via a triangle kernel and applied via
4530+ matrix multiplication. The weight matrix depends only on the input/output sizes (which
4531+ are static during JIT compilation), so it becomes an XLA constant with no runtime
4532+ overhead.
4533+
4534+ # Examples
4535+ ```julia
4536+ x = rand(Float32, 4, 4)
4537+ y = resize(x, (8, 8)) # 2× upsample
4538+ z = resize(x, (2, 3)) # mixed downsample / upsample
4539+ ```
4540+ """
4541+ @noinline function resize_linear! (
4542+ dst:: AbstractArray{T,N} , src:: AbstractArray{T,N}
4543+ ) where {T,N}
4544+ output_shape = size (dst)
4545+ Tf = float (T)
4546+ x = Tf .(src)
4547+ for d in 1 : N
4548+ size (x, d) == output_shape[d] && continue
4549+ m = size (x, d)
4550+ n = output_shape[d]
4551+ # Weight matrix: (m, n). Computed at trace time from static sizes.
4552+ mat = _resize_weight_mat (m, n)
4553+ W = Tf .(constant (mat))
4554+ # Move dim d to position 1, flatten remaining dims, matmul, restore.
4555+ perm = [d; setdiff (1 : ndims (x), d)]
4556+ xt = permutedims (x, perm) # (m, d2, d3, ...)
4557+ rest = size (xt)[2 : end ]
4558+ xt = reshape (xt, m, prod (rest)) # (m, rest_flat)
4559+ W′ = transpose (W, collect (Int64, ndims (W): - 1 : 1 ))
4560+ xt = dot_general (W′, xt; contracting_dimensions= ([2 ], [1 ])) # (n, rest_flat)
4561+ xt = reshape (xt, n, rest... )
4562+ x = permutedims (xt, invperm (perm))
4563+ end
4564+ return copyto! (dst, x)
4565+ end
4566+
4567+ function _resize_weight_mat (input_size:: Int , output_size:: Int )
4568+ inv_scale = input_size / output_size
4569+ # Center of each output pixel in 0-indexed input-pixel coordinates
4570+ sample_f = ((0 : (output_size - 1 )) .+ 0.5 ) .* inv_scale .- 0.5 # (output_size,)
4571+ input_pos = 0.0 : (input_size - 1 ) # (input_size,)
4572+ # x[i, j] = |sample_f[j] - (i-1)|; broadcasting over both dims
4573+ x = Base. abs .(sample_f' .- input_pos) # (input_size, output_size)
4574+ # Triangle kernel: max(0, 1 - |x|)
4575+ weights = max .(0.0 , 1.0 .- x)
4576+ # Normalize each column
4577+ col_sums = sum (weights; dims= 1 )
4578+ weights = weights ./ col_sums
4579+ # Zero out columns where the sample is outside the valid input range
4580+ valid = (sample_f .>= - 0.5 ) .& (sample_f .<= input_size - 0.5 )
4581+ weights = weights .* valid'
4582+ return weights # (input_size, output_size), Float64
4583+ end
4584+
4585+ """
4586+ resize(x::AbstractArray{T,N}, output_shape::NTuple{N,Int}) -> AbstractArray
4587+
4588+ Resize array `x` to `output_shape` using linear interpolation. See [`resize_linear!`](@ref).
4589+ """
4590+ function resize (x:: AbstractArray{T,N} , output_shape:: NTuple{N,Int} ) where {T,N}
4591+ dst = similar (x, float (T), output_shape)
4592+ return resize_linear! (dst, float (T).(x))
4593+ end
4594+
45254595end # module Ops
0 commit comments