Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion include/matx/kernels/fltflt.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct alignas(8) fltflt {
float lo;

// The default constructor does not initialize the components, so the value is indeterminate.
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt() = default;
__MATX_INLINE__ fltflt() = default;
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ constexpr explicit fltflt(double x)
: hi(static_cast<float>(x)), lo(static_cast<float>(x - static_cast<double>(hi))) {}
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ constexpr explicit fltflt(float x) : hi(x), lo(0.0f) {}
Expand Down Expand Up @@ -519,6 +519,60 @@ static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt fltflt_sqrt(fltflt a
return fltflt_add(prod, yn);
}

// fltflt_sqrt_fast() is a faster approximation of fltflt_sqrt() that uses a single FMA to
// compute the residual a - yn^2 instead of full fltflt subtraction. The FMA computes
// a.hi - yn*yn exactly (exact multiply, single rounding), and adding a.lo recovers the
// input's low-order bits. The result has precision comparable to fltflt_sqrt for most
// values at roughly 1/5 the cost (~7 FLOPs vs ~35+). We do see differences for some
// inputs. For example, for 1e9*pi + sqrt(2), fltflt_sqrt() matches the fp64
// baseline in all mantissa bits and fltflt_sqrt_fast() matches the first 45 mantissa bits.
// This function may eventually become the default sqrt() implementation.
static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt fltflt_sqrt_fast(fltflt a) {
const float xn = (a.hi == 0.0f) ? 0.0f : detail::fltflt_rsqrt(a.hi);
const float yn = detail::fmul_rn(a.hi, xn);
const float residual = detail::fadd_rn(
detail::fmaf_rn(-yn, yn, a.hi), a.lo);
const float correction = detail::fmul_rn(
detail::fmul_rn(xn, 0.5f), residual);
return fltflt_fast_two_sum(yn, correction);
}

// fltflt_norm3d() computes sqrt(dx^2 + dy^2 + dz^2) with minimal intermediate
// normalizations. Instead of the separate fltflt_mul + fltflt_fma + fltflt_fma + fltflt_sqrt_fast
// chain (5 normalizations, ~50 ops), this function computes all three exact squares,
// accumulates with a single normalization, and applies fltflt_sqrt_fast (~39 ops).
// The three inputs are assumed to be normalized fltflt values.
static __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt fltflt_norm3d(fltflt dx, fltflt dy, fltflt dz) {
// Exact squares of hi components (each captures full rounding error)
const fltflt px = fltflt_two_prod_fma(dx.hi, dx.hi);
const fltflt py = fltflt_two_prod_fma(dy.hi, dy.hi);
const fltflt pz = fltflt_two_prod_fma(dz.hi, dz.hi);

// Sum the three .hi values using two_sum to capture rounding errors
const fltflt s = fltflt_two_sum(px.hi, py.hi);
const fltflt t = fltflt_two_sum(s.hi, pz.hi);

// Accumulate all eight low-order terms into a single float:
// - two_sum rounding errors: s.lo, t.lo
// - two_prod_fma error terms: px.lo, py.lo, pz.lo
// - cross terms from squaring: 2*dx.hi*dx.lo, 2*dy.hi*dy.lo, 2*dz.hi*dz.lo
// All terms are O(eps) relative to t.hi, so their sum is at most 8*eps*|t.hi|.
// This may result in slight precision loss due to potential overlap between
// lo and t.hi, but this should still be valid for ~44 bits prior to the sqrt.
float lo = detail::fadd_rn(t.lo, s.lo);
lo = detail::fadd_rn(lo, px.lo);
lo = detail::fadd_rn(lo, py.lo);
lo = detail::fadd_rn(lo, pz.lo);
lo = detail::fmaf_rn(detail::fadd_rn(dx.hi, dx.hi), dx.lo, lo);
lo = detail::fmaf_rn(detail::fadd_rn(dy.hi, dy.hi), dy.lo, lo);
lo = detail::fmaf_rn(detail::fadd_rn(dz.hi, dz.hi), dz.lo, lo);

// Single normalization before sqrt
const fltflt sum_sq = fltflt_fast_two_sum(t.hi, lo);

return fltflt_sqrt_fast(sum_sq);
}

// Scalar sqrt overload so unary operator dispatch can handle fltflt expressions
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ fltflt sqrt(fltflt a) { return fltflt_sqrt(a); }

Expand Down
68 changes: 39 additions & 29 deletions include/matx/kernels/sar_bp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@

namespace matx {

static constexpr double SPEED_OF_LIGHT = 2.997291625155841e+08;
// SI-defined speed of light in m/s. The speed of light through the atmosphere will be roughly 0.03% slower
// than this, but it is assumed that any corrections for atmospheric propagation will be done elsewhere.
static constexpr double SPEED_OF_LIGHT = 299792458.0;

#ifdef __CUDACC__

Expand All @@ -75,8 +77,7 @@ __device__ inline fltflt ComputeRangeToPixelFloatFloat(fltflt apx, fltflt apy, f
const fltflt dx = px - apx;
const fltflt dy = py - apy;
const fltflt dz = pz - apz;
const fltflt dx2dy2 = fltflt_fma(dx, dx, dy * dy);
return fltflt_sqrt(fltflt_fma(dz, dz, dx2dy2));
return fltflt_norm3d(dx, dy, dz);
}

template <typename PlatPosType, SarBpComputeType ComputeType, typename strict_compute_t, typename loose_compute_t>
Expand Down Expand Up @@ -140,6 +141,9 @@ __global__ void SarBpFillPhaseLUT(cuda::std::complex<StorageType> *phase_lut, Co
template <SarBpComputeType ComputeType>
using strict_compute_param_t = typename std::conditional<ComputeType == SarBpComputeType::Double || ComputeType == SarBpComputeType::Mixed || ComputeType == SarBpComputeType::FloatFloat, double, float>::type;

template <SarBpComputeType ComputeType>
using strict_or_ff_compute_param_t = typename std::conditional<ComputeType == SarBpComputeType::FloatFloat, fltflt, strict_compute_param_t<ComputeType>>::type;

template <SarBpComputeType ComputeType>
using loose_compute_param_t = typename std::conditional<ComputeType == SarBpComputeType::Double, double, float>::type;

Expand All @@ -154,7 +158,7 @@ struct SarBpSharedMemory<SarBpComputeType::FloatFloat> {
template <SarBpComputeType ComputeType, typename OutImageType, typename InitialImageType, typename RangeProfilesType, typename PlatPosType, typename VoxLocType, typename RangeToMcpType, bool PhaseLUT>
__launch_bounds__(16*16)
__global__ void SarBp(OutImageType output, const InitialImageType initial_image, const __grid_constant__ RangeProfilesType range_profiles, const __grid_constant__ PlatPosType platform_positions, const __grid_constant__ VoxLocType voxel_locations, const __grid_constant__ RangeToMcpType range_to_mcp,
strict_compute_param_t<ComputeType> dr_inv,
strict_or_ff_compute_param_t<ComputeType> dr_inv,
strict_compute_param_t<ComputeType> phase_correction_partial,
cuda::std::complex<loose_compute_param_t<ComputeType>> *phase_lut)
{
Expand All @@ -180,6 +184,7 @@ __global__ void SarBp(OutImageType output, const InitialImageType initial_image,
using voxel_loc_t = typename VoxLocType::value_type;
using compute_t = typename std::conditional<ComputeType == SarBpComputeType::Double, double, float>::type;
using strict_compute_t = typename std::conditional<ComputeType == SarBpComputeType::Double || ComputeType == SarBpComputeType::Mixed, double, float>::type;
using strict_or_ff_compute_t = typename std::conditional<ComputeType == SarBpComputeType::FloatFloat, fltflt, strict_compute_t>::type;
using strict_complex_compute_t = cuda::std::complex<strict_compute_t>;
using loose_compute_t = typename std::conditional<ComputeType == SarBpComputeType::Double, double, float>::type;
using loose_complex_compute_t = cuda::std::complex<loose_compute_t>;
Expand Down Expand Up @@ -221,7 +226,7 @@ __global__ void SarBp(OutImageType output, const InitialImageType initial_image,
};

[[maybe_unused]] const loose_compute_t phase_correction_partial_loose = static_cast<loose_compute_t>(phase_correction_partial);
const auto get_reference_phase = [&phase_lut, &phase_correction_partial, &phase_correction_partial_loose](strict_compute_t diffR, index_t bin_floor_int, loose_compute_t w) -> loose_complex_compute_t {
const auto get_reference_phase = [&phase_lut, &phase_correction_partial, &phase_correction_partial_loose](strict_or_ff_compute_t diffR, index_t bin_floor_int, loose_compute_t w) -> loose_complex_compute_t {
if constexpr (PhaseLUT) {
const loose_complex_compute_t base_phase = phase_lut[bin_floor_int];
float incr_sinx, incr_cosx;
Expand All @@ -231,7 +236,8 @@ __global__ void SarBp(OutImageType output, const InitialImageType initial_image,
base_phase.real() * incr_sinx + base_phase.imag() * incr_cosx
};
} else {
strict_compute_t sinx, cosx;
// With PhaseLUT == false, strict_or_ff_compute_t is either float or double, so we can use sincos[f] directly.
strict_or_ff_compute_t sinx, cosx;
if constexpr (std::is_same_v<strict_compute_t, double>) {
::sincos(phase_correction_partial * diffR, &sinx, &cosx);
} else {
Expand All @@ -243,16 +249,11 @@ __global__ void SarBp(OutImageType output, const InitialImageType initial_image,
}
};

[[maybe_unused]] fltflt dr_inv_fltflt{};
if constexpr (ComputeType == SarBpComputeType::FloatFloat) {
// We could perform this in only a single thread and store the result to shared memory, but
// in initial tests that did not improve performance.
dr_inv_fltflt = static_cast<fltflt>(dr_inv);
}
[[maybe_unused]] const int tid = threadIdx.x + threadIdx.y * blockDim.x;

loose_complex_compute_t accum{};
const loose_compute_t bin_offset = static_cast<loose_compute_t>(0.5) * static_cast<loose_compute_t>(num_range_bins-1);

const loose_compute_t max_bin_f = static_cast<loose_compute_t>(num_range_bins) - static_cast<loose_compute_t>(2.0);
const int num_pulse_blocks = (num_pulses + PULSE_BLOCK_SIZE - 1) / PULSE_BLOCK_SIZE;
for (int block = 0; block < num_pulse_blocks; ++block) {
Expand All @@ -272,7 +273,9 @@ __global__ void SarBp(OutImageType output, const InitialImageType initial_image,
sh_mem.ant_pos[ip][1] = static_cast<fltflt>(platform_positions.operator()(p, 1));
sh_mem.ant_pos[ip][2] = static_cast<fltflt>(platform_positions.operator()(p, 2));
}
sh_mem.ant_pos[ip][3] = static_cast<fltflt>(r_to_mcp(p));
const fltflt rtm = static_cast<fltflt>(r_to_mcp(p));
const fltflt neg_rtm = fltflt{-rtm.hi, -rtm.lo};
sh_mem.ant_pos[ip][3] = fltflt_fma(neg_rtm, dr_inv, bin_offset);
}
__syncthreads();
if (! is_valid) {
Expand All @@ -282,27 +285,34 @@ __global__ void SarBp(OutImageType output, const InitialImageType initial_image,
#pragma unroll 4
for (index_t ip = 0; ip < num_pulses_in_block; ++ip) {
const int p = block * PULSE_BLOCK_SIZE + ip;
[[maybe_unused]] strict_compute_t diffR{};
loose_compute_t bin;
strict_or_ff_compute_t diffR;
loose_compute_t w;
index_t bin_floor_int;
if constexpr (ComputeType == SarBpComputeType::FloatFloat) {
const fltflt diffR_ff = ComputeRangeToPixelFloatFloat(
sh_mem.ant_pos[ip][0], sh_mem.ant_pos[ip][1], sh_mem.ant_pos[ip][2], px, py, pz) - sh_mem.ant_pos[ip][3];
bin = static_cast<loose_compute_t>(diffR_ff * dr_inv_fltflt) + bin_offset;
// diffR is otherwise unused for FloatFloat and thus not set
// This is just the distance to the pixel rather than the differential range to the MCP.
// We use diffR because otherwise we would need to initialize diffR to avoid a
// compiler warning about uninitialized use of diffR.
diffR = ComputeRangeToPixelFloatFloat(
sh_mem.ant_pos[ip][0], sh_mem.ant_pos[ip][1], sh_mem.ant_pos[ip][2], px, py, pz);
// sh_mem.ant_pos[ip][3] is -mcp * dr_inv + bin_offset, so here we compute
// dist * dr_inv + (-mcp * dr_inv + bin_offset) = (dist - mcp) * dr_inv + bin_offset
const fltflt bin = fltflt_fma(diffR, dr_inv, sh_mem.ant_pos[ip][3]);
float floor_hi = ::floorf(bin.hi);
float frac = (bin.hi - floor_hi) + bin.lo;
// bin.lo may push bin over a boundary, in which case floor and frac are incorrect.
// Compute an adjustment based on whether or not the fractional part is outside (0.0, 1.0).
const float adjust = ::floorf(frac); // -1, 0, or 1
bin_floor_int = static_cast<index_t>(floor_hi + adjust);
w = frac - adjust;
} else {
diffR = ComputeRangeToPixel<PlatPosType, ComputeType, strict_compute_t, loose_compute_t>(
platform_positions, p, px, py, pz) - static_cast<strict_compute_t>(r_to_mcp(p));
bin = static_cast<loose_compute_t>(diffR * dr_inv) + bin_offset;
const strict_compute_t bin = diffR * dr_inv + bin_offset;
const strict_compute_t bin_floor = ::floor(bin);
w = static_cast<loose_compute_t>(bin - bin_floor);
bin_floor_int = static_cast<index_t>(bin_floor);
}
if (bin >= static_cast<loose_compute_t>(0.0) && bin < max_bin_f) {
loose_compute_t bin_floor;
if constexpr (std::is_same_v<loose_compute_t, float>) {
bin_floor = ::floorf(bin);
} else {
bin_floor = ::floor(bin);
}
const loose_compute_t w = bin - bin_floor;
const index_t bin_floor_int = static_cast<index_t>(bin_floor);
if (bin_floor_int >= 0 && bin_floor_int < static_cast<index_t>(num_range_bins-1)) {

range_profiles_t sample_lo, sample_hi;

Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/sar_bp.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ inline void sar_bp_impl(OutImageType &out, const InitialImageType &initial_image
cuda::std::complex<float> *phase_lut = static_cast<cuda::std::complex<float> *>(workspace);
SarBpFillPhaseLUT<double, float><<<lut_grid, lut_block, 0, stream>>>(phase_lut, params.center_frequency, params.del_r, range_profiles.Size(1));
SarBp<SarBpComputeType::FloatFloat, OutImageType, InitialImageType, RangeProfilesType, PlatPosType, VoxLocType, RangeToMcpType, PhaseLUT><<<grid, block, 0, stream>>>(
out, initial_image, range_profiles, platform_positions, voxel_locations, range_to_mcp, dr_inv, phase_correction_partial, phase_lut);
out, initial_image, range_profiles, platform_positions, voxel_locations, range_to_mcp, static_cast<fltflt>(dr_inv), phase_correction_partial, phase_lut);
} else {
cuda::std::complex<float> *phase_lut = static_cast<cuda::std::complex<float> *>(workspace);
SarBpFillPhaseLUT<float, float><<<lut_grid, lut_block, 0, stream>>>(phase_lut, static_cast<float>(params.center_frequency), static_cast<float>(params.del_r), range_profiles.Size(1));
Expand Down
91 changes: 87 additions & 4 deletions test/00_misc/FloatFloatTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ struct FltFltSqrt {
}
};

struct FltFltSqrtFast {
__MATX_HOST__ __MATX_DEVICE__ double operator()(fltflt a) const
{
return static_cast<double>(static_cast<double>(fltflt_sqrt_fast(a)));
}
};

struct FltFltAbs {
__MATX_HOST__ __MATX_DEVICE__ double operator()(fltflt a) const
{
Expand Down Expand Up @@ -500,19 +507,95 @@ TYPED_TEST(FltFltExecutorTests, Division) {
}

TYPED_TEST(FltFltExecutorTests, SquareRoot) {
auto pi = make_tensor<fltflt>({});
(pi = static_cast<fltflt>(std::numbers::pi)).run(this->exec);

auto input = make_tensor<fltflt>({});
auto sqrt_result = make_tensor<double>({});
(sqrt_result = matx::apply(FltFltSqrt{}, pi)).run(this->exec);

// Test case: pi (moderate value)
(input = static_cast<fltflt>(std::numbers::pi)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrt{}, input)).run(this->exec);
this->exec.sync();

const double pi_sqrt_ref_f64 = std::sqrt(std::numbers::pi);
const float pi_sqrt_ref_f32 = std::sqrt(std::numbers::pi_v<float>);

EXPECT_LE(numMatchingMantissaBits(pi_sqrt_ref_f32, pi_sqrt_ref_f64), 24);
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), pi_sqrt_ref_f64), 44);

// Test case: zero
(input = static_cast<fltflt>(0.0)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrt{}, input)).run(this->exec);
this->exec.sync();
EXPECT_EQ(sqrt_result(), 0.0);

// Test case: small number (1.23e-7)
constexpr double small_val = 1.23e-7;
const double small_ref = std::sqrt(small_val);
(input = static_cast<fltflt>(small_val)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrt{}, input)).run(this->exec);
this->exec.sync();
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), small_ref), 44);

// Test case: large number (e * 1e10, 52 active mantissa bits)
constexpr double large_val = std::numbers::e * 1e10;
const double large_ref = std::sqrt(large_val);
(input = static_cast<fltflt>(large_val)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrt{}, input)).run(this->exec);
this->exec.sync();
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), large_ref), 44);

// Test case: 1e9*pi + sqrt(2) (52 active mantissa bits, ~3.1e9)
const double large_val2 = 1e9 * std::numbers::pi + std::sqrt(2.0);
const double large_ref2 = std::sqrt(large_val2);
(input = static_cast<fltflt>(large_val2)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrt{}, input)).run(this->exec);
this->exec.sync();
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), large_ref2), 44);
}

TYPED_TEST(FltFltExecutorTests, SquareRootFast) {
auto input = make_tensor<fltflt>({});
auto sqrt_result = make_tensor<double>({});

// Test case: pi (moderate value)
(input = static_cast<fltflt>(std::numbers::pi)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrtFast{}, input)).run(this->exec);
this->exec.sync();

const double pi_sqrt_ref_f64 = std::sqrt(std::numbers::pi);
const float pi_sqrt_ref_f32 = std::sqrt(std::numbers::pi_v<float>);

EXPECT_LE(numMatchingMantissaBits(pi_sqrt_ref_f32, pi_sqrt_ref_f64), 24);
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), pi_sqrt_ref_f64), 44);

// Test case: zero
(input = static_cast<fltflt>(0.0)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrtFast{}, input)).run(this->exec);
this->exec.sync();
EXPECT_EQ(sqrt_result(), 0.0);

// Test case: small number (1.23e-7)
constexpr double small_val = 1.23e-7;
const double small_ref = std::sqrt(small_val);
(input = static_cast<fltflt>(small_val)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrtFast{}, input)).run(this->exec);
this->exec.sync();
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), small_ref), 44);

// Test case: large number (e * 1e10, 52 active mantissa bits)
constexpr double large_val = std::numbers::e * 1e10;
const double large_ref = std::sqrt(large_val);
(input = static_cast<fltflt>(large_val)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrtFast{}, input)).run(this->exec);
this->exec.sync();
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), large_ref), 44);

// Test case: 1e9*pi + sqrt(2) (52 active mantissa bits, ~3.1e9)
const double large_val2 = 1e9 * std::numbers::pi + std::sqrt(2.0);
const double large_ref2 = std::sqrt(large_val2);
(input = static_cast<fltflt>(large_val2)).run(this->exec);
(sqrt_result = matx::apply(FltFltSqrtFast{}, input)).run(this->exec);
this->exec.sync();
EXPECT_GE(numMatchingMantissaBits(sqrt_result(), large_ref2), 44);
}

TYPED_TEST(FltFltExecutorTests, MatXSqrtOperator) {
Expand Down
Loading