diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 13e7e1fc..22ae851b 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 13e7e1fce1a8d332eea563c14130136ef0533b16 +Subproject commit 22ae851b317b394ee7b326df95029bc5250fddc5 diff --git a/software/libgemmini b/software/libgemmini index 4be22079..cf27fa93 160000 --- a/software/libgemmini +++ b/software/libgemmini @@ -1 +1 @@ -Subproject commit 4be220794cfdb834e8ecc2ee7becdf8632cc268c +Subproject commit cf27fa93c23c841a1b784b843347f8a99c6495a0 diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala index bc87ae10..a46a4576 100644 --- a/src/main/scala/gemmini/LoopConv.scala +++ b/src/main/scala/gemmini/LoopConv.scala @@ -16,6 +16,9 @@ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_b val in_channels = UInt(large_iterator_bitwidth.W) val out_channels = UInt(large_iterator_bitwidth.W) val out_dim = UInt(large_iterator_bitwidth.W) + val out_stride = UInt(large_iterator_bitwidth.W) //stride for output activation + val in_stride = UInt(large_iterator_bitwidth.W) //stride for input activation + val weight_stride = UInt(large_iterator_bitwidth.W) //stride for weight val pool_out_dim = UInt(small_iterator_bitwidth.W) val stride = UInt(tiny_iterator_bitwidth.W) val padding = UInt(tiny_iterator_bitwidth.W) @@ -272,11 +275,11 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw val icol_padded = icol +& undilated(lpad).zext val is_zeros = irow < 0.S || irow >= irows_unpadded.zext || icol < 0.S || icol >= icols_unpadded.zext - val dram_stride = Mux(req.trans_input_3120, batch_size * (input_w/8).U, in_channels * (input_w/8).U) + val dram_stride = Mux(req.trans_input_3120, batch_size * (input_w/8).U, in_stride * (input_w/8).U) // Addresses val dram_offset = Mux(req.trans_input_3120, (((ich * in_dim * in_dim +& irow*in_dim +& icol) * batches +& b) * (input_w/8).U).asUInt, - (((b * in_dim * in_dim +& irow*in_dim +& icol) * in_channels +& ich) * (input_w/8).U).asUInt) + (((b * in_dim * in_dim +& irow*in_dim +& icol) * in_stride +& ich) * (input_w/8).U).asUInt) val dram_addr = Mux(is_zeros, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset)) val spad_addr = Mux(req.trans_input_3120, // To prevent Verilator errors, we replace some "/ block_size.U" calls here with ">> log2Up(block_size)" @@ -333,7 +336,7 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw io.idle := state === idle && !command_p.io.busy io.loop_id := req.loop_id - command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop + command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop && (req.dram_addr =/= 0.U) command_p.io.in.bits.cmd := Mux(state === config, config_cmd, mvin_cmd) command_p.io.in.bits.dram_addr := dram_addr command_p.io.in.bits.spad_addr := spad_addr @@ -355,7 +358,9 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw } // Sending outputs - when(command_p.io.in.fire) { + when(req.dram_addr === 0.U){ + state := idle + }.elsewhen(command_p.io.in.fire) { when (state === config) { state := ld }.otherwise { @@ -442,7 +447,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit out_channels_per_bank * kcols * krows * kchs) val addr_start = req.addr_end - B_rows - val dram_stride = MuxCase(out_channels, Seq( + val dram_stride = MuxCase(weight_stride, Seq( req.dw -> 1.U, req.trans_weight_1203 -> (kernel_dim * kernel_dim * out_channels), req.trans_weight_0132 -> in_channels @@ -455,7 +460,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit val kch = Reg(UInt(large_iterator_bitwidth.W)) // Addresses - val dram_offset = MuxCase(((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * out_channels +& och) * (input_w/8).U, Seq( + val dram_offset = MuxCase(((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * weight_stride +& och) * (input_w/8).U, Seq( req.dw -> (krow * kernel_dim +& kcol) * (input_w/8).U, req.trans_weight_1203 -> (((kch*kernel_dim*kernel_dim +& krow*kernel_dim +& kcol) * out_channels +& och) * (input_w/8).U), req.trans_weight_0132 -> (((krow*kernel_dim*out_channels +& kcol*out_channels +& och) * in_channels +& kch) * (input_w/8).U) @@ -512,7 +517,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit io.idle := state === idle && !command_p.io.busy io.loop_id := req.loop_id - command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop + command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop && (req.dram_addr =/= 0.U) command_p.io.in.bits.cmd := Mux(state === config, config_cmd, mvin_cmd) command_p.io.in.bits.dram_addr := dram_addr command_p.io.in.bits.spad_addr := spad_addr @@ -534,7 +539,9 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit } // Sending outputs - when(command_p.io.in.fire) { + when(req.dram_addr === 0.U){ + state := idle + }.elsewhen(command_p.io.in.fire) { when (state === config) { state := ld }.otherwise { @@ -880,11 +887,11 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: // Addresses val dram_offset = Mux(req.trans_output_1203, ((orow*out_dim*batch_size +& ocol*batch_size +& b) * out_channels +& och) * (input_w/8).U, - ((b*out_dim*out_dim +& orow*out_dim +& ocol) * out_channels +& och) * (input_w/8).U) + ((b*out_dim*out_dim +& orow*out_dim +& ocol) * out_stride +& och) * (input_w/8).U) val dram_addr = req.dram_addr + LoopConv.castDramOffset(dram_offset) val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol - val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels + och) * (input_w/8).U + val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_stride + och) * (input_w/8).U val pool_spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols // Sizes @@ -933,7 +940,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: val pre_pool_config_cmd_rs2 = Wire(config_mvout_rs2_t.cloneType) pre_pool_config_cmd_rs2 := DontCare pre_pool_config_cmd_rs2.acc_scale := ACC_SCALE_NO_CHANGE - pre_pool_config_cmd_rs2.stride := out_channels * (input_w / 8).U + pre_pool_config_cmd_rs2.stride := out_stride * (input_w / 8).U pre_pool_config_cmd.rs2 := pre_pool_config_cmd_rs2.asUInt val post_pool_config_cmd = Wire(new RoCCCommand) @@ -949,7 +956,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: val post_pool_config_cmd_rs2 = Wire(config_mvout_rs2_t.cloneType) post_pool_config_cmd_rs2 := DontCare post_pool_config_cmd_rs2.acc_scale := ACC_SCALE_NO_CHANGE - post_pool_config_cmd_rs2.stride := out_channels * (input_w / 8).U + post_pool_config_cmd_rs2.stride := out_stride * (input_w / 8).U post_pool_config_cmd.rs2 := post_pool_config_cmd_rs2.asUInt val pool_cmd = Wire(new RoCCCommand) @@ -1070,6 +1077,8 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s val dw = Bool() val max_pixels_per_row = UInt(small_iterator_bitwidth.W) + val a_ex_spad_id = UInt(2.W) + val b_ex_spad_id = UInt(2.W) val configured = Bool() @@ -1306,11 +1315,14 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: is (LOOP_CONV_WS_CONFIG_4) { loop_being_configured.inner_bounds.orows := cmd.bits.cmd.rs1(63, 48) loop_being_configured.inner_bounds.prad := cmd.bits.cmd.rs1(47, 32) - loop_being_configured.inner_bounds.pupad := cmd.bits.cmd.rs1(31, 16) - loop_being_configured.inner_bounds.pdpad := cmd.bits.cmd.rs1(15, 0) + loop_being_configured.inner_bounds.pupad := cmd.bits.cmd.rs1(31, 21) + loop_being_configured.inner_bounds.pdpad := cmd.bits.cmd.rs1(20, 10) + loop_being_configured.outer_bounds.kernel_dilation := cmd.bits.cmd.rs1(9, 0) loop_being_configured.inner_bounds.ocols := cmd.bits.cmd.rs2(15, 0) - loop_being_configured.outer_bounds.kernel_dilation := cmd.bits.cmd.rs2(31, 16) + loop_being_configured.outer_bounds.in_stride := cmd.bits.cmd.rs2(63, 48) + loop_being_configured.outer_bounds.weight_stride := cmd.bits.cmd.rs2(47, 32) + loop_being_configured.outer_bounds.out_stride := cmd.bits.cmd.rs2(31, 16) } is (LOOP_CONV_WS_CONFIG_5) { @@ -1334,6 +1346,9 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: !has_first_layer_optimizations.B || config_max_pixels_per_row === 0.U, 1.U, config_max_pixels_per_row) + loop_being_configured.a_ex_spad_id := cmd.bits.cmd.rs1(19, 18) + loop_being_configured.b_ex_spad_id := cmd.bits.cmd.rs1(17, 16) + loop_being_configured.wrot180 := has_training_convs.B && cmd.bits.cmd.rs1(1) loop_being_configured.input_dilated := has_training_convs.B && cmd.bits.cmd.rs2(2) loop_being_configured.trans_output_1203 := has_training_convs.B && cmd.bits.cmd.rs1(2) @@ -1387,7 +1402,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: ld_input.io.req.bits.outer_bounds := loop_requesting_ld_input.outer_bounds ld_input.io.req.bits.inner_bounds := loop_requesting_ld_input.inner_bounds ld_input.io.req.bits.derived_params := loop_requesting_ld_input.derived_params() - ld_input.io.req.bits.addr_start := loop_requesting_ld_input.a_addr_start + ld_input.io.req.bits.addr_start := Mux(loop_requesting_ld_input.a_ex_spad_id === 0.U, loop_requesting_ld_input.a_addr_start, (loop_requesting_ld_input.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U) ld_input.io.req.bits.dram_addr := loop_requesting_ld_input.input_dram_addr ld_input.io.req.bits.downsample := loop_requesting_ld_input.downsample ld_input.io.req.bits.max_pixels_per_row := loop_requesting_ld_input.max_pixels_per_row @@ -1407,7 +1422,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: ld_weights.io.req.bits.outer_bounds := loop_requesting_ld_weights.outer_bounds ld_weights.io.req.bits.inner_bounds := loop_requesting_ld_weights.inner_bounds ld_weights.io.req.bits.derived_params := loop_requesting_ld_weights.derived_params() - ld_weights.io.req.bits.addr_end := loop_requesting_ld_weights.b_addr_end + ld_weights.io.req.bits.addr_end := Mux(loop_requesting_ld_weights.b_ex_spad_id === 0.U, loop_requesting_ld_weights.b_addr_end, (loop_requesting_ld_weights.b_ex_spad_id) * (max_addr / concurrent_loops).U) ld_weights.io.req.bits.dram_addr := loop_requesting_ld_weights.weights_dram_addr ld_weights.io.req.bits.trans_weight_1203 := loop_requesting_ld_weights.trans_weight_1203 ld_weights.io.req.bits.trans_weight_0132 := loop_requesting_ld_weights.trans_weight_0132 @@ -1426,8 +1441,8 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: ex.io.req.bits.outer_bounds := loop_requesting_ex.outer_bounds ex.io.req.bits.inner_bounds := loop_requesting_ex.inner_bounds ex.io.req.bits.derived_params := loop_requesting_ex.derived_params() - ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start - ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end + ex.io.req.bits.a_addr_start := Mux(loop_requesting_ex.a_ex_spad_id === 0.U, loop_requesting_ex.a_addr_start, (loop_requesting_ex.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U) + ex.io.req.bits.b_addr_end := Mux(loop_requesting_ex.b_ex_spad_id === 0.U, loop_requesting_ex.b_addr_end, (loop_requesting_ex.b_ex_spad_id) * (max_addr / concurrent_loops).U) ex.io.req.bits.c_addr_start := ex_c_addr_start ex.io.req.bits.wrot180 := loop_requesting_ex.wrot180 ex.io.req.bits.downsample := loop_requesting_ex.downsample diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index c9e6fed3..e53d8363 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -86,12 +86,14 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.k := k io.idle := state === idle - io.cmd.valid := state =/= idle && !io.rob_overloaded + io.cmd.valid := state =/= idle && !io.rob_overloaded && req.dram_addr =/= 0.U io.cmd.bits := mvin_cmd io.loop_id := req.loop_id - when (io.cmd.fire) { + when(req.dram_addr === 0.U){ + state := idle + }.elsewhen(io.cmd.fire) { // The order here is k, j, i val i_blocks = Mux(req.transpose, max_blocks, 1.U) val k_blocks = Mux(req.transpose, 1.U, max_blocks) @@ -194,12 +196,14 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.j := j io.idle := state === idle - io.cmd.valid := state =/= idle && !io.rob_overloaded + io.cmd.valid := state =/= idle && !io.rob_overloaded && req.dram_addr =/= 0.U io.cmd.bits := mvin_cmd io.loop_id := req.loop_id - when (io.cmd.fire) { + when(req.dram_addr === 0.U){ + state := idle + }.elsewhen(io.cmd.fire) { // The order here is k, j, i val j_blocks = Mux(req.transpose, 1.U, max_blocks) val k_blocks = Mux(req.transpose, max_blocks, 1.U) @@ -698,6 +702,8 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val val full_c = Bool() val ex_accumulate = Bool() + val a_ex_spad_id = UInt(2.W) + val b_ex_spad_id = UInt(2.W) val configured = Bool() val running = Bool() @@ -896,6 +902,8 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size loop_being_configured.low_d := cmd.bits.cmd.rs1(2) loop_being_configured.act := cmd.bits.cmd.rs1(8+Activation.bitwidth-1, 8) // TODO magic numbers + loop_being_configured.a_ex_spad_id := cmd.bits.cmd.rs1(19, 18) + loop_being_configured.b_ex_spad_id := cmd.bits.cmd.rs1(17, 16) loop_being_configured.a_transpose := cmd.bits.cmd.rs2(0) loop_being_configured.b_transpose := cmd.bits.cmd.rs2(1) @@ -920,7 +928,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size ldA.io.req.bits.dram_addr := loop_requesting_ldA.a_dram_addr ldA.io.req.bits.dram_stride := loop_requesting_ldA.a_dram_stride ldA.io.req.bits.transpose := loop_requesting_ldA.a_transpose - ldA.io.req.bits.addr_start := loop_requesting_ldA.a_addr_start + ldA.io.req.bits.addr_start := Mux(loop_requesting_ldA.a_ex_spad_id === 0.U, loop_requesting_ldA.a_addr_start, (loop_requesting_ldA.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U) ldA.io.req.bits.loop_id := loop_requesting_ldA_id ldA.io.req.valid := !loop_requesting_ldA.lda_started && loop_requesting_ldA.configured @@ -939,7 +947,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size ldB.io.req.bits.dram_addr := loop_requesting_ldB.b_dram_addr ldB.io.req.bits.dram_stride := loop_requesting_ldB.b_dram_stride ldB.io.req.bits.transpose := loop_requesting_ldB.b_transpose - ldB.io.req.bits.addr_end := loop_requesting_ldB.b_addr_end + ldB.io.req.bits.addr_end := Mux(loop_requesting_ldB.b_ex_spad_id === 0.U, loop_requesting_ldB.b_addr_end, (loop_requesting_ldB.b_ex_spad_id) * (max_addr / concurrent_loops).U) ldB.io.req.bits.loop_id := loop_requesting_ldB_id ldB.io.req.valid := !loop_requesting_ldB.ldb_started && loop_requesting_ldB.configured @@ -958,8 +966,8 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size ex.io.req.bits.pad_k := loop_requesting_ex.pad_k ex.io.req.bits.pad_i := loop_requesting_ex.pad_i ex.io.req.bits.accumulate := loop_requesting_ex.ex_accumulate - ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start - ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end + ex.io.req.bits.a_addr_start := Mux(loop_requesting_ex.a_ex_spad_id === 0.U, loop_requesting_ex.a_addr_start, (loop_requesting_ex.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U) + ex.io.req.bits.b_addr_end := Mux(loop_requesting_ex.b_ex_spad_id === 0.U, loop_requesting_ex.b_addr_end, (loop_requesting_ex.b_ex_spad_id) * (max_addr / concurrent_loops).U) ex.io.req.bits.a_tranpose := loop_requesting_ex.a_transpose ex.io.req.bits.b_tranpose := loop_requesting_ex.b_transpose ex.io.req.bits.c_addr_start := ex_c_addr_start