diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index f3cfd5234..709d1c30b 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 92f60e476..ce6d1e159 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -14,12 +14,12 @@ #define CONSTANT_SPACE __attribute__((address_space(4))) typedef _Float16 half16 __attribute__((ext_vector_type(16))); -typedef float float8 __attribute__((ext_vector_type(8))); +typedef float float8 __attribute__((ext_vector_type(8))); #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define FUNC_CALL(NAME) __zluda_ptx_impl_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME -#define DECLARE_ATTR(TYPE, NAME) \ +#define DECLARE_ATTR(TYPE, NAME) \ extern "C" __attribute__((constant)) CONSTANT_SPACE TYPE ATTR(NAME) \ __device__ @@ -189,8 +189,7 @@ extern "C" BAR_RED_IMPL(and); BAR_RED_IMPL(or); - -typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); + typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); // shfl.sync opts consists of two values, the warp end ID and the subsection mask. // @@ -219,12 +218,12 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); idx = self; \ } \ int32_t output = __builtin_amdgcn_ds_bpermute(idx << 2, (int32_t)input); \ - return {(uint32_t)output, uint32_t(!out_of_bounds)}; \ + return {(uint32_t)output, uint32_t(!out_of_bounds)}; \ } \ \ uint32_t FUNC(shfl_sync_##mode##_b32)(uint32_t input, int32_t delta, uint32_t opts, uint32_t membermask) \ { \ - return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \ + return __zluda_ptx_impl_shfl_sync_##mode##_b32_pred(input, delta, opts, membermask).x; \ } // We are using the HIP __shfl intrinsics to implement these, rather than the __shfl_sync @@ -493,10 +492,10 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); } float FUNC(div_f32_part2)(float x, float y, - float fma_4, - float fma_1, - float fma_3, - uint8_t numerator_scaled_flag) + float fma_4, + float fma_1, + float fma_3, + uint8_t numerator_scaled_flag) { return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag}); } @@ -596,35 +595,34 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); REDUX_SYNC_IMPL(min); REDUX_SYNC_IMPL(max); - - __device__ inline static uint32_t load_single_matrix(void SHARED_SPACE * lds_address, uint32_t warp_offset) + __device__ inline static uint32_t load_single_matrix(void SHARED_SPACE *lds_address, uint32_t warp_offset) { uint32_t laneid = __zluda_ptx_impl_sreg_laneid(); int32_t row_address = __builtin_amdgcn_ds_bpermute((int32_t)(warp_offset + (laneid / 4U)) << 2U, (int32_t)lds_address); uint32_t matrix_cell_address = (uint32_t)row_address + ((laneid % 4) * 4); - return *((uint32_t SHARED_SPACE*)matrix_cell_address); + return *((uint32_t SHARED_SPACE *)matrix_cell_address); } - __device__ inline static uint32_t load_single_matrix_trans(void SHARED_SPACE * lds_address, uint32_t warp_offset) + __device__ inline static uint32_t load_single_matrix_trans(void SHARED_SPACE *lds_address, uint32_t warp_offset) { uint32_t laneid = __zluda_ptx_impl_sreg_laneid(); int32_t row_address_lo = __builtin_amdgcn_ds_bpermute((int32_t)(warp_offset + ((laneid % 4U) * 2)) << 2U, (int32_t)lds_address); uint32_t address_lo = (uint32_t)row_address_lo + ((laneid / 4) * 2); - uint16_t lo = *((uint16_t SHARED_SPACE*)address_lo); + uint16_t lo = *((uint16_t SHARED_SPACE *)address_lo); int32_t row_address_hi = __builtin_amdgcn_ds_bpermute((int32_t)(warp_offset + ((laneid % 4U) * 2) + 1) << 2U, (int32_t)lds_address); uint32_t address_hi = (uint32_t)row_address_hi + ((laneid / 4) * 2); - uint16_t hi = *((uint16_t SHARED_SPACE*)address_hi); - return std::bit_cast(ushort2::Native_vec_ { lo, hi }); + uint16_t hi = *((uint16_t SHARED_SPACE *)address_hi); + return std::bit_cast(ushort2::Native_vec_{lo, hi}); } - uint2::Native_vec_ FUNC(ldmatrix_m8n8_x2_b16)(void SHARED_SPACE * address) + uint2::Native_vec_ FUNC(ldmatrix_m8n8_x2_b16)(void SHARED_SPACE *address) { uint32_t x0 = load_single_matrix(address, 0); uint32_t x1 = load_single_matrix(address, 8); return uint2::Native_vec_{x0, x1}; } - uint4::Native_vec_ FUNC(ldmatrix_m8n8_x4_b16)(void SHARED_SPACE * address) + uint4::Native_vec_ FUNC(ldmatrix_m8n8_x4_b16)(void SHARED_SPACE *address) { uint32_t x0 = load_single_matrix(address, 0); uint32_t x1 = load_single_matrix(address, 8); @@ -633,7 +631,7 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); return uint4::Native_vec_{x0, x1, x2, x3}; } - uint4::Native_vec_ FUNC(ldmatrix_m8n8_x4_trans_b16)(void SHARED_SPACE * address) + uint4::Native_vec_ FUNC(ldmatrix_m8n8_x4_trans_b16)(void SHARED_SPACE *address) { uint32_t x0 = load_single_matrix_trans(address, 0); uint32_t x1 = load_single_matrix_trans(address, 8); @@ -642,84 +640,103 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); return uint4::Native_vec_{x0, x1, x2, x3}; } - static inline __device__ _Float16 top16_as_fp16(uint32_t value) { + static inline __device__ _Float16 top16_as_fp16(uint32_t value) + { uint16_t half_bits = static_cast((value >> 16) & 0xFFFF); - return *reinterpret_cast<_Float16*>(&half_bits); + return *reinterpret_cast<_Float16 *>(&half_bits); } - static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) { + static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) + { uint16_t half_bits = static_cast(value & 0xFFFF); - return *reinterpret_cast<_Float16*>(&half_bits); + return *reinterpret_cast<_Float16 *>(&half_bits); } - static inline __device__ float bpermute_lane(int lane, float x) { + static inline __device__ float bpermute_lane(int lane, float x) + { return __hip_ds_bpermutef(4 * lane, x); } - static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) { + static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) + { return __hip_ds_bpermute(4 * lane, x); } - static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) { + static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) + { const unsigned lIdx = threadIdx.x; const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode) half16 aFrag; - for (int vGPR = 0; vGPR < 8; ++vGPR) { - int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2 - int cudaTID = (vGPR % 4 + lane * 4) % 32; + for (int vGPR = 0; vGPR < 8; ++vGPR) + { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2 + int cudaTID = (vGPR % 4 + lane * 4) % 32; uint32_t reg0, reg1; // Select the two consecutive elements from a_reg: - if (cudaChunk == 0) { + if (cudaChunk == 0) + { reg0 = a_reg.x; reg1 = a_reg.y; - } else { // cudaChunk==2 + } + else + { // cudaChunk==2 reg0 = a_reg.z; reg1 = a_reg.w; } uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0); uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1); uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1; - aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg); + aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg); aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg); } return aFrag; } - static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) { + static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) + { const unsigned lIdx = threadIdx.x; const int lane = lIdx % 16; half16 bFrag; - for (int vGPR = 0; vGPR < 8; ++vGPR) { - int cudaChunk = vGPR / 4; // will be 0 or 1 - int cudaTID = vGPR % 4 + (lane * 4) % 64; + for (int vGPR = 0; vGPR < 8; ++vGPR) + { + int cudaChunk = vGPR / 4; // will be 0 or 1 + int cudaTID = vGPR % 4 + (lane * 4) % 64; uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y; uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg); - if (lane < 8) { - bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg); + if (lane < 8) + { + bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg); bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg); - } else { - bFrag[2 * vGPR] = 0.0f; + } + else + { + bFrag[2 * vGPR] = 0.0f; bFrag[2 * vGPR + 1] = 0.0f; } } return bFrag; } - static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) { + static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) + { const int lIdx = (int)threadIdx.x; float8 cFrag; // Loop over the eight vector GPRs. - for (int vGPR = 0; vGPR < 8; ++vGPR) { - int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use. - int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8; - int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2; + for (int vGPR = 0; vGPR < 8; ++vGPR) + { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use. + int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8; + int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2; float ctmp0, ctmp1; - if (cudaChunk == 0) { + if (cudaChunk == 0) + { ctmp0 = bpermute_lane(cudaTID, c_reg.x); ctmp1 = bpermute_lane(cudaTID, c_reg.y); - } else { // cudaChunk == 2 + } + else + { // cudaChunk == 2 ctmp0 = bpermute_lane(cudaTID, c_reg.z); ctmp1 = bpermute_lane(cudaTID, c_reg.w); } @@ -734,41 +751,202 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); return cFrag; } - static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) { - const int lIdx = (int)threadIdx.x; - float4::Native_vec_ d_out; - - for (int cChunk = 0; cChunk < 4; ++cChunk) { - int r_vGPR = (cChunk / 2) * 4; - int add8 = (lIdx & 0x4) ? 8 : 0; - int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8; - float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]); - float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]); - float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]); - float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]); - float val; - if (lIdx < 8) { - val = d_tmp0; - } else if (lIdx < 16) { - val = d_tmp1; - } else if (lIdx < 24) { - val = d_tmp2; - } else { - val = d_tmp3; - } - if (cChunk == 0) d_out.x = val; - else if (cChunk == 1) d_out.y = val; - else if (cChunk == 2) d_out.z = val; - else d_out.w = val; + static inline __device__ std::pair select_registers_for_permlane( + uint32_t laneid, + uint32_t lower_regs[4], + uint32_t upper_regs[4], + uint32_t rotation) + { + uint32_t rotated_half_lane_id = (laneid + 16 - rotation) % 16; + bool rotated_half_lower_half = rotated_half_lane_id < 8; + if (rotated_half_lower_half) + { + uint32_t lower = lower_regs[rotated_half_lane_id % 2]; + uint32_t upper = upper_regs[rotated_half_lane_id % 2]; + return std::make_pair(lower, upper); + } + else + { + uint32_t lower = lower_regs[2 + (rotated_half_lane_id % 2)]; + uint32_t upper = upper_regs[2 + (rotated_half_lane_id % 2)]; + return std::make_pair(lower, upper); } - return d_out; } - float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + static __device__ float8 shuffle_d8(float8 registers) + { + uint32_t laneid = __zluda_ptx_impl_sreg_laneid(); + bool low_half = laneid < 16; + // We start with: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // (0, 0) (0, 1) (0, 2) (0, 3) (0, 4) (0, 5) (0, 6) (0, 7) (0, 8) (0, 9) (0, 10) (0, 11) (0, 12) (0, 13) (0, 14) (0, 15) (1, 0) (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (1, 8) (1, 9) (1, 10) (1, 11) (1, 12) (1, 13) (1, 14) (1, 15) + // (2, 0) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6) (2, 7) (2, 8) (2, 9) (2, 10) (2, 11) (2, 12) (2, 13) (2, 14) (2, 15) (3, 0) (3, 1) (3, 2) (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) (3, 8) (3, 9) (3, 10) (3, 11) (3, 12) (3, 13) (3, 14) (3, 15) + // (4, 0) (4, 1) (4, 2) (4, 3) (4, 4) (4, 5) (4, 6) (4, 7) (4, 8) (4, 9) (4, 10) (4, 11) (4, 12) (4, 13) (4, 14) (4, 15) (5, 0) (5, 1) (5, 2) (5, 3) (5, 4) (5, 5) (5, 6) (5, 7) (5, 8) (5, 9) (5, 10) (5, 11) (5, 12) (5, 13) (5, 14) (5, 15) + // (6, 0) (6, 1) (6, 2) (6, 3) (6, 4) (6, 5) (6, 6) (6, 7) (6, 8) (6, 9) (6, 10) (6, 11) (6, 12) (6, 13) (6, 14) (6, 15) (7, 0) (7, 1) (7, 2) (7, 3) (7, 4) (7, 5) (7, 6) (7, 7) (7, 8) (7, 9) (7, 10) (7, 11) (7, 12) (7, 13) (7, 14) (7, 15) + // (8, 0) (8, 1) (8, 2) (8, 3) (8, 4) (8, 5) (8, 6) (8, 7) (8, 8) (8, 9) (8, 10) (8, 11) (8, 12) (8, 13) (8, 14) (8, 15) (9, 0) (9, 1) (9, 2) (9, 3) (9, 4) (9, 5) (9, 6) (9, 7) (9, 8) (9, 9) (9, 10) (9, 11) (9, 12) (9, 13) (9, 14) (9, 15) + // (10, 0) (10, 1) (10, 2) (10, 3) (10, 4) (10, 5) (10, 6) (10, 7) (10, 8) (10, 9) (10, 10) (10, 11) (10, 12) (10, 13) (10, 14) (10, 15) (11, 0) (11, 1) (11, 2) (11, 3) (11, 4) (11, 5) (11, 6) (11, 7) (11, 8) (11, 9) (11, 10) (11, 11) (11, 12) (11, 13) (11, 14) (11, 15) + // (12, 0) (12, 1) (12, 2) (12, 3) (12, 4) (12, 5) (12, 6) (12, 7) (12, 8) (12, 9) (12, 10) (12, 11) (12, 12) (12, 13) (12, 14) (12, 15) (13, 0) (13, 1) (13, 2) (13, 3) (13, 4) (13, 5) (13, 6) (13, 7) (13, 8) (13, 9) (13, 10) (13, 11) (13, 12) (13, 13) (13, 14) (13, 15) + // (14, 0) (14, 1) (14, 2) (14, 3) (14, 4) (14, 5) (14, 6) (14, 7) (14, 8) (14, 9) (14, 10) (14, 11) (14, 12) (14, 13) (14, 14) (14, 15) (15, 0) (15, 1) (15, 2) (15, 3) (15, 4) (15, 5) (15, 6) (15, 7) (15, 8) (15, 9) (15, 10) (15, 11) (15, 12) (15, 13) (15, 14) (15, 15) + + // And the result is: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31 + // ------ ------- ------- ------- ------ ------- ------- ------- ------- -------- -------- -------- ------- -------- -------- -------- ------- -------- -------- -------- ------- -------- -------- -------- ------- -------- -------- -------- ------- -------- -------- -------- + // (0, 0) (0, 2) (0, 4) (0, 6) (1, 0) (1, 2) (1, 4) (1, 6) (2, 0) (2, 2) (2, 4) (2, 6) (3, 0) (3, 2) (3, 4) (3, 6) (4, 0) (4, 2) (4, 4) (4, 6) (5, 0) (5, 2) (5, 4) (5, 6) (6, 0) (6, 2) (6, 4) (6, 6) (7, 0) (7, 2) (7, 4) (7, 6) + // (0, 1) (0, 3) (0, 5) (0, 7) (1, 1) (1, 3) (1, 5) (1, 7) (2, 1) (2, 3) (2, 5) (2, 7) (3, 1) (3, 3) (3, 5) (3, 7) (4, 1) (4, 3) (4, 5) (4, 7) (5, 1) (5, 3) (5, 5) (5, 7) (6, 1) (6, 3) (6, 5) (6, 7) (7, 1) (7, 3) (7, 5) (7, 7) + // (8, 0) (8, 2) (8, 4) (8, 6) (9, 0) (9, 2) (9, 4) (9, 6) (10, 0) (10, 2) (10, 4) (10, 6) (11, 0) (11, 2) (11, 4) (11, 6) (12, 0) (12, 2) (12, 4) (12, 6) (13, 0) (13, 2) (13, 4) (13, 6) (14, 0) (14, 2) (14, 4) (14, 6) (15, 0) (15, 2) (15, 4) (15, 6) + // (8, 1) (8, 3) (8, 5) (8, 7) (9, 1) (9, 3) (9, 5) (9, 7) (10, 1) (10, 3) (10, 5) (10, 7) (11, 1) (11, 3) (11, 5) (11, 7) (12, 1) (12, 3) (12, 5) (12, 7) (13, 1) (13, 3) (13, 5) (13, 7) (14, 1) (14, 3) (14, 5) (14, 7) (15, 1) (15, 3) (15, 5) (15, 7) + // (0, 8) (0, 10) (0, 12) (0, 14) (1, 8) (1, 10) (1, 12) (1, 14) (2, 8) (2, 10) (2, 12) (2, 14) (3, 8) (3, 10) (3, 12) (3, 14) (4, 8) (4, 10) (4, 12) (4, 14) (5, 8) (5, 10) (5, 12) (5, 14) (6, 8) (6, 10) (6, 12) (6, 14) (7, 8) (7, 10) (7, 12) (7, 14) + // (0, 9) (0, 11) (0, 13) (0, 15) (1, 9) (1, 11) (1, 13) (1, 15) (2, 9) (2, 11) (2, 13) (2, 15) (3, 9) (3, 11) (3, 13) (3, 15) (4, 9) (4, 11) (4, 13) (4, 15) (5, 9) (5, 11) (5, 13) (5, 15) (6, 9) (6, 11) (6, 13) (6, 15) (7, 9) (7, 11) (7, 13) (7, 15) + // (8, 8) (8, 10) (8, 12) (8, 14) (9, 8) (9, 10) (9, 12) (9, 14) (10, 8) (10, 10) (10, 12) (10, 14) (11, 8) (11, 10) (11, 12) (11, 14) (12, 8) (12, 10) (12, 12) (12, 14) (13, 8) (13, 10) (13, 12) (13, 14) (14, 8) (14, 10) (14, 12) (14, 14) (15, 8) (15, 10) (15, 12) (15, 14) + // (8, 9) (8, 11) (8, 13) (8, 15) (9, 9) (9, 11) (9, 13) (9, 15) (10, 9) (10, 11) (10, 13) (10, 15) (11, 9) (11, 11) (11, 13) (11, 15) (12, 9) (12, 11) (12, 13) (12, 15) (13, 9) (13, 11) (13, 13) (13, 15) (14, 9) (14, 11) (14, 13) (14, 15) (15, 9) (15, 11) (15, 13) (15, 15) + // We could do all the manipulations with ds_bpermutes, but RDNA + // documentation says "It uses LDS hardware", which to me implies that + // it does reduce available LDS bandwidth for real LDS instructions. + // Typically, MMA operations follow loads from LDS and precede stores + // to LDS, so we do need all the LDS bandwidth we can get. Additionally, + // ds_bpermute has lower throughput than v_permlane(x) or dpp + // Approximate throughput table (measured on RDNA3 using a crappy benchmark): + // * v_permlanex16_b32 ~ 2.5 cycles + // * ds_bpermute_b32 ~ 6 cycles + // * v_mov_b32_dpp (unfolded) ~ 3 cycles + + uint32_t v0 = std::bit_cast(registers[0]); + uint32_t v1 = std::bit_cast(registers[1]); + uint32_t v2 = std::bit_cast(registers[2]); + uint32_t v3 = std::bit_cast(registers[3]); + uint32_t v4 = std::bit_cast(registers[4]); + uint32_t v5 = std::bit_cast(registers[5]); + uint32_t v6 = std::bit_cast(registers[6]); + uint32_t v7 = std::bit_cast(registers[7]); + + // Moving registers to prepare for permlanex16: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // (4, 0) (4, 1) (4, 2) (4, 3) (4, 4) (4, 5) (4, 6) (4, 7) (4, 8) (4, 9) (4, 10) (4, 11) (4, 12) (4, 13) (4, 14) (4, 15) (1, 0) (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (1, 8) (1, 9) (1, 10) (1, 11) (1, 12) (1, 13) (1, 14) (1, 15) + // (6, 0) (6, 1) (6, 2) (6, 3) (6, 4) (6, 5) (6, 6) (6, 7) (6, 8) (6, 9) (6, 10) (6, 11) (6, 12) (6, 13) (6, 14) (6, 15) (3, 0) (3, 1) (3, 2) (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) (3, 8) (3, 9) (3, 10) (3, 11) (3, 12) (3, 13) (3, 14) (3, 15) + // (0, 0) (0, 1) (0, 2) (0, 3) (0, 4) (0, 5) (0, 6) (0, 7) (0, 8) (0, 9) (0, 10) (0, 11) (0, 12) (0, 13) (0, 14) (0, 15) (5, 0) (5, 1) (5, 2) (5, 3) (5, 4) (5, 5) (5, 6) (5, 7) (5, 8) (5, 9) (5, 10) (5, 11) (5, 12) (5, 13) (5, 14) (5, 15) + // (2, 0) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6) (2, 7) (2, 8) (2, 9) (2, 10) (2, 11) (2, 12) (2, 13) (2, 14) (2, 15) (7, 0) (7, 1) (7, 2) (7, 3) (7, 4) (7, 5) (7, 6) (7, 7) (7, 8) (7, 9) (7, 10) (7, 11) (7, 12) (7, 13) (7, 14) (7, 15) + // (12, 0) (12, 1) (12, 2) (12, 3) (12, 4) (12, 5) (12, 6) (12, 7) (12, 8) (12, 9) (12, 10) (12, 11) (12, 12) (12, 13) (12, 14) (12, 15) (9, 0) (9, 1) (9, 2) (9, 3) (9, 4) (9, 5) (9, 6) (9, 7) (9, 8) (9, 9) (9, 10) (9, 11) (9, 12) (9, 13) (9, 14) (9, 15) + // (14, 0) (14, 1) (14, 2) (14, 3) (14, 4) (14, 5) (14, 6) (14, 7) (14, 8) (14, 9) (14, 10) (14, 11) (14, 12) (14, 13) (14, 14) (14, 15) (11, 0) (11, 1) (11, 2) (11, 3) (11, 4) (11, 5) (11, 6) (11, 7) (11, 8) (11, 9) (11, 10) (11, 11) (11, 12) (11, 13) (11, 14) (11, 15) + // (8, 0) (8, 1) (8, 2) (8, 3) (8, 4) (8, 5) (8, 6) (8, 7) (8, 8) (8, 9) (8, 10) (8, 11) (8, 12) (8, 13) (8, 14) (8, 15) (13, 0) (13, 1) (13, 2) (13, 3) (13, 4) (13, 5) (13, 6) (13, 7) (13, 8) (13, 9) (13, 10) (13, 11) (13, 12) (13, 13) (13, 14) (13, 15) + // (10, 0) (10, 1) (10, 2) (10, 3) (10, 4) (10, 5) (10, 6) (10, 7) (10, 8) (10, 9) (10, 10) (10, 11) (10, 12) (10, 13) (10, 14) (10, 15) (15, 0) (15, 1) (15, 2) (15, 3) (15, 4) (15, 5) (15, 6) (15, 7) (15, 8) (15, 9) (15, 10) (15, 11) (15, 12) (15, 13) (15, 14) (15, 15) + if (low_half) + { + std::swap(v0, v2); + std::swap(v1, v3); + std::swap(v4, v6); + std::swap(v5, v7); + } + + // Transfer halves with permlanex16: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // (1, 0) (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (1, 8) (1, 9) (1, 10) (1, 11) (1, 12) (1, 13) (1, 14) (1, 15) (4, 0) (4, 1) (4, 2) (4, 3) (4, 4) (4, 5) (4, 6) (4, 7) (4, 8) (4, 9) (4, 10) (4, 11) (4, 12) (4, 13) (4, 14) (4, 15) + // (3, 0) (3, 1) (3, 2) (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) (3, 8) (3, 9) (3, 10) (3, 11) (3, 12) (3, 13) (3, 14) (3, 15) (6, 0) (6, 1) (6, 2) (6, 3) (6, 4) (6, 5) (6, 6) (6, 7) (6, 8) (6, 9) (6, 10) (6, 11) (6, 12) (6, 13) (6, 14) (6, 15) + // (0, 0) (0, 1) (0, 2) (0, 3) (0, 4) (0, 5) (0, 6) (0, 7) (0, 8) (0, 9) (0, 10) (0, 11) (0, 12) (0, 13) (0, 14) (0, 15) (5, 0) (5, 1) (5, 2) (5, 3) (5, 4) (5, 5) (5, 6) (5, 7) (5, 8) (5, 9) (5, 10) (5, 11) (5, 12) (5, 13) (5, 14) (5, 15) + // (2, 0) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6) (2, 7) (2, 8) (2, 9) (2, 10) (2, 11) (2, 12) (2, 13) (2, 14) (2, 15) (7, 0) (7, 1) (7, 2) (7, 3) (7, 4) (7, 5) (7, 6) (7, 7) (7, 8) (7, 9) (7, 10) (7, 11) (7, 12) (7, 13) (7, 14) (7, 15) + // (9, 0) (9, 1) (9, 2) (9, 3) (9, 4) (9, 5) (9, 6) (9, 7) (9, 8) (9, 9) (9, 10) (9, 11) (9, 12) (9, 13) (9, 14) (9, 15) (12, 0) (12, 1) (12, 2) (12, 3) (12, 4) (12, 5) (12, 6) (12, 7) (12, 8) (12, 9) (12, 10) (12, 11) (12, 12) (12, 13) (12, 14) (12, 15) + // (11, 0) (11, 1) (11, 2) (11, 3) (11, 4) (11, 5) (11, 6) (11, 7) (11, 8) (11, 9) (11, 10) (11, 11) (11, 12) (11, 13) (11, 14) (11, 15) (14, 0) (14, 1) (14, 2) (14, 3) (14, 4) (14, 5) (14, 6) (14, 7) (14, 8) (14, 9) (14, 10) (14, 11) (14, 12) (14, 13) (14, 14) (14, 15) + // (8, 0) (8, 1) (8, 2) (8, 3) (8, 4) (8, 5) (8, 6) (8, 7) (8, 8) (8, 9) (8, 10) (8, 11) (8, 12) (8, 13) (8, 14) (8, 15) (13, 0) (13, 1) (13, 2) (13, 3) (13, 4) (13, 5) (13, 6) (13, 7) (13, 8) (13, 9) (13, 10) (13, 11) (13, 12) (13, 13) (13, 14) (13, 15) + // (10, 0) (10, 1) (10, 2) (10, 3) (10, 4) (10, 5) (10, 6) (10, 7) (10, 8) (10, 9) (10, 10) (10, 11) (10, 12) (10, 13) (10, 14) (10, 15) (15, 0) (15, 1) (15, 2) (15, 3) (15, 4) (15, 5) (15, 6) (15, 7) (15, 8) (15, 9) (15, 10) (15, 11) (15, 12) (15, 13) (15, 14) (15, 15) + constexpr uint32_t permlanex16_mask_lo = 0b0111'0110'0101'0100'0011'0010'0001'0000; + constexpr uint32_t permlanex16_mask_hi = 0b1111'1110'1101'1100'1011'1010'1001'1000; + v0 = __builtin_amdgcn_permlanex16(v0, v0, permlanex16_mask_lo, permlanex16_mask_hi, true, true); + v1 = __builtin_amdgcn_permlanex16(v1, v1, permlanex16_mask_lo, permlanex16_mask_hi, true, true); + v4 = __builtin_amdgcn_permlanex16(v4, v4, permlanex16_mask_lo, permlanex16_mask_hi, true, true); + v5 = __builtin_amdgcn_permlanex16(v5, v5, permlanex16_mask_lo, permlanex16_mask_hi, true, true); + + // Readjusting rows in the lower half to match the ordering of the upper half: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // (0, 0) (0, 1) (0, 2) (0, 3) (0, 4) (0, 5) (0, 6) (0, 7) (0, 8) (0, 9) (0, 10) (0, 11) (0, 12) (0, 13) (0, 14) (0, 15) (4, 0) (4, 1) (4, 2) (4, 3) (4, 4) (4, 5) (4, 6) (4, 7) (4, 8) (4, 9) (4, 10) (4, 11) (4, 12) (4, 13) (4, 14) (4, 15) + // (2, 0) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6) (2, 7) (2, 8) (2, 9) (2, 10) (2, 11) (2, 12) (2, 13) (2, 14) (2, 15) (6, 0) (6, 1) (6, 2) (6, 3) (6, 4) (6, 5) (6, 6) (6, 7) (6, 8) (6, 9) (6, 10) (6, 11) (6, 12) (6, 13) (6, 14) (6, 15) + // (1, 0) (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (1, 8) (1, 9) (1, 10) (1, 11) (1, 12) (1, 13) (1, 14) (1, 15) (5, 0) (5, 1) (5, 2) (5, 3) (5, 4) (5, 5) (5, 6) (5, 7) (5, 8) (5, 9) (5, 10) (5, 11) (5, 12) (5, 13) (5, 14) (5, 15) + // (3, 0) (3, 1) (3, 2) (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) (3, 8) (3, 9) (3, 10) (3, 11) (3, 12) (3, 13) (3, 14) (3, 15) (7, 0) (7, 1) (7, 2) (7, 3) (7, 4) (7, 5) (7, 6) (7, 7) (7, 8) (7, 9) (7, 10) (7, 11) (7, 12) (7, 13) (7, 14) (7, 15) + // (8, 0) (8, 1) (8, 2) (8, 3) (8, 4) (8, 5) (8, 6) (8, 7) (8, 8) (8, 9) (8, 10) (8, 11) (8, 12) (8, 13) (8, 14) (8, 15) (12, 0) (12, 1) (12, 2) (12, 3) (12, 4) (12, 5) (12, 6) (12, 7) (12, 8) (12, 9) (12, 10) (12, 11) (12, 12) (12, 13) (12, 14) (12, 15) + // (10, 0) (10, 1) (10, 2) (10, 3) (10, 4) (10, 5) (10, 6) (10, 7) (10, 8) (10, 9) (10, 10) (10, 11) (10, 12) (10, 13) (10, 14) (10, 15) (14, 0) (14, 1) (14, 2) (14, 3) (14, 4) (14, 5) (14, 6) (14, 7) (14, 8) (14, 9) (14, 10) (14, 11) (14, 12) (14, 13) (14, 14) (14, 15) + // (9, 0) (9, 1) (9, 2) (9, 3) (9, 4) (9, 5) (9, 6) (9, 7) (9, 8) (9, 9) (9, 10) (9, 11) (9, 12) (9, 13) (9, 14) (9, 15) (13, 0) (13, 1) (13, 2) (13, 3) (13, 4) (13, 5) (13, 6) (13, 7) (13, 8) (13, 9) (13, 10) (13, 11) (13, 12) (13, 13) (13, 14) (13, 15) + // (11, 0) (11, 1) (11, 2) (11, 3) (11, 4) (11, 5) (11, 6) (11, 7) (11, 8) (11, 9) (11, 10) (11, 11) (11, 12) (11, 13) (11, 14) (11, 15) (15, 0) (15, 1) (15, 2) (15, 3) (15, 4) (15, 5) (15, 6) (15, 7) (15, 8) (15, 9) (15, 10) (15, 11) (15, 12) (15, 13) (15, 14) (15, 15) + if (low_half) + { + std::swap(v0, v2); + std::swap(v1, v3); + std::swap(v4, v6); + std::swap(v5, v7); + } + + // Rotate to avoid a value required by separate permlane threads being held in the same thread, but different registers: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // [0, 0] <0, 1> [0, 2] <0, 3> [0, 4] <0, 5> [0, 6] <0, 7> {0, 8} (0, 9) {0, 10} (0, 11) {0, 12} (0, 13) {0, 14} (0, 15) // rotate 0 + // (2, 15) [2, 0] <2, 1> [2, 2] <2, 3> [2, 4] <2, 5> [2, 6] <2, 7> {2, 8} (2, 9) {2, 10} (2, 11) {2, 12} (2, 13) {2, 14} // rotate 1 + // {1, 8} (1, 9) {1, 10} (1, 11) {1, 12} (1, 13) {1, 14} (1, 15) [1, 0] <1, 1> [1, 2] <1, 3> [1, 4] <1, 5> [1, 6] <1, 7> // rotate 8 + // <3, 7> {3, 8} (3, 9) {3, 10} (3, 11) {3, 12} (3, 13) {3, 14} (3, 15) [3, 0] <3, 1> [3, 2] <3, 3> [3, 4] <3, 5> [3, 6] // rotate 9 + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // [8, 0] <8, 1> [8, 2] <8, 3> [8, 4] <8, 5> [8, 6] <8, 7> {8, 8} (8, 9) {8, 10} (8, 11) {8, 12} (8, 13) {8, 14} (8, 15) // rotate 0 + // (10, 15) [10, 0] <10, 1> [10, 2] <10, 3> [10, 4] <10, 5> [10, 6] <10, 7> {10, 8} (10, 9) {10, 10} (10, 11) {10, 12} (10, 13) {10, 14} // rotate 1 + // {9, 8} (9, 9) {9, 10} (9, 11) {9, 12} (9, 13) {9, 14} (9, 15) [9, 0] <9, 1> [9, 2] <9, 3> [9, 4] <9, 5> [9, 6] <9, 7> // rotate 8 + // <11, 7> {11, 8} (11, 9) {11, 10} (11, 11) {11, 12} (11, 13) {11, 14} (11, 15) [11, 0] <11, 1> [11, 2] <11, 3> [11, 4] <11, 5> [11, 6] // rotate 9 + v1 = std::bit_cast(__builtin_amdgcn_mov_dpp(std::bit_cast(v1), 0x121, 0xf, 0xf, true)); + v2 = std::bit_cast(__builtin_amdgcn_mov_dpp(std::bit_cast(v2), 0x128, 0xf, 0xf, true)); + v3 = std::bit_cast(__builtin_amdgcn_mov_dpp(std::bit_cast(v3), 0x129, 0xf, 0xf, true)); + v5 = std::bit_cast(__builtin_amdgcn_mov_dpp(std::bit_cast(v5), 0x121, 0xf, 0xf, true)); + v6 = std::bit_cast(__builtin_amdgcn_mov_dpp(std::bit_cast(v6), 0x128, 0xf, 0xf, true)); + v7 = std::bit_cast(__builtin_amdgcn_mov_dpp(std::bit_cast(v7), 0x129, 0xf, 0xf, true)); + + // Move values into correct registers: + // T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 + // ------- ------- ------- ------- ------- ------- ------- ------- ------- ------- -------- -------- -------- -------- -------- -------- + // [0, 0] [2, 0] [0, 2] [2, 2] [0, 4] [2, 4] [0, 6] [2, 6] [1, 0] [3, 0] [1, 2] [3, 2] [1, 4] [3, 4] [1, 6] [3, 6] + // <3, 7> <0, 1> <2, 1> <0, 3> <2, 3> <0, 5> <2, 5> <0, 7> <2, 7> <1, 1> <3, 1> <1, 3> <3, 3> <1, 5> <3, 5> <1, 7> + // [8, 0] [10, 0] [8, 2] [10, 2] [8, 4] [10, 4] [8, 6] [10, 6] [9, 0] [11, 0] [9, 2] [11, 2] [9, 4] [11, 4] [9, 6] [11, 6] + // <11, 7> <8, 1> <10, 1> <8, 3> <10, 3> <8, 5> <10, 5> <8, 7> <10, 7> <9, 1> <11, 1> <9, 3> <11, 3> <9, 5> <11, 5> <9, 7> + // {1, 8} {3, 8} {1, 10} {3, 10} {1, 12} {3, 12} {1, 14} {3, 14} {0, 8} {2, 8} {0, 10} {2, 10} {0, 12} {2, 12} {0, 14} {2, 14} + // (2, 15) (1, 9) (3, 9) (1, 11) (3, 11) (1, 13) (3, 13) (1, 15) (3, 15) (0, 9) (2, 9) (0, 11) (2, 11) (0, 13) (2, 13) (0, 15) + // {9, 8} {11, 8} {9, 10} {11, 10} {9, 12} {11, 12} {9, 14} {11, 14} {8, 8} {10, 8} {8, 10} {10, 10} {8, 12} {10, 12} {8, 14} {10, 14} + // (10, 15) (9, 9) (11, 9) (9, 11) (11, 11) (9, 13) (11, 13) (9, 15) (11, 15) (8, 9) (10, 9) (8, 11) (10, 11) (8, 13) (10, 13) (8, 15) + uint32_t lower_regs[4] = {v0, v1, v2, v3}; + uint32_t upper_regs[4] = {v4, v5, v6, v7}; + std::tie(v0, v2) = select_registers_for_permlane(laneid, lower_regs, upper_regs, 0); + std::tie(v1, v3) = select_registers_for_permlane(laneid, lower_regs, upper_regs, 1); + std::tie(v4, v6) = select_registers_for_permlane(laneid, lower_regs, upper_regs, 8); + std::tie(v5, v7) = select_registers_for_permlane(laneid, lower_regs, upper_regs, 9); + + // Do permlane operations to finalize shuffle: + v0 = __builtin_amdgcn_permlane16(v0, v0, 0xeca86420, 0xfdb97531, true, true); + v1 = __builtin_amdgcn_permlane16(v1, v1, 0xfdb97531, 0x0eca8642, true, true); + v2 = __builtin_amdgcn_permlane16(v2, v2, 0xeca86420, 0xfdb97531, true, true); + v3 = __builtin_amdgcn_permlane16(v3, v3, 0xfdb97531, 0x0eca8642, true, true); + v4 = __builtin_amdgcn_permlane16(v4, v4, 0x6420eca8, 0x7531fdb9, true, true); + v5 = __builtin_amdgcn_permlane16(v5, v5, 0x7531fdb9, 0x6420eca8, true, true); + v6 = __builtin_amdgcn_permlane16(v6, v6, 0x6420eca8, 0x7531fdb9, true, true); + v7 = __builtin_amdgcn_permlane16(v7, v7, 0x7531fdb9, 0x6420eca8, true, true); + + return float8{ + std::bit_cast(v0), + std::bit_cast(v1), + std::bit_cast(v2), + std::bit_cast(v3), + std::bit_cast(v4), + std::bit_cast(v5), + std::bit_cast(v6), + std::bit_cast(v7)}; + } + + static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) + { + auto result = shuffle_d8(dFrag); + return float4::Native_vec_{result.x, result.y, result.z, result.w}; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) + { // Reshuffle from Nvidia-like register layout to AMD layout: - half16 aFrag = shuffle_a(a_reg); - half16 bFrag = shuffle_b(b_reg); - float8 cFrag = shuffle_c(c_reg); + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); // Call the (built‐in) 16x16 MMA instruction. It returns a float8. float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag); @@ -779,11 +957,12 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); return d_out; } - float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) + { // Reshuffle from Nvidia-like register layout to AMD layout: - half16 aFrag = shuffle_a(a_reg); - half16 bFrag = shuffle_b(b_reg); - float8 cFrag = shuffle_c(c_reg); + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); // Call the (built‐in) 16x16 MMA instruction. It returns a float8. float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag); diff --git a/zluda/src/impl/event.rs b/zluda/src/impl/event.rs index 4bb40df09..573048955 100644 --- a/zluda/src/impl/event.rs +++ b/zluda/src/impl/event.rs @@ -20,3 +20,7 @@ pub(crate) unsafe fn record(event: hipEvent_t, stream: hipStream_t) -> hipError_ pub(crate) unsafe fn synchronize(event: hipEvent_t) -> hipError_t { hipEventSynchronize(event) } + +pub(crate) unsafe fn elapsed_time(ms: *mut f32, start: hipEvent_t, end: hipEvent_t) -> hipError_t { + hipEventElapsedTime(ms, start, end) +} diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index ba2f3ab88..1616160c4 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -102,6 +102,7 @@ cuda_macros::cuda_function_declarations!( cuEventCreate, cuEventDestroy_v2, cuEventQuery, + cuEventElapsedTime, cuEventRecord, cuEventSynchronize, cuFuncGetAttribute,