diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 000d2fd3..6c45e1bd 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 000d2fd3e472103cb2a2c91e3d0afedc85b3738b +Subproject commit 6c45e1bd08da4caaaf01462c65da270f5d6b0bd2 diff --git a/src/main/scala/gemmini/Configs.scala b/src/main/scala/gemmini/Configs.scala index cd258650..6bdb6f6c 100644 --- a/src/main/scala/gemmini/Configs.scala +++ b/src/main/scala/gemmini/Configs.scala @@ -146,6 +146,12 @@ object GemminiConfigs { hardcode_d_to_garbage_addr = false, mesh_output_delay = 1, + + ld_ooo = false, + ex_ooo = true, + st_ooo = true, + + use_preload_filter = true, ) val chipConfig = defaultConfig.copy(sp_capacity=CapacityInKilobytes(64), acc_capacity=CapacityInKilobytes(32), dataflow=Dataflow.WS, @@ -160,6 +166,24 @@ object GemminiConfigs { ) val leanConfig = defaultConfig.copy(dataflow=Dataflow.WS, max_in_flight_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true) + + val synthesize_for_rob_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true) // Module ROB + val synthesize_for_rob_in_order = leanConfig.copy(ld_ooo = false, ex_ooo = false, st_ooo = false, lean_ooo_rob = false) // Module ROB + + val synthesize_for_microthreads_coarse_16_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 16, ex_fine_grained_interleaving = false) // Module LoopMatmul + val synthesize_for_microthreads_coarse_8_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 8, ex_fine_grained_interleaving = false) // Module LoopMatmul + val synthesize_for_microthreads_coarse_4_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 4, ex_fine_grained_interleaving = false) // Module LoopMatmul + val synthesize_for_microthreads_coarse_2_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 2, ex_fine_grained_interleaving = false) // Module LoopMatmul + + val synthesize_for_microthreads_fine_16_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 16, ex_fine_grained_interleaving = true) // Module LoopMatmul + val synthesize_for_microthreads_fine_8_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 8, ex_fine_grained_interleaving = true) // Module LoopMatmul + val synthesize_for_microthreads_fine_4_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 4, ex_fine_grained_interleaving = true) // Module LoopMatmul + val synthesize_for_microthreads_fine_2_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, lean_ooo_rob = true, ex_total_k_portions = 2, ex_fine_grained_interleaving = true) // Module LoopMatmul + + val synthesize_for_microthreads_1_in_order = leanConfig.copy(ld_ooo = false, ex_ooo = false, st_ooo = false, lean_ooo_rob = true, ex_total_k_portions = 1, ex_fine_grained_interleaving = false) // Module LoopMatmul + + val synthesize_for_weightA_ooo = leanConfig.copy(ld_ooo = false, ex_ooo = true, st_ooo = true, staticWeightAEnabled = true, lean_weightA = true) // Module WeightedArbiter + val synthesize_for_weightA_in_order = leanConfig.copy(ld_ooo = false, ex_ooo = false, st_ooo = false, staticWeightAEnabled = false, lean_weightA = false) // Module WeightedArbiterr } /** diff --git a/src/main/scala/gemmini/ConfigsFP.scala b/src/main/scala/gemmini/ConfigsFP.scala index 2762c93d..41954b8b 100644 --- a/src/main/scala/gemmini/ConfigsFP.scala +++ b/src/main/scala/gemmini/ConfigsFP.scala @@ -72,6 +72,12 @@ object GemminiFPConfigs { hardcode_d_to_garbage_addr = false, mesh_output_delay = 0, + + ld_ooo = false, + ex_ooo = false, + st_ooo = false, + + use_preload_filter = true, ) //FP32 Single Precision Configuration diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index cc28697d..46b3ff79 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -17,6 +17,16 @@ class GemminiCmd(rob_entries: Int)(implicit p: Parameters) extends Bundle { val cmd = new RoCCCommand val rob_id = UDValid(UInt(log2Up(rob_entries).W)) + val i = UInt(16.W) // TODO magic numbers + val j = UInt(16.W) // TODO magic numbers + val k = UInt(16.W) // TODO magic numbers + val max_i = UInt(16.W) // TODO magic numbers + val max_j = UInt(16.W) // TODO magic numbers + val max_k = UInt(16.W) // TODO magic numbers + val use_iterators = Bool() + + val ex_k_portion = UInt(8.W) // TODO magic numbers + override def cloneType: this.type = new GemminiCmd(rob_entries).asInstanceOf[this.type] } @@ -105,7 +115,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] */ // Incoming commands and ROB - val rob = Module(new ROB(outer.config, new RoCCCommand)) + val rob = Module(new ROB(outer.config, new RoCCCommand, new GemminiCmd(rob_entries))) val raw_cmd = Queue(io.cmd) @@ -123,9 +133,10 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] // val (unrolled_cmd, loop_matmul_unroller_busy) = LoopMatmul(unrolled_cmd_after_conv, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, - val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(conv_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, - meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, - inputType.getWidth, accType.getWidth, dma_maxbytes) + val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(conv_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, rob.io.ex_k_portion_utilizations, + meshRows*tileRows, coreMaxAddrBits, rob_entries, rob_full_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, + inputType.getWidth, accType.getWidth, dma_maxbytes, new GemminiCmd(rob_entries), ex_total_k_portions, ex_fine_grained_interleaving, local_addr_t, lean_weightA, lean_ooo_rob, + staticWeightAEnabled) val unrolled_cmd = Queue(loop_cmd) unrolled_cmd.ready := false.B @@ -170,13 +181,11 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] tiler.io.issue.load.ready := false.B tiler.io.issue.store.ready := false.B tiler.io.issue.exec.ready := false.B - */ rob.io.issue.ld.ready := false.B rob.io.issue.st.ready := false.B rob.io.issue.ex.ready := false.B - /* when (is_cisc_mode) { load_controller.io.cmd <> tiler.io.issue.load store_controller.io.cmd <> tiler.io.issue.store @@ -203,23 +212,28 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] } */ - load_controller.io.cmd.valid := rob.io.issue.ld.valid - rob.io.issue.ld.ready := load_controller.io.cmd.ready - load_controller.io.cmd.bits.cmd := rob.io.issue.ld.cmd - load_controller.io.cmd.bits.cmd.inst.funct := rob.io.issue.ld.cmd.inst.funct - load_controller.io.cmd.bits.rob_id.push(rob.io.issue.ld.rob_id) + val (rob_issue_ld, rob_issue_ex) = PreloadFilter(outer.config, new RoCCCommand, rob.io.issue.ld, rob.io.issue.ex) + + load_controller.io.cmd.valid := rob_issue_ld.valid + rob_issue_ld.ready := load_controller.io.cmd.ready + load_controller.io.cmd.bits := DontCare + load_controller.io.cmd.bits.cmd := rob_issue_ld.cmd + load_controller.io.cmd.bits.cmd.inst.funct := rob_issue_ld.cmd.inst.funct + load_controller.io.cmd.bits.rob_id.push(rob_issue_ld.rob_id) store_controller.io.cmd.valid := rob.io.issue.st.valid rob.io.issue.st.ready := store_controller.io.cmd.ready + store_controller.io.cmd.bits := DontCare store_controller.io.cmd.bits.cmd := rob.io.issue.st.cmd store_controller.io.cmd.bits.cmd.inst.funct := rob.io.issue.st.cmd.inst.funct store_controller.io.cmd.bits.rob_id.push(rob.io.issue.st.rob_id) - ex_controller.io.cmd.valid := rob.io.issue.ex.valid - rob.io.issue.ex.ready := ex_controller.io.cmd.ready - ex_controller.io.cmd.bits.cmd := rob.io.issue.ex.cmd - ex_controller.io.cmd.bits.cmd.inst.funct := rob.io.issue.ex.cmd.inst.funct - ex_controller.io.cmd.bits.rob_id.push(rob.io.issue.ex.rob_id) + ex_controller.io.cmd.valid := rob_issue_ex.valid + rob_issue_ex.ready := ex_controller.io.cmd.ready + ex_controller.io.cmd.bits := DontCare + ex_controller.io.cmd.bits.cmd := rob_issue_ex.cmd + ex_controller.io.cmd.bits.cmd.inst.funct := rob_issue_ex.cmd.inst.funct + ex_controller.io.cmd.bits.rob_id.push(rob_issue_ex.rob_id) // Wire up scratchpad to controllers spad.module.io.dma.read <> load_controller.io.dma @@ -353,7 +367,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] // val config_cmd_type = cmd.bits.rs1(1,0) // TODO magic numbers //val funct = unrolled_cmd.bits.inst.funct - val risc_funct = unrolled_cmd.bits.inst.funct + val risc_funct = unrolled_cmd.bits.cmd.inst.funct val is_flush = risc_funct === FLUSH_CMD /* @@ -365,7 +379,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] when (is_flush) { // val skip = compressed_cmd.bits.rs1(0) - val skip = unrolled_cmd.bits.rs1(0) + val skip = unrolled_cmd.bits.cmd.rs1(0) tlb.io.exp.flush_skip := skip tlb.io.exp.flush_retry := !skip diff --git a/src/main/scala/gemmini/DMACommandTracker.scala b/src/main/scala/gemmini/DMACommandTracker.scala index a687e918..2b85472e 100644 --- a/src/main/scala/gemmini/DMACommandTracker.scala +++ b/src/main/scala/gemmini/DMACommandTracker.scala @@ -6,7 +6,7 @@ import chisel3.util._ // This module is meant to go inside the Load controller, where it can track which commands are currently // in flight and which are completed -class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => T) extends Module { +class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => T, prng_seed: Int, proportion_of_slow_accesses_out_of_128: Int, stall_delay: Int) extends Module { def cmd_id_t = UInt((log2Ceil(nCmds) max 1).W) val io = IO(new Bundle { @@ -56,6 +56,8 @@ class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => val tag = tag_t.cloneType val bytes_left = UInt(log2Up(maxBytes+1).W) + val stall_cycles = UInt(32.W) // TODO magic number + def init(dummy: Int = 0): Unit = { valid := false.B } @@ -73,9 +75,9 @@ class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => io.busy := cmd_valids.reduce(_ || _) val cmd_completed_id = MuxCase(0.U, cmds.zipWithIndex.map { case (cmd, i) => - (cmd.valid && cmd.bytes_left === 0.U) -> i.U + (cmd.valid && cmd.bytes_left === 0.U && cmd.stall_cycles === 0.U) -> i.U }) - io.cmd_completed.valid := cmds.map(cmd => cmd.valid && cmd.bytes_left === 0.U).reduce(_ || _) + io.cmd_completed.valid := cmds.map(cmd => cmd.valid && cmd.bytes_left === 0.U && cmd.stall_cycles === 0.U).reduce(_ || _) io.cmd_completed.bits.cmd_id := cmd_completed_id io.cmd_completed.bits.tag := cmds(cmd_completed_id).tag @@ -83,6 +85,11 @@ class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => cmds(next_empty_alloc).valid := true.B cmds(next_empty_alloc).tag := io.alloc.bits.tag cmds(next_empty_alloc).bytes_left := io.alloc.bits.bytes_to_read + + val random_number = random.GaloisLFSR.maxPeriod(width=8, seed=Some(prng_seed)) + + cmds(next_empty_alloc).stall_cycles := Mux(random_number < proportion_of_slow_accesses_out_of_128.U, + stall_delay.U, 0.U) } when (io.request_returned.fire()) { @@ -97,6 +104,12 @@ class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => cmds(io.cmd_completed.bits.cmd_id).valid := false.B } + cmds.foreach { cmd => + when (cmd.valid && cmd.bytes_left === 0.U && cmd.stall_cycles > 0.U) { + cmd.stall_cycles := cmd.stall_cycles - 1.U + } + } + when (reset.asBool()) { cmds.foreach(_.init()) } diff --git a/src/main/scala/gemmini/DSEConfigs.scala b/src/main/scala/gemmini/DSEConfigs.scala index b71477bb..e2897c90 100644 --- a/src/main/scala/gemmini/DSEConfigs.scala +++ b/src/main/scala/gemmini/DSEConfigs.scala @@ -74,6 +74,12 @@ object DSEBaseConfig { max_in_flight_reqs = 16, mesh_output_delay = 1, + + ld_ooo = false, + ex_ooo = false, + st_ooo = false, + + use_preload_filter = true, ) } diff --git a/src/main/scala/gemmini/ExIUnroller.scala b/src/main/scala/gemmini/ExIUnroller.scala new file mode 100644 index 00000000..657bcc16 --- /dev/null +++ b/src/main/scala/gemmini/ExIUnroller.scala @@ -0,0 +1,93 @@ +package gemmini + +import chisel3._ +import chisel3.util._ +import chisel3.experimental._ +import freechips.rocketchip.tile.RoCCCommand +import chipsalliance.rocketchip.config.Parameters +import GemminiISA._ +import Util._ + +class ExIUnroller[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V])(implicit p: Parameters) extends Module { + import config._ + + val block_rows = meshRows * tileRows + val block_cols = meshColumns * tileColumns + + val io = IO(new Bundle { + val in = Flipped(Decoupled(new GemminiCmd(rob_entries))) + val out = Decoupled(new GemminiCmd(rob_entries)) + }) + + object State extends ChiselEnum { + val preload, compute = Value + } + import State._ + val state = RegInit(preload) + + val (q, len) = MultiHeadedQueue(io.in, entries=3, heads=2, maxpop=2) + + val first_cmd_is_preload = q.bits(0).cmd.inst.funct === PRELOAD_CMD + + val total_I = q.bits(0).cmd.rs2(63, 48).asUInt() // This is only valid if first_cmd_is_preload === true.B // TODO magic numbers + val I_sent = RegInit(0.U(16.W)) // TODO magic number + val last_send = total_I -& I_sent <= block_rows.U + + val must_unroll = first_cmd_is_preload && total_I > block_rows.U + + val J_blocks = Cat(q.bits(0).cmd.inst.opcode, q.bits(0).cmd.inst.rs1, q.bits(0).cmd.inst.rs2, q.bits(0).cmd.inst.rd) + val K_blocks = Cat(q.bits(1).cmd.inst.opcode, q.bits(1).cmd.inst.rs1, q.bits(1).cmd.inst.rs2, q.bits(1).cmd.inst.rd) + val I_block = I_sent / block_rows.U + + val preload_cmd_with_bounded_i = WireInit(q.bits(0)) + preload_cmd_with_bounded_i.cmd.rs2 := (minOf(total_I -& I_sent, block_rows.U) << 48) | + (q.bits(0).cmd.rs2(47, 32) << 32) | + (q.bits(0).cmd.rs2(31, 0).asTypeOf(local_addr_t) + I_block * J_blocks * block_rows.U).asUInt() + preload_cmd_with_bounded_i.rob_id.valid := last_send && q.bits(0).rob_id.valid + + val compute_cmd_with_bounded_i = WireInit(q.bits(1)) + compute_cmd_with_bounded_i.cmd.rs1 := (minOf(total_I -& I_sent, block_rows.U) << 48) | + (q.bits(1).cmd.rs1(47, 32) << 32) | + (q.bits(1).cmd.rs1(31, 0).asTypeOf(local_addr_t) + I_block * K_blocks * block_rows.U).asUInt() + compute_cmd_with_bounded_i.cmd.rs2 := (minOf(total_I -& I_sent, block_rows.U) << 48) | + (q.bits(1).cmd.rs2(47, 32) << 32) | + (q.bits(1).cmd.rs2(31, 0).asTypeOf(local_addr_t) + I_block * J_blocks * block_rows.U).asUInt() + compute_cmd_with_bounded_i.rob_id.valid := last_send && q.bits(1).rob_id.valid + + when (I_sent > 0.U) { + preload_cmd_with_bounded_i.cmd.rs1 := (block_rows.U << 48) | (block_cols.U << 32) | GARBAGE_ADDR + compute_cmd_with_bounded_i.cmd.inst.funct := COMPUTE_AND_STAY_CMD + } + when (q.bits(0).cmd.rs2(31, 0).asTypeOf(local_addr_t).is_garbage()) { + preload_cmd_with_bounded_i.cmd.rs2 := (block_rows.U << 48) | (block_cols.U << 32) | GARBAGE_ADDR + } + when (q.bits(1).cmd.rs1(31, 0).asTypeOf(local_addr_t).is_garbage()) { + compute_cmd_with_bounded_i.cmd.rs1 := (block_rows.U << 48) | (block_cols.U << 32) | GARBAGE_ADDR + } + when (q.bits(1).cmd.rs2(31, 0).asTypeOf(local_addr_t).is_garbage() || (dataflow == Dataflow.WS && hardcode_d_to_garbage_addr).B) { + compute_cmd_with_bounded_i.cmd.rs2 := (block_rows.U << 48) | (block_cols.U << 32) | GARBAGE_ADDR + } + + io.out.valid := Mux(must_unroll, (q.valid(0) && state === preload) || (q.valid(1) && state === compute), q.valid(0)) + io.out.bits := Mux(must_unroll, Mux(state === preload, preload_cmd_with_bounded_i, compute_cmd_with_bounded_i), q.bits(0)) + + q.pop := Mux(io.out.fire(), Mux(must_unroll, Mux(state === compute && last_send, 2.U, 0.U), 1.U), 0.U) + + // Control the state + when (io.out.fire() && must_unroll) { + state := state.next + } + + // Control I_sent + when (io.out.fire() && must_unroll && state === compute) { + I_sent := floorAdd(I_sent, block_rows.U, total_I) + } +} + +object ExIUnroller { + def apply[T <: Data : Arithmetic, U <: Data, V <: Data](in: ReadyValidIO[GemminiCmd], config: GemminiArrayConfig[T, U, V])(implicit p: Parameters) = { + val mod = Module(new ExIUnroller(config)) + mod.io.in <> in + mod.io.out + } +} diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index d8a1bac9..1744422d 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -59,12 +59,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } - val unrolled_cmd = TransposePreloadUnroller(io.cmd, config) + val unrolled_cmd = TransposePreloadUnroller(ExIUnroller(io.cmd, config), config) val cmd_q_heads = 3 assert(ex_queue_length >= cmd_q_heads) // val (cmd, _) = MultiHeadedQueue(io.cmd, ex_queue_length, cmd_q_heads) - val (cmd, _) = MultiHeadedQueue(unrolled_cmd, ex_queue_length, cmd_q_heads) + // val (cmd, _) = MultiHeadedQueue(unrolled_cmd, ex_queue_length, cmd_q_heads) + val (cmd, _) = MultiHeadedQueue(unrolled_cmd, rob_full_entries, cmd_q_heads) // TODO this should be ex_queue_length cmd.pop := 0.U io.solitary_preload := cmd.valid(0) && cmd.bits(0).cmd.inst.funct === PRELOAD_CMD && !cmd.valid(1) @@ -784,7 +785,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.enq.bits.a_transpose := a_transpose mesh_cntl_signals_q.io.enq.bits.bd_transpose := bd_transpose - mesh_cntl_signals_q.io.enq.bits.rob_id.valid := !performing_single_mul && !c_address_rs2.is_garbage() + mesh_cntl_signals_q.io.enq.bits.rob_id.valid := cmd.bits(preload_cmd_place).rob_id.valid && !performing_single_mul && !c_address_rs2.is_garbage() mesh_cntl_signals_q.io.enq.bits.rob_id.bits := cmd.bits(preload_cmd_place).rob_id.bits mesh_cntl_signals_q.io.enq.bits.dataflow := current_dataflow @@ -963,15 +964,16 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In //val complete_lock = RegInit(false.B) //Seah: added for WS accumulator - when(mesh.io.resp.fire() && mesh.io.resp.bits.tag.rob_id.valid) { + when(mesh.io.resp.fire()) { output_counter := wrappingAdd(output_counter, 1.U, w_total_output_rows) val last = mesh.io.resp.bits.last - when(last) { - mesh_completed_rob_id_fire := true.B - io.completed.valid := true.B + when(last && mesh.io.resp.bits.tag.rob_id.valid) { + mesh_completed_rob_id_fire := mesh.io.resp.bits.tag.rob_id.valid + io.completed.valid := mesh.io.resp.bits.tag.rob_id.valid io.completed.bits := mesh.io.resp.bits.tag.rob_id.bits } + start_array_outputting := !is_garbage_addr } diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index 8d6db34f..fc9e75d7 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -66,6 +66,26 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( mesh_output_delay: Int, + ld_ooo: Boolean, + ex_ooo: Boolean, + st_ooo: Boolean, + + use_preload_filter: Boolean, + + prng_seed: Int = 1, // ALON: You can change the PRNG seed here + proportion_of_slow_accesses_out_of_128: Int = 10, // ALON: The number of memory accesses (out of 128) that are slow. You can also make this 0 + stall_delay: Int = 1000, // ALON: How many cycles should we wait for a slow memory access? You can also make this 0 + delay_lds: Boolean = false, // ALON: Should loads be stalled? + delay_sts: Boolean = false, // ALON: Should stores be stalled? + + ex_total_k_portions: Int = 1, // ALON: You can change this to any number of k-portions that you would like + ex_fine_grained_interleaving: Boolean = true, // ALON: If this is true, then we use the newer ("finer") intervleaving strategy + + lean_ooo_rob: Boolean = false, // No garbage preloads + lean_weightA: Boolean = false, // Only static weightA supported + + staticWeightAEnabled: Boolean = true, + headerFileName: String = "gemmini_params.h" ) { val sp_width = meshColumns * tileColumns * inputType.getWidth diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index 46a23010..3d08e83a 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -1,6 +1,7 @@ package gemmini import chisel3._ +import chisel3.util._ object GemminiISA { // funct values @@ -54,4 +55,96 @@ object GemminiISA { // dataflow configuration //========================================================================== val GARBAGE_ADDR = "hffffffff".U(32.W) + + private val register_len = 64 + + object LoadCmd { + class Rs1(val coreMaxAddrBits: Int) extends Bundle { + val garbage = UInt((register_len - coreMaxAddrBits).W) + val dram_addr = UInt(coreMaxAddrBits.W) + } + + class Rs2(local_addr_t: LocalAddr) extends Bundle { + private val maxLocalAddrBits = local_addr_t.maxLocalAddrBits + + val garbage1 = UInt(((16 - maxLocalAddrBits) max 0).W) + val rows = UInt((16 min maxLocalAddrBits).W) + val garbage2 = UInt(((16 - maxLocalAddrBits) max 0).W) + val cols = UInt((16 min maxLocalAddrBits).W) + val spad_addr = local_addr_t.cloneType + + override def cloneType: Rs2.this.type = new Rs2(local_addr_t).asInstanceOf[this.type] + } + } + + object StoreCmd { + class Rs1(val coreMaxAddrBits: Int) extends Bundle { + val garbage = UInt((register_len - coreMaxAddrBits).W) + val dram_addr = UInt(coreMaxAddrBits.W) + } + + class Rs2(local_addr_t: LocalAddr) extends Bundle { + private val maxLocalAddrBits = local_addr_t.maxLocalAddrBits + + val garbage1 = UInt(((16 - maxLocalAddrBits) max 0).W) + val rows = UInt((16 min maxLocalAddrBits).W) + val garbage2 = UInt(((16 - maxLocalAddrBits) max 0).W) + val cols = UInt((16 min maxLocalAddrBits).W) + val spad_addr = local_addr_t.cloneType + + override def cloneType: Rs2.this.type = new Rs2(local_addr_t).asInstanceOf[this.type] + } + } + + object PreloadCmd { + class Rs1(local_addr_t: LocalAddr, block_size: Int) extends Bundle { + private val maxLocalAddrBits = local_addr_t.maxLocalAddrBits + + val garbage1 = UInt((16 - log2Up(block_size+1)).W) + val bd_rows = UInt(log2Up(block_size+1).W) + val garbage2 = UInt((16 - log2Up(block_size+1)).W) + val bd_cols = UInt(log2Up(block_size+1).W) + val bd = local_addr_t.cloneType + + override def cloneType: Rs1.this.type = new Rs1(local_addr_t, block_size).asInstanceOf[this.type] + } + + class Rs2(local_addr_t: LocalAddr, block_size: Int, max_block_len: Int) extends Bundle { + private val maxLocalAddrBits = local_addr_t.maxLocalAddrBits + + val garbage1 = UInt((16 - log2Up(max_block_len*block_size+1)).W) + val c_rows = UInt(log2Up(max_block_len*block_size+1).W) + val garbage2 = UInt((16 - log2Up(max_block_len*block_size+1)).W) + val c_cols = UInt(log2Up(max_block_len*block_size+1).W) + val c = local_addr_t.cloneType + + override def cloneType: Rs2.this.type = new Rs2(local_addr_t, block_size, max_block_len).asInstanceOf[this.type] + } + } + + object ComputeCmd { + class Rs1(local_addr_t: LocalAddr, block_size: Int, max_block_len: Int) extends Bundle { + private val maxLocalAddrBits = local_addr_t.maxLocalAddrBits + + val garbage1 = UInt((16 - log2Up(max_block_len*block_size+1)).W) + val a_rows = UInt(log2Up(max_block_len*block_size+1).W) + val garbage2 = UInt((16 - log2Up(max_block_len*block_size+1)).W) + val a_cols = UInt(log2Up(max_block_len*block_size+1).W) + val a = local_addr_t.cloneType + + override def cloneType: Rs1.this.type = new Rs1(local_addr_t, block_size, max_block_len).asInstanceOf[this.type] + } + + class Rs2(local_addr_t: LocalAddr, block_size: Int) extends Bundle { + private val maxLocalAddrBits = local_addr_t.maxLocalAddrBits + + val garbage1 = UInt((16 - log2Up(block_size+1)).W) + val bd_rows = UInt(log2Up(block_size+1).W) + val garbage2 = UInt((16 - log2Up(block_size+1)).W) + val bd_cols = UInt(log2Up(block_size+1).W) + val bd = local_addr_t.cloneType + + override def cloneType: Rs2.this.type = new Rs2(local_addr_t, block_size).asInstanceOf[this.type] + } + } } diff --git a/src/main/scala/gemmini/LoadController.scala b/src/main/scala/gemmini/LoadController.scala index a89a219e..db2d9df6 100644 --- a/src/main/scala/gemmini/LoadController.scala +++ b/src/main/scala/gemmini/LoadController.scala @@ -77,7 +77,9 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig (block_cols * config.accType.getWidth / 8) val maxBytesInMatRequest = block_rows * maxBytesInRowRequest - val cmd_tracker = Module(new DMACommandTracker(nCmds, maxBytesInMatRequest, deps_t)) + val cmd_tracker = Module(new DMACommandTracker(nCmds, maxBytesInMatRequest, deps_t, prng_seed = prng_seed, + proportion_of_slow_accesses_out_of_128 = if (delay_lds) proportion_of_slow_accesses_out_of_128 else 0, + stall_delay = stall_delay)) io.busy := cmd.valid || cmd_tracker.io.busy diff --git a/src/main/scala/gemmini/LocalAddr.scala b/src/main/scala/gemmini/LocalAddr.scala index b003fd7b..e6233493 100644 --- a/src/main/scala/gemmini/LocalAddr.scala +++ b/src/main/scala/gemmini/LocalAddr.scala @@ -8,7 +8,7 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en private val spAddrBits = log2Ceil(sp_banks * sp_bank_entries) private val accAddrBits = log2Ceil(acc_banks * acc_bank_entries) - private val maxAddrBits = spAddrBits max accAddrBits + val maxLocalAddrBits = spAddrBits max accAddrBits private val spBankBits = log2Up(sp_banks) private val spBankRowBits = log2Up(sp_bank_entries) @@ -19,9 +19,9 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en val is_acc_addr = Bool() val accumulate = Bool() val read_full_acc_row = Bool() - val garbage = UInt(((localAddrBits - maxAddrBits - 4) max 0).W) - val garbage_bit = if (localAddrBits - maxAddrBits >= 4) UInt(1.W) else UInt(0.W) - val data = UInt(maxAddrBits.W) + val garbage = UInt(((localAddrBits - maxLocalAddrBits - 4) max 0).W) + val garbage_bit = if (localAddrBits - maxLocalAddrBits >= 4) UInt(1.W) else UInt(0.W) + val data = UInt(maxLocalAddrBits.W) def sp_bank(dummy: Int = 0) = if (spAddrBits == spBankRowBits) 0.U else data(spAddrBits - 1, spBankRowBits) def sp_row(dummy: Int = 0) = data(spBankRowBits - 1, 0) @@ -57,6 +57,10 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en is_acc_addr === other.is_acc_addr && Mux(is_acc_addr, full_acc_addr() > other.full_acc_addr(), full_sp_addr() > other.full_sp_addr()) + def ===(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() === other.full_acc_addr(), full_sp_addr() === other.full_sp_addr()) + def add_with_overflow(other: UInt): Tuple2[LocalAddr, Bool] = { require(isPow2(sp_bank_entries)) // TODO remove this requirement require(isPow2(acc_bank_entries)) // TODO remove this requirement @@ -66,7 +70,7 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en val overflow = Mux(is_acc_addr, sum(accAddrBits), sum(spAddrBits)) val result = WireInit(this) - result.data := sum(maxAddrBits - 1, 0) + result.data := sum(maxLocalAddrBits - 1, 0) (result, overflow) } @@ -76,7 +80,7 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en accumulate := true.B read_full_acc_row := true.B garbage_bit := 1.U - data := ~(0.U(maxAddrBits.W)) + data := ~(0.U(maxLocalAddrBits.W)) } override def cloneType: LocalAddr.this.type = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries).asInstanceOf[this.type] diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index af9e3061..1a85a7fa 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -23,11 +23,11 @@ class LoopMatmulLdAReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat } class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, input_w: Int, - max_block_len: Int, concurrent_loops: Int) + max_block_len: Int, concurrent_loops: Int, cmd_t: GemminiCmd, local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdAReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops))) - val cmd = Decoupled(Output(new RoCCCommand)) + val cmd = Decoupled(Output(cmd_t)) val i = Output(UInt(iterator_bitwidth.W)) val k = Output(UInt(iterator_bitwidth.W)) val idle = Output(Bool()) @@ -66,11 +66,22 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U) val rows = block_size.U - Mux(row_iterator === max_row_iterator-1.U, row_pad, 0.U) + val mvin_cmd_rs1 = Wire(new GemminiISA.LoadCmd.Rs1(coreMaxAddrBits)) + mvin_cmd_rs1 := DontCare + mvin_cmd_rs1.dram_addr := dram_addr + + val mvin_cmd_rs2 = Wire(new GemminiISA.LoadCmd.Rs2(local_addr_t)) + mvin_cmd_rs2 := DontCare + mvin_cmd_rs2.rows := rows + mvin_cmd_rs2.cols := cols + mvin_cmd_rs2.spad_addr := 0.U.asTypeOf(local_addr_t) + mvin_cmd_rs2.spad_addr.data := sp_addr + val mvin_cmd = Wire(new RoCCCommand) mvin_cmd := DontCare mvin_cmd.inst.funct := LOAD_CMD - mvin_cmd.rs1 := dram_addr - mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr + mvin_cmd.rs1 := mvin_cmd_rs1.asUInt() + mvin_cmd.rs2 := mvin_cmd_rs2.asUInt() io.req.ready := state === idle io.i := i @@ -78,7 +89,16 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.idle := state === idle io.cmd.valid := state =/= idle && !io.rob_overloaded - io.cmd.bits := mvin_cmd + io.cmd.bits.cmd := mvin_cmd + io.cmd.bits.rob_id := DontCare + io.cmd.bits.i := i + io.cmd.bits.j := DontCare + io.cmd.bits.k := k + io.cmd.bits.max_i := req.max_i + io.cmd.bits.max_j := DontCare + io.cmd.bits.max_k := req.max_k + io.cmd.bits.use_iterators := true.B + io.cmd.bits.ex_k_portion := DontCare io.loop_id := req.loop_id @@ -121,11 +141,11 @@ class LoopMatmulLdBReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat } class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, input_w: Int, - max_block_len: Int, concurrent_loops: Int) + max_block_len: Int, concurrent_loops: Int, cmd_t: GemminiCmd, local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdBReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops))) - val cmd = Decoupled(Output(new RoCCCommand)) + val cmd = Decoupled(Output(cmd_t)) val k = Output(UInt(iterator_bitwidth.W)) val j = Output(UInt(iterator_bitwidth.W)) @@ -167,11 +187,22 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U) val rows = block_size.U - Mux(max_row_iterator === max_row_iterator-1.U, row_pad, 0.U) + val mvin_cmd_rs1 = Wire(new GemminiISA.LoadCmd.Rs1(coreMaxAddrBits)) + mvin_cmd_rs1 := DontCare + mvin_cmd_rs1.dram_addr := dram_addr + + val mvin_cmd_rs2 = Wire(new GemminiISA.LoadCmd.Rs2(local_addr_t)) + mvin_cmd_rs2 := DontCare + mvin_cmd_rs2.rows := rows + mvin_cmd_rs2.cols := cols + mvin_cmd_rs2.spad_addr := 0.U.asTypeOf(local_addr_t) + mvin_cmd_rs2.spad_addr.data := sp_addr + val mvin_cmd = Wire(new RoCCCommand) mvin_cmd := DontCare mvin_cmd.inst.funct := LOAD2_CMD - mvin_cmd.rs1 := dram_addr - mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr + mvin_cmd.rs1 := mvin_cmd_rs1.asUInt() + mvin_cmd.rs2 := mvin_cmd_rs2.asUInt() io.req.ready := state === idle io.k := k @@ -179,7 +210,16 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.idle := state === idle io.cmd.valid := state =/= idle && !io.rob_overloaded - io.cmd.bits := mvin_cmd + io.cmd.bits.cmd := mvin_cmd + io.cmd.bits.rob_id := DontCare + io.cmd.bits.i := DontCare + io.cmd.bits.j := j + io.cmd.bits.k := k + io.cmd.bits.max_i := DontCare + io.cmd.bits.max_j := req.max_j + io.cmd.bits.max_k := req.max_k + io.cmd.bits.use_iterators := true.B + io.cmd.bits.ex_k_portion := DontCare io.loop_id := req.loop_id @@ -222,11 +262,12 @@ class LoopMatmulLdDReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat } class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, - acc_w: Int, max_block_len: Int, max_block_len_acc: Int, concurrent_loops: Int) + acc_w: Int, max_block_len: Int, max_block_len_acc: Int, concurrent_loops: Int, cmd_t: GemminiCmd, + local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdDReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, concurrent_loops))) - val cmd = Decoupled(Output(new RoCCCommand)) + val cmd = Decoupled(Output(cmd_t)) val idle = Output(Bool()) val rob_overloaded = Input(Bool()) @@ -248,27 +289,46 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val j = Reg(UInt(iterator_bitwidth.W)) val i = Reg(UInt(iterator_bitwidth.W)) - val acc_addr_start = (BigInt(1) << 31).U | req.addr_start - val dram_addr = Mux(req.low_d, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (input_w/8).U, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (acc_w/8).U) - val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U + val sp_addr = req.addr_start + (i * req.max_j + j) * block_size.U val blocks = Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j) val cols = (blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U) val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U) + val mvin_cmd_rs1 = Wire(new GemminiISA.LoadCmd.Rs1(coreMaxAddrBits)) + mvin_cmd_rs1 := DontCare + mvin_cmd_rs1.dram_addr := dram_addr + + val mvin_cmd_rs2 = Wire(new GemminiISA.LoadCmd.Rs2(local_addr_t)) + mvin_cmd_rs2 := DontCare + mvin_cmd_rs2.rows := rows + mvin_cmd_rs2.cols := cols + mvin_cmd_rs2.spad_addr := 0.U.asTypeOf(local_addr_t) + mvin_cmd_rs2.spad_addr.is_acc_addr := true.B + mvin_cmd_rs2.spad_addr.data := sp_addr + val mvin_cmd = Wire(new RoCCCommand) mvin_cmd := DontCare mvin_cmd.inst.funct := LOAD3_CMD - mvin_cmd.rs1 := dram_addr - mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr + mvin_cmd.rs1 := mvin_cmd_rs1.asUInt() + mvin_cmd.rs2 := mvin_cmd_rs2.asUInt() io.req.ready := state === idle io.idle := state === idle // The order here is k, j, i io.cmd.valid := state =/= idle && !io.rob_overloaded && req.dram_addr =/= 0.U - io.cmd.bits := mvin_cmd + io.cmd.bits.cmd := mvin_cmd + io.cmd.bits.rob_id := DontCare + io.cmd.bits.i := i + io.cmd.bits.j := j + io.cmd.bits.k := DontCare + io.cmd.bits.max_i := req.max_i + io.cmd.bits.max_j := req.max_j + io.cmd.bits.max_k := DontCare + io.cmd.bits.use_iterators := true.B + io.cmd.bits.ex_k_portion := DontCare io.loop_id := req.loop_id @@ -305,6 +365,7 @@ class LoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, val it val pad_i = UInt(log2Up(block_size).W) val a_tranpose = Bool() val b_tranpose = Bool() + val ooo = Bool() val accumulate = Bool() val a_addr_start = UInt(log2Up(max_addr).W) val b_addr_end = UInt(log2Up(max_addr).W) @@ -312,13 +373,26 @@ class LoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, val it val loop_id = UInt(log2Up(concurrent_loops).W) } -class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, concurrent_loops: Int) +class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, max_block_len: Int, concurrent_loops: Int, cmd_t: GemminiCmd, total_k_portions: Int, k_portion: Int, fine_grained_interleaving: Boolean, local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { val GARBAGE_ADDR = (~0.U(32.W)).asUInt() + val rocc_cmd_t = new RoCCCommand + class blocks_holder_t extends Bundle { + val opcode = UInt(rocc_cmd_t.inst.opcode.getWidth.W) + val rs1 = UInt(rocc_cmd_t.inst.rs1.getWidth.W) + val rs2 = UInt(rocc_cmd_t.inst.rs2.getWidth.W) + val rd = UInt(rocc_cmd_t.inst.rd.getWidth.W) + + override def cloneType: blocks_holder_t.this.type = (new blocks_holder_t).asInstanceOf[this.type] + } + val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulExecuteReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops))) - val cmd = Decoupled(Output(new RoCCCommand)) + val cmd = Decoupled(Output(cmd_t)) + + val req_out = Output(new LoopMatmulExecuteReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) + val is_pre = Output(Bool()) val k = Output(UInt(iterator_bitwidth.W)) val j = Output(UInt(iterator_bitwidth.W)) @@ -335,6 +409,10 @@ class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val idle = Output(Bool()) val rob_overloaded = Input(Bool()) + val must_send_compute = Output(Bool()) + + val can_send_command = Output(Bool()) + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) }) @@ -345,15 +423,28 @@ class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val state = RegInit(idle) val req = Reg(new LoopMatmulExecuteReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) + io.req_out := req + + // val max_i_blocks = Mux(req.a_tranpose, 1.U, Mux(req.max_i <= max_block_len.U, req.max_i, max_block_len.U)) + // val max_i_blocks = Mux(req.max_i <= max_block_len.U, req.max_i, max_block_len.U) + val max_i_blocks = max_block_len.U + val lower_k_bound = if (fine_grained_interleaving) { (max_block_len * k_portion).U } else { (req.max_k / total_k_portions.U) * k_portion.U } + val upper_k_bound = if (fine_grained_interleaving || k_portion == total_k_portions - 1) { req.max_k } else { (req.max_k / total_k_portions.U) * (k_portion + 1).U } + + /* val d_addr_start = (BigInt(1) << 31).U | req.c_addr_start val c_addr_start = (BigInt(3) << 30).U | req.c_addr_start val b_addr_start = req.b_addr_end - req.max_k * req.max_j * block_size.U + */ val k = Reg(UInt(iterator_bitwidth.W)) val j = Reg(UInt(iterator_bitwidth.W)) val i = Reg(UInt(iterator_bitwidth.W)) + val i_blocks = Mux(i + max_i_blocks <= req.max_i, max_i_blocks, req.max_i-i) + + /* val a_row = Mux(req.a_tranpose, k, i) val a_col = Mux(req.a_tranpose, i, k) val b_row = Mux(req.b_tranpose, j, k) @@ -368,71 +459,285 @@ class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val c_addr = c_addr_start + (i * req.max_j + j) * block_size.U val a_cols = block_size.U - Mux(k === req.max_k - 1.U, req.pad_k, 0.U) - val a_rows = block_size.U - Mux(i === req.max_i - 1.U, req.pad_i, 0.U) + val a_rows = i_blocks * block_size.U - Mux(i + max_i_blocks >= req.max_i, req.pad_i, 0.U) val b_cols = block_size.U - Mux(j === req.max_j - 1.U, req.pad_j, 0.U) val b_rows = block_size.U - Mux(k === req.max_k - 1.U, req.pad_k, 0.U) val c_cols = block_size.U - Mux(j === req.max_j - 1.U, req.pad_j, 0.U) - val c_rows = block_size.U - Mux(i === req.max_i - 1.U, req.pad_i, 0.U) + val c_rows = i_blocks * block_size.U - Mux(i + max_i_blocks >= req.max_i, req.pad_i, 0.U) + + val pre_addr_is_not_garbage = i === 0.U || req.ooo + val out_addr_accumulates = req.accumulate || k =/= 0.U + + val pre_addr = Mux(pre_addr_is_not_garbage, b_addr, GARBAGE_ADDR) + val out_addr = Mux(out_addr_accumulates, c_addr, d_addr) + */ - val pre_addr = Mux(i === 0.U, b_addr, GARBAGE_ADDR) - val out_addr = Mux(req.accumulate || k =/= 0.U, c_addr, d_addr) + /* + val j_blocks_holder = req.max_j.asTypeOf(new blocks_holder_t) + val k_blocks_holder = req.max_k.asTypeOf(new blocks_holder_t) + + val pre_cmd_rs1 = Wire(new GemminiISA.PreloadCmd.Rs1(local_addr_t, block_size)) + pre_cmd_rs1 := DontCare + pre_cmd_rs1.bd_rows := b_rows + pre_cmd_rs1.bd_cols := b_cols + pre_cmd_rs1.bd := 0.U.asTypeOf(local_addr_t) + pre_cmd_rs1.bd.data := pre_addr + + when (!pre_addr_is_not_garbage) { + pre_cmd_rs1.bd.make_this_garbage() + } + + val pre_cmd_rs2 = Wire(new GemminiISA.PreloadCmd.Rs2(local_addr_t, block_size, max_block_len)) + pre_cmd_rs2 := DontCare + pre_cmd_rs2.c_rows := c_rows + pre_cmd_rs2.c_cols := c_cols + pre_cmd_rs2.c := 0.U.asTypeOf(local_addr_t) + pre_cmd_rs2.c.is_acc_addr := true.B + pre_cmd_rs2.c.accumulate := out_addr_accumulates + pre_cmd_rs2.c.data := out_addr val pre_cmd = Wire(new RoCCCommand) pre_cmd := DontCare pre_cmd.inst.funct := PRELOAD_CMD - pre_cmd.rs1 := pre_addr | (b_cols << 32).asUInt() | (b_rows << 48).asUInt() - pre_cmd.rs2 := out_addr | (c_cols << 32).asUInt() | (c_rows << 48).asUInt() + pre_cmd.rs1 := pre_cmd_rs1.asUInt() + pre_cmd.rs2 := pre_cmd_rs2.asUInt() + pre_cmd.inst.opcode := j_blocks_holder.opcode + pre_cmd.inst.rs1 := j_blocks_holder.rs1 + pre_cmd.inst.rs2 := j_blocks_holder.rs2 + pre_cmd.inst.rd := j_blocks_holder.rd + + val comp_cmd_rs1 = Wire(new GemminiISA.ComputeCmd.Rs1(local_addr_t, block_size, max_block_len)) + comp_cmd_rs1 := DontCare + comp_cmd_rs1.a_rows := a_rows + comp_cmd_rs1.a_cols := a_cols + comp_cmd_rs1.a := 0.U.asTypeOf(local_addr_t) + comp_cmd_rs1.a.data := a_addr val comp_cmd = Wire(new RoCCCommand()) comp_cmd := DontCare - comp_cmd.inst.funct := Mux(i === 0.U, COMPUTE_AND_FLIP_CMD, COMPUTE_AND_STAY_CMD) - comp_cmd.rs1 := a_addr | (a_cols << 32).asUInt() | (a_rows << 48).asUInt() + comp_cmd.inst.funct := Mux(i === 0.U || req.ooo, COMPUTE_AND_FLIP_CMD, COMPUTE_AND_STAY_CMD) + comp_cmd.rs1 := comp_cmd_rs1.asUInt() comp_cmd.rs2 := GARBAGE_ADDR | (block_size.U << 32).asUInt() | (block_size.U << 48).asUInt() + comp_cmd.inst.opcode := k_blocks_holder.opcode + comp_cmd.inst.rs1 := k_blocks_holder.rs1 + comp_cmd.inst.rs2 := k_blocks_holder.rs2 + comp_cmd.inst.rd := k_blocks_holder.rd + */ io.req.ready := state === idle io.k := k io.j := j io.i := i io.idle := state === idle + io.must_send_compute := state === comp // The order here is k, j, i - val lda_ahead = io.lda_completed || io.ld_ka > k || (io.ld_ka === k && io.ld_i > i) + // val lda_ahead = io.lda_completed || io.ld_ka > k || (io.ld_ka === k && io.ld_i >= i + i_blocks) + val lda_ahead = io.lda_completed || io.ld_ka > k || (io.ld_ka === k && io.ld_i >= i + max_i_blocks) val ldb_ahead = io.ldb_completed || io.ld_kb > k || (io.ld_ka === k && io.ld_j > j) val ldd_ahead = io.ldd_completed val ld_ahead = lda_ahead && ldb_ahead && ldd_ahead + io.can_send_command := state =/= idle && ld_ahead + io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead - io.cmd.bits := Mux(state === pre, pre_cmd, comp_cmd) + io.cmd.bits := 0.U.asTypeOf(io.cmd.bits) + /* + io.cmd.bits.cmd := Mux(state === pre, pre_cmd, comp_cmd) + io.cmd.bits.rob_id := DontCare + io.cmd.bits.i := i + io.cmd.bits.j := j + io.cmd.bits.k := k + io.cmd.bits.max_i := req.max_i + io.cmd.bits.max_j := req.max_j + io.cmd.bits.max_k := req.max_k + io.cmd.bits.use_iterators := true.B + */ + io.cmd.bits.ex_k_portion := k_portion.U io.loop_id := req.loop_id + io.is_pre := state === pre + when (io.cmd.fire()) { when (state === pre) { state := comp }.otherwise { - val next_i = floorAdd(i, 1.U, req.max_i) + // val jump_k = fine_grained_interleaving.B && (k +& 1.U) % max_block_len.U === 0.U + val jump_k = fine_grained_interleaving.B && (k % max_block_len.U) === (max_block_len-1).U + val k_it = Mux(jump_k, (total_k_portions * max_block_len - max_block_len + 1).U, 1.U) + + val next_i = floorAdd(i, max_i_blocks, req.max_i) val next_j = floorAdd(j, 1.U, req.max_j, next_i === 0.U) - val next_k = floorAdd(k, 1.U, req.max_k, next_j === 0.U && next_i === 0.U) + // val next_k = floorAdd(k, 1.U, req.max_k, next_j === 0.U && next_i === 0.U) + // val next_k = floorAdd(k, k_it, upper_k_bound, next_j === 0.U && next_i === 0.U, min=lower_k_bound) + val next_k = floorAdd(k, k_it, upper_k_bound, next_j === 0.U && next_i === 0.U) k := next_k j := next_j i := next_i state := Mux(next_k === 0.U && next_j === 0.U && next_i === 0.U, idle, pre) + // state := Mux(next_k === lower_k_bound && next_j === 0.U && next_i === 0.U, idle, pre) } } when (io.req.fire()) { req := io.req.bits - state := pre j := 0.U - k := 0.U + // k := 0.U + k := (if (fine_grained_interleaving) { lower_k_bound } else { (io.req.bits.max_k / total_k_portions.U) * k_portion.U }) i := 0.U + + when (!fine_grained_interleaving.B || lower_k_bound < io.req.bits.max_k) { + state := pre + } } assert(!(state =/= idle && req.a_tranpose && req.b_tranpose)) } +class LoopMatmulExecuteAddrGenerator(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, max_block_len: Int, concurrent_loops: Int, cmd_t: GemminiCmd, total_k_portions: Int, fine_grained_interleaving: Boolean, local_addr_t: LocalAddr, no_garbage_preload: Boolean) + (implicit p: Parameters) extends Module { + val GARBAGE_ADDR = (~0.U(32.W)).asUInt() + + val rocc_cmd_t = new RoCCCommand + class blocks_holder_t extends Bundle { + val opcode = UInt(rocc_cmd_t.inst.opcode.getWidth.W) + val rs1 = UInt(rocc_cmd_t.inst.rs1.getWidth.W) + val rs2 = UInt(rocc_cmd_t.inst.rs2.getWidth.W) + val rd = UInt(rocc_cmd_t.inst.rd.getWidth.W) + + override def cloneType: blocks_holder_t.this.type = (new blocks_holder_t).asInstanceOf[this.type] + } + + val io = IO(new Bundle { + val req = Input(new LoopMatmulExecuteReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) + val k = Input(UInt(iterator_bitwidth.W)) + val j = Input(UInt(iterator_bitwidth.W)) + val i = Input(UInt(iterator_bitwidth.W)) + val is_pre = Input(Bool()) + val k_portion = Input(UInt(log2Up(total_k_portions).W)) + + val cmd = Output(cmd_t) + }) + + val req = io.req + val is_pre = io.is_pre + val k_portion = io.k_portion + + // val max_i_blocks = Mux(req.a_tranpose, 1.U, Mux(req.max_i <= max_block_len.U, req.max_i, max_block_len.U)) + val max_i_blocks = max_block_len.U + + val d_addr_start = (BigInt(1) << 31).U | req.c_addr_start + val c_addr_start = (BigInt(3) << 30).U | req.c_addr_start + val b_addr_start = req.b_addr_end - req.max_k * req.max_j * block_size.U + + val k = io.k + val j = io.j + val i = io.i + + val i_blocks = Mux(i + max_i_blocks <= req.max_i, max_i_blocks, req.max_i-i) + + /* + val a_row = Mux(req.a_tranpose, k, i) + val a_col = Mux(req.a_tranpose, i, k) + val b_row = Mux(req.b_tranpose, j, k) + val b_col = Mux(req.b_tranpose, k, j) + + val a_max_col = Mux(req.a_tranpose, req.max_i, req.max_k) + val b_max_col = Mux(req.b_tranpose, req.max_k, req.max_j) + */ + + val a_row = i + val a_col = k + val b_row = k + val b_col = j + + val a_max_col = req.max_k + val b_max_col = req.max_j + + val a_addr = req.a_addr_start + (a_row * a_max_col + a_col) * block_size.U + val b_addr = b_addr_start + (b_row * b_max_col + b_col) * block_size.U + val d_addr = d_addr_start + (i * req.max_j + j) * block_size.U + val c_addr = c_addr_start + (i * req.max_j + j) * block_size.U + + val a_cols = block_size.U - Mux(k === req.max_k - 1.U, req.pad_k, 0.U) + val a_rows = i_blocks * block_size.U - Mux(i + max_i_blocks >= req.max_i, req.pad_i, 0.U) + val b_cols = block_size.U - Mux(j === req.max_j - 1.U, req.pad_j, 0.U) + val b_rows = block_size.U - Mux(k === req.max_k - 1.U, req.pad_k, 0.U) + val c_cols = block_size.U - Mux(j === req.max_j - 1.U, req.pad_j, 0.U) + val c_rows = i_blocks * block_size.U - Mux(i + max_i_blocks >= req.max_i, req.pad_i, 0.U) + + val pre_addr_is_not_garbage = i === 0.U || req.ooo || no_garbage_preload.B + val out_addr_accumulates = req.accumulate || k =/= 0.U + + val pre_addr = Mux(pre_addr_is_not_garbage, b_addr, GARBAGE_ADDR) + val out_addr = Mux(out_addr_accumulates, c_addr, d_addr) + + val j_blocks_holder = req.max_j.asTypeOf(new blocks_holder_t) + val k_blocks_holder = req.max_k.asTypeOf(new blocks_holder_t) + + val pre_cmd_rs1 = Wire(new GemminiISA.PreloadCmd.Rs1(local_addr_t, block_size)) + pre_cmd_rs1 := DontCare + pre_cmd_rs1.bd_rows := b_rows + pre_cmd_rs1.bd_cols := b_cols + pre_cmd_rs1.bd := 0.U.asTypeOf(local_addr_t) + pre_cmd_rs1.bd.data := pre_addr + + when (!pre_addr_is_not_garbage) { + pre_cmd_rs1.bd.make_this_garbage() + } + + val pre_cmd_rs2 = Wire(new GemminiISA.PreloadCmd.Rs2(local_addr_t, block_size, max_block_len)) + pre_cmd_rs2 := DontCare + pre_cmd_rs2.c_rows := c_rows + pre_cmd_rs2.c_cols := c_cols + pre_cmd_rs2.c := 0.U.asTypeOf(local_addr_t) + pre_cmd_rs2.c.is_acc_addr := true.B + pre_cmd_rs2.c.accumulate := out_addr_accumulates + pre_cmd_rs2.c.data := out_addr + + val pre_cmd = Wire(new RoCCCommand) + pre_cmd := DontCare + pre_cmd.inst.funct := PRELOAD_CMD + pre_cmd.rs1 := pre_cmd_rs1.asUInt() + pre_cmd.rs2 := pre_cmd_rs2.asUInt() + pre_cmd.inst.opcode := j_blocks_holder.opcode + pre_cmd.inst.rs1 := j_blocks_holder.rs1 + pre_cmd.inst.rs2 := j_blocks_holder.rs2 + pre_cmd.inst.rd := j_blocks_holder.rd + + val comp_cmd_rs1 = Wire(new GemminiISA.ComputeCmd.Rs1(local_addr_t, block_size, max_block_len)) + comp_cmd_rs1 := DontCare + comp_cmd_rs1.a_rows := a_rows + comp_cmd_rs1.a_cols := a_cols + comp_cmd_rs1.a := 0.U.asTypeOf(local_addr_t) + comp_cmd_rs1.a.data := a_addr + + val comp_cmd = Wire(new RoCCCommand()) + comp_cmd := DontCare + comp_cmd.inst.funct := Mux(i === 0.U || req.ooo, COMPUTE_AND_FLIP_CMD, COMPUTE_AND_STAY_CMD) + comp_cmd.rs1 := comp_cmd_rs1.asUInt() + comp_cmd.rs2 := GARBAGE_ADDR | (block_size.U << 32).asUInt() | (block_size.U << 48).asUInt() + comp_cmd.inst.opcode := k_blocks_holder.opcode + comp_cmd.inst.rs1 := k_blocks_holder.rs1 + comp_cmd.inst.rs2 := k_blocks_holder.rs2 + comp_cmd.inst.rd := k_blocks_holder.rd + + io.cmd.cmd := Mux(is_pre, pre_cmd, comp_cmd) + io.cmd.rob_id := DontCare + io.cmd.i := i + io.cmd.j := j + io.cmd.k := k + io.cmd.max_i := req.max_i + io.cmd.max_j := req.max_j + io.cmd.max_k := req.max_k + io.cmd.use_iterators := true.B + io.cmd.ex_k_portion := k_portion +} + + // StC class LoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { @@ -448,11 +753,11 @@ class LoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat val loop_id = UInt(log2Up(concurrent_loops).W) } -class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int) +class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int, cmd_t: GemminiCmd, local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulStCReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, concurrent_loops))) - val cmd = Decoupled(Output(new RoCCCommand)) + val cmd = Decoupled(Output(cmd_t)) val ex_k = Input(UInt(iterator_bitwidth.W)) val ex_j = Input(UInt(iterator_bitwidth.W)) @@ -481,20 +786,31 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val j = Reg(UInt(iterator_bitwidth.W)) val i = Reg(UInt(iterator_bitwidth.W)) - val acc_addr_start = (BigInt(1) << 31).U | (req.full_c << 29.U).asUInt() | req.addr_start - val dram_addr = Mux(req.full_c, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (acc_w/8).U, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (input_w/8).U) - val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U + val sp_addr = req.addr_start + (i * req.max_j + j) * block_size.U val blocks = Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j) val cols = (blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U) val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U) + val mvout_cmd_rs1 = Wire(new GemminiISA.LoadCmd.Rs1(coreMaxAddrBits)) + mvout_cmd_rs1 := DontCare + mvout_cmd_rs1.dram_addr := dram_addr + + val mvout_cmd_rs2 = Wire(new GemminiISA.LoadCmd.Rs2(local_addr_t)) + mvout_cmd_rs2 := DontCare + mvout_cmd_rs2.rows := rows + mvout_cmd_rs2.cols := cols + mvout_cmd_rs2.spad_addr := 0.U.asTypeOf(local_addr_t) + mvout_cmd_rs2.spad_addr.is_acc_addr := true.B + mvout_cmd_rs2.spad_addr.read_full_acc_row := req.full_c + mvout_cmd_rs2.spad_addr.data := sp_addr + val mvout_cmd = Wire(new RoCCCommand) mvout_cmd := DontCare mvout_cmd.inst.funct := STORE_CMD - mvout_cmd.rs1 := dram_addr - mvout_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr + mvout_cmd.rs1 := mvout_cmd_rs1.asUInt() + mvout_cmd.rs2 := mvout_cmd_rs2.asUInt() io.req.ready := state === idle io.j := j @@ -509,7 +825,16 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In ((io.ex_j === j + blocks - 1.U) && io.ex_i > i))) io.cmd.valid := state =/= idle && !io.rob_overloaded && ex_ahead && req.dram_addr =/= 0.U - io.cmd.bits := mvout_cmd + io.cmd.bits.cmd := mvout_cmd + io.cmd.bits.rob_id := DontCare + io.cmd.bits.i := i + io.cmd.bits.j := j + io.cmd.bits.k := DontCare + io.cmd.bits.max_i := req.max_i + io.cmd.bits.max_j := req.max_j + io.cmd.bits.max_k := DontCare + io.cmd.bits.use_iterators := true.B + io.cmd.bits.ex_k_portion := DontCare io.loop_id := req.loop_id @@ -564,6 +889,7 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val val ex_accumulate = Bool() val weightA = UInt(8.W) // TODO magic numbers + val ooo = Bool() val configured = Bool() @@ -605,19 +931,21 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val } } -class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, - max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) +class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, rob_full_entries: Int, max_lds: Int, max_exs: Int, max_sts: Int, + max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int, cmd_t: GemminiCmd, ex_total_k_portions: Int, ex_fine_grained_interleaving: Boolean, local_addr_t: LocalAddr, lean_weightA: Boolean, lean_ooo_rob: Boolean, staticWeightAEnabled: Boolean) (implicit p: Parameters) extends Module { val iterator_bitwidth = 16 + val iterator_bitwidth_ceiled = 16 min (local_addr_t.maxLocalAddrBits - log2Up(block_size) + 1) val max_block_len = (dma_max_bytes / (block_size * input_w / 8)) max 1 val max_block_len_acc = (dma_max_bytes / (block_size * acc_w / 8)) max 1 val io = IO(new Bundle { val in = Flipped(Decoupled(new RoCCCommand)) - val out = Decoupled(new RoCCCommand) + val out = Decoupled(cmd_t) val ld_utilization = Input(UInt(log2Up(rob_size+1).W)) val st_utilization = Input(UInt(log2Up(rob_size+1).W)) val ex_utilization = Input(UInt(log2Up(rob_size+1).W)) + val ex_k_portion_utilizations = Input(Vec(ex_total_k_portions, UInt(log2Up(rob_size+1).W))) val busy = Output(Bool()) }) @@ -635,11 +963,14 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val loop_being_configured = loops(loop_being_configured_id) // Create inner modules - val ldA = Module(new LoopMatmulLdA(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) - val ldB = Module(new LoopMatmulLdB(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) - val ldD = Module(new LoopMatmulLdD(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, max_block_len_acc, concurrent_loops)) - val ex = Module(new LoopMatmulExecute(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) - val stC = Module(new LoopMatmulStC(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, concurrent_loops)) + val ldA = Module(new LoopMatmulLdA(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops, cmd_t, local_addr_t)) + val ldB = Module(new LoopMatmulLdB(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops, cmd_t, local_addr_t)) + val ldD = Module(new LoopMatmulLdD(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, max_block_len_acc, concurrent_loops, cmd_t, local_addr_t)) + // val ex = Module(new LoopMatmulExecute(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, max_block_len, concurrent_loops, cmd_t)) + val exs = (0 until ex_total_k_portions).map { i => + Module(new LoopMatmulExecute(block_size, coreMaxAddrBits, iterator_bitwidth_ceiled, max_addr, max_acc_addr, max_block_len, concurrent_loops, cmd_t, total_k_portions = ex_total_k_portions, k_portion = i, fine_grained_interleaving = ex_fine_grained_interleaving, local_addr_t)) + } + val stC = Module(new LoopMatmulStC(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, concurrent_loops, cmd_t, local_addr_t)) // Create command queue val cmd = Queue(io.in) @@ -647,18 +978,67 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: io.busy := cmd.valid || loop_configured // Create ld arbiters - val ldab_arb = Module(new WeightedArbiter(new RoCCCommand(), maxWeightA=255)) // TODO magic numbers + val ldab_arb = Module(new WeightedArbiter(cmd_t, maxWeightA=255, staticWeightAEnabled=staticWeightAEnabled, onlyStaticWeightA=lean_weightA)) // TODO magic numbers ldab_arb.io.inA <> ldA.io.cmd ldab_arb.io.inB <> ldB.io.cmd val ab_loads_on_same_loop = ldA.io.loop_id === ldB.io.loop_id ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id ldab_arb.io.weightA := head_loop.weightA + ldab_arb.io.inA_idle := ldA.io.idle + ldab_arb.io.inB_idle := ldB.io.idle + ldab_arb.io.inA_k := ldA.io.k + ldab_arb.io.inA_i := ldA.io.i + ldab_arb.io.inB_k := ldB.io.k + ldab_arb.io.inB_j := ldB.io.j + + // Create ex arbiters + // ALON: This is the arbiter between the k-portions. You could try out an RR arbiter instead. Right now, we're using Chisel's default arbiter which is a priority arbiter that prioritizes the earliest k-portions + class ExAddrGeneratorInput extends Bundle { + val req = exs.head.io.req_out.cloneType + val k = exs.head.io.k.cloneType + val j = exs.head.io.j.cloneType + val i = exs.head.io.i.cloneType + val is_pre = Bool() + val k_portion = UInt(log2Up(ex_total_k_portions).W) + + override def cloneType: ExAddrGeneratorInput.this.type = new ExAddrGeneratorInput().asInstanceOf[this.type] + } + val ex_arb = Module(new ExArbiter(new ExAddrGeneratorInput, ex_total_k_portions, ex_fine_grained_interleaving, iterator_bitwidth_ceiled)) + (ex_arb.io.in, ex_arb.io.k, exs).zipped.foreach { case (in, k, ex) => + in.valid := ex.io.cmd.valid + ex.io.cmd.ready := in.ready + in.bits.req := ex.io.req_out + in.bits.k := ex.io.k + in.bits.j := ex.io.j + in.bits.i := ex.io.i + in.bits.is_pre := ex.io.is_pre + in.bits.k_portion := ex.io.cmd.bits.ex_k_portion + k := ex.io.k + } + val ex_addr_generator = Module(new LoopMatmulExecuteAddrGenerator(block_size, coreMaxAddrBits, iterator_bitwidth_ceiled, max_addr, max_acc_addr, max_block_len, concurrent_loops, cmd_t, total_k_portions = ex_total_k_portions, fine_grained_interleaving = ex_fine_grained_interleaving, local_addr_t, lean_ooo_rob)) + ex_addr_generator.io.req := ex_arb.io.out.bits.req + ex_addr_generator.io.k := ex_arb.io.out.bits.k + ex_addr_generator.io.j := ex_arb.io.out.bits.j + ex_addr_generator.io.i := ex_arb.io.out.bits.i + ex_addr_generator.io.is_pre := ex_arb.io.out.bits.is_pre + ex_addr_generator.io.k_portion := ex_arb.io.out.bits.k_portion + /* + val ex_arb = Module(new ExArbiter(cmd_t, ex_total_k_portions, ex_fine_grained_interleaving)) + (ex_arb.io.in, ex_arb.io.k, exs).zipped.foreach { case (in, k, ex) => + in <> ex.io.cmd + k := ex.io.k + } + */ // Create global arbiter - val arb = Module(new Arbiter(new RoCCCommand(), 4)) + val arb = Module(new Arbiter(cmd_t, 4)) arb.io.in(0) <> stC.io.cmd - arb.io.in(1) <> ex.io.cmd + // arb.io.in(1) <> ex.io.cmd + // arb.io.in(1) <> ex_arb.io.out + arb.io.in(1).valid := ex_arb.io.out.valid + arb.io.in(1).bits := ex_addr_generator.io.cmd + ex_arb.io.out.ready := arb.io.in(1).ready arb.io.in(2) <> ldD.io.cmd arb.io.in(3) <> ldab_arb.io.out val unrolled_cmd = arb.io.out @@ -668,8 +1048,15 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val is_loop_config_cmd = cmd.bits.inst.funct >= LOOP_WS_CONFIG_BOUNDS && cmd.bits.inst.funct <= LOOP_WS_CONFIG_STRIDES_DC val is_loop_cmd = is_loop_run_cmd || is_loop_config_cmd - io.out.bits := Mux(loop_configured, unrolled_cmd.bits, cmd.bits) - io.out.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this + // io.out.bits := Mux(loop_configured, unrolled_cmd.bits, cmd.bits) + when (loop_configured) { + io.out.bits := unrolled_cmd.bits + }.otherwise { + io.out.bits := DontCare + io.out.bits.cmd := cmd.bits + io.out.bits.use_iterators := false.B + } + io.out.bits.cmd.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this io.out.valid := Mux(loop_configured, unrolled_cmd.valid, cmd.valid && !is_loop_config_cmd && !is_loop_run_cmd) cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured, !loop_configured && io.out.ready) @@ -678,11 +1065,40 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: // Wire up overloaded signals ldA.io.rob_overloaded := io.ld_utilization >= max_lds.U ldB.io.rob_overloaded := io.ld_utilization >= max_lds.U - ex.io.rob_overloaded := io.ex_utilization >= max_exs.U + // ex.io.rob_overloaded := io.ex_utilization >= max_exs.U + (exs zip io.ex_k_portion_utilizations).zipWithIndex.foreach { case ((ex, k_util), id) => + /* + A k-portion is inactive iff it has finished sending all its matmul commands, or if it can't send any matmul commands + currently because the loads that it needs haven't been sent out yet + */ + + val other_exs = exs.filter(_ != ex) + val must_wait_for_other_compute = if (ex_total_k_portions == 1) { false.B } else { other_exs.map(_.io.must_send_compute).reduce(_ || _) } + + val limits = if (ex_total_k_portions == 1) { Seq(rob_full_entries) } else { (1 to ex_total_k_portions).map(i => rob_full_entries / i) } + val limits_uint = VecInit(limits.map(_.U)) + val first_limits = VecInit(limits.map(l => if (ex_fine_grained_interleaving) l.U else (l * 1.5).toInt.U)) // ALON: You can scale the earliest k-portion's limit by any scalar factor (e.g. 1.25) that you would like + + val active_exs = PopCount(exs.map(_.io.can_send_command)) + val earliest_k_portion = MuxCase((ex_total_k_portions - 1).U, (0 until ex_total_k_portions).map { i => + exs(i).io.can_send_command -> i.U + }) + + val default_k_util_limit = Mux(id.U === earliest_k_portion, first_limits(active_exs), limits_uint(active_exs)) + val max_k_util_limit = (rob_full_entries - 2).U + + // If we've just send a preload, then we should just send the next compute, without worrying about k_util + val k_util_limit = Mux(ex.io.must_send_compute || default_k_util_limit > max_k_util_limit, max_k_util_limit, + default_k_util_limit) + + // ALON: You can change "k_util_limit" to any limit (e.g. 12.U) that you would like + ex.io.rob_overloaded := io.ex_utilization >= max_exs.U || k_util >= k_util_limit || must_wait_for_other_compute + } ldD.io.rob_overloaded := io.ld_utilization >= max_lds.U stC.io.rob_overloaded := io.st_utilization >= max_sts.U // Wire up iterator inputs + /* ex.io.lda_completed := (ldA.io.loop_id =/= ex.io.loop_id) || ldA.io.idle ex.io.ldb_completed := (ldB.io.loop_id =/= ex.io.loop_id) || ldB.io.idle ex.io.ldd_completed := (ldD.io.loop_id =/= ex.io.loop_id) || ldD.io.idle @@ -690,11 +1106,49 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: ex.io.ld_kb := ldB.io.k ex.io.ld_j := ldB.io.j ex.io.ld_i := ldA.io.i + */ + exs.foreach { ex => + ex.io.lda_completed := (ldA.io.loop_id =/= ex.io.loop_id) || ldA.io.idle + ex.io.ldb_completed := (ldB.io.loop_id =/= ex.io.loop_id) || ldB.io.idle + ex.io.ldd_completed := (ldD.io.loop_id =/= ex.io.loop_id) || ldD.io.idle + ex.io.ld_ka := ldA.io.k + ex.io.ld_kb := ldB.io.k + ex.io.ld_j := ldB.io.j + ex.io.ld_i := ldA.io.i + } + /* stC.io.ex_completed := (ex.io.loop_id =/= stC.io.loop_id) || ex.io.idle stC.io.ex_k := ex.io.k stC.io.ex_j := ex.io.j stC.io.ex_i := ex.io.i + */ + val exs_completed = exs.map(ex => (ex.io.loop_id =/= stC.io.loop_id) || ex.io.idle) + stC.io.ex_completed := exs_completed.reduce(_ && _) + if (ex_fine_grained_interleaving) { + // TODO getting the index here is very inefficient + val max_k = exs.map(_.io.k).reduce(maxOf) + val ks_maxed = (exs zip exs_completed).map { case (ex, completed) => Mux(completed, max_k, ex.io.k) } + val min_k = ks_maxed.reduce(minOf) + val min_k_index = WireInit(0.U(log2Up(ex_total_k_portions).W)) + ks_maxed.zipWithIndex.foreach { case (k, i) => + when (k === min_k) { + min_k_index := i.U + } + } + + val ks = VecInit(exs.map(_.io.k)) + val js = VecInit(exs.map(_.io.j)) + val is = VecInit(exs.map(_.io.i)) + + stC.io.ex_k := ks(min_k_index) + stC.io.ex_j := js(min_k_index) + stC.io.ex_i := is(min_k_index) + } else { + stC.io.ex_k := MuxCase(exs.last.io.k, (exs_completed zip exs).init.map { case (ex_completed, ex) => (!ex_completed) -> ex.io.k }) + stC.io.ex_j := MuxCase(exs.last.io.j, (exs_completed zip exs).init.map { case (ex_completed, ex) => (!ex_completed) -> ex.io.j }) + stC.io.ex_i := MuxCase(exs.last.io.i, (exs_completed zip exs).init.map { case (ex_completed, ex) => (!ex_completed) -> ex.io.i }) + } val loops_configured = RegInit(0.U(16.W)) dontTouch(loops_configured) @@ -741,6 +1195,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: loop_being_configured.b_transpose := cmd.bits.rs2(1) loop_being_configured.weightA := cmd.bits.rs1(15, 8) // TODO magic numbers + loop_being_configured.ooo := lean_ooo_rob.B || cmd.bits.rs2(2) // TODO magic numbers loop_being_configured.configured := true.B @@ -794,6 +1249,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val loop_requesting_ex_id = Mux(head_loop.ex_started, tail_loop_id, head_loop_id) val loop_requesting_ex = loops(loop_requesting_ex_id) + /* ex.io.req.bits.max_j := loop_requesting_ex.max_j ex.io.req.bits.max_k := loop_requesting_ex.max_k ex.io.req.bits.max_i := loop_requesting_ex.max_i @@ -805,6 +1261,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end ex.io.req.bits.a_tranpose := loop_requesting_ex.a_transpose ex.io.req.bits.b_tranpose := loop_requesting_ex.b_transpose + ex.io.req.bits.ooo := loop_requesting_ex.ooo ex.io.req.bits.c_addr_start := ex_c_addr_start ex.io.req.bits.loop_id := loop_requesting_ex_id @@ -819,6 +1276,36 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: ex_c_addr_start := floorAdd(ex_c_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) } } + */ + exs.foreach { ex => + ex.io.req.bits.max_j := loop_requesting_ex.max_j + ex.io.req.bits.max_k := loop_requesting_ex.max_k + ex.io.req.bits.max_i := loop_requesting_ex.max_i + ex.io.req.bits.pad_j := loop_requesting_ex.pad_j + ex.io.req.bits.pad_k := loop_requesting_ex.pad_k + ex.io.req.bits.pad_i := loop_requesting_ex.pad_i + ex.io.req.bits.accumulate := loop_requesting_ex.ex_accumulate + ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start + ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end + ex.io.req.bits.a_tranpose := loop_requesting_ex.a_transpose + ex.io.req.bits.b_tranpose := loop_requesting_ex.b_transpose + ex.io.req.bits.ooo := loop_requesting_ex.ooo + ex.io.req.bits.c_addr_start := ex_c_addr_start + ex.io.req.bits.loop_id := loop_requesting_ex_id + + ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.lda_started && + loop_requesting_ex.ldb_started && loop_requesting_ex.ldd_started && loop_requesting_ex.configured && + exs.map(_.io.req.ready).reduce(_ && _) // TODO ready-valid loop + + when (ex.io.req.fire()) { + loop_requesting_ex.running := true.B + loop_requesting_ex.ex_started := true.B + + when (loop_requesting_ex.c_dram_addr =/= 0.U) { + ex_c_addr_start := floorAdd(ex_c_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + } val loop_requesting_ldD_id = Mux(head_loop.ldd_started, tail_loop_id, head_loop_id) val loop_requesting_ldD = loops(loop_requesting_ldD_id) @@ -876,9 +1363,14 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: loops(ldB.io.loop_id).ldb_completed := true.B } + /* when (ex.io.idle && loops(ex.io.loop_id).running && loops(ex.io.loop_id).ex_started) { loops(ex.io.loop_id).ex_completed := true.B } + */ + when (exs.map(_.io.idle).reduce(_ && _) && loops(exs.head.io.loop_id).running && loops(exs.head.io.loop_id).ex_started) { + loops(exs.head.io.loop_id).ex_completed := true.B + } when (ldD.io.idle && loops(ldD.io.loop_id).running && loops(ldD.io.loop_id).ldd_started) { loops(ldD.io.loop_id).ldd_completed := true.B @@ -904,16 +1396,35 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: } object LoopMatmul { - def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt, - block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, - max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) - (implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = { - val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts, - max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes)) + def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt, ex_k_utilizations: Vec[UInt], + block_size: Int, coreMaxAddrBits: Int, rob_size: Int, rob_full_entries: Int, max_lds: Int, max_exs: Int, max_sts: Int, + max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int, cmd_t: GemminiCmd, ex_total_k_portions: Int, ex_fine_grained_interleaving: Boolean, + local_addr_t: LocalAddr, lean_weightA: Boolean, lean_ooo_rob: Boolean, staticWeightAEnabled: Boolean) + (implicit p: Parameters): Tuple2[DecoupledIO[GemminiCmd], Bool] = { + val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, rob_full_entries, max_lds, max_exs, max_sts, + max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes, cmd_t, ex_total_k_portions, ex_fine_grained_interleaving, + local_addr_t, lean_weightA, lean_ooo_rob, staticWeightAEnabled)) mod.io.in <> in mod.io.ld_utilization := ld_utilization mod.io.st_utilization := st_utilization mod.io.ex_utilization := ex_utilization + mod.io.ex_k_portion_utilizations := ex_k_utilizations (mod.io.out, mod.io.busy) } } + +class ExArbiter[T <: Data](gen: T, n: Int, ex_fine_grained: Boolean, iterator_bitwidth: Int) extends Module { + val io = IO(new Bundle { + val in = Flipped(Vec(n, Decoupled(gen))) + val out = Decoupled(gen) + val k = Input(Vec(n, UInt(iterator_bitwidth.W))) + }) + + val chosen = (io.in zip io.k).zipWithIndex.foldLeft(0.U) { case (acc, ((in, k), i)) => + if (ex_fine_grained) Mux(io.in(acc).valid, Mux(in.valid && k < io.k(acc), i.U, acc), i.U) + else Mux(io.in(acc).valid, acc, i.U) + } + + io.in.foreach(_.ready := false.B) + io.out <> io.in(chosen) +} diff --git a/src/main/scala/gemmini/MeshWithDelays.scala b/src/main/scala/gemmini/MeshWithDelays.scala index acab135d..1b071a28 100644 --- a/src/main/scala/gemmini/MeshWithDelays.scala +++ b/src/main/scala/gemmini/MeshWithDelays.scala @@ -113,10 +113,15 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] val last_fire = fire_counter === total_fires - 1.U && input_next_row_into_spatial_array + val preloads = RegInit(0.U(32.W)) + dontTouch(preloads) + when (io.req.fire()) { req.push(io.req.bits) in_prop := io.req.bits.pe_control.propagate ^ in_prop matmul_id := wrappingAdd(matmul_id, 1.U, max_simultaneous_matmuls) + + preloads := preloads + io.req.bits.pe_control.propagate }.elsewhen (last_fire) { req.valid := req.bits.flush > 1.U req.bits.flush := req.bits.flush - 1.U diff --git a/src/main/scala/gemmini/PreloadFilter.scala b/src/main/scala/gemmini/PreloadFilter.scala new file mode 100644 index 00000000..129a4312 --- /dev/null +++ b/src/main/scala/gemmini/PreloadFilter.scala @@ -0,0 +1,165 @@ +package gemmini + +import chisel3._ +import chisel3.util._ +import freechips.rocketchip.tile.RoCCCommand +import GemminiISA._ +import Util._ + +class PreloadFilter[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: RoCCCommand) extends Module { + import config._ + + val io = IO(new Bundle { + val in_ld = Flipped(new ROBIssue(cmd_t, rob_entries)) + val in_ex = Flipped(new ROBIssue(cmd_t, rob_entries)) + + val out_ld = new ROBIssue(cmd_t, rob_entries) + val out_ex = new ROBIssue(cmd_t, rob_entries) + }) + + val block_cols = meshColumns * tileColumns + val block_rows = meshRows * tileRows + val block_size = block_rows max block_cols + + class AddressRangeT extends Bundle { + // TODO maybe this should be merged with OpT in ROB.scala? + val start = local_addr_t.cloneType + val end = local_addr_t.cloneType + val wraps_around = Bool() + + def overlaps(other: AddressRangeT): Bool = { + ((other.start <= start && (start < other.end || other.wraps_around)) || + (start <= other.start && (other.start < end || wraps_around))) && + !(start.is_garbage() || other.start.is_garbage()) // TODO the "is_garbage" check might not really be necessary + } + + def ===(other: AddressRangeT): Bool = { + start === other.start && end === other.end && wraps_around === other.wraps_around + } + + def make_this_garbage(dummy: Int=0): Unit = { + start.make_this_garbage() + } + + def is_garbage(dummy: Int=0): Bool = start.is_garbage() + } + + val df = if (dataflow == Dataflow.BOTH) Reg(UInt(1.W)) else dataflow.id.U // TODO magic numbers + val b_transposed = Reg(Bool()) + val preloaded_address = Reg(new AddressRangeT) + val ld_block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) + val last_preload_was_filtered = RegInit(false.B) + + val ex_set_only_strides = io.in_ex.cmd.rs1(7) // TODO magic numbers + val ex_is_config = io.in_ex.cmd.inst.funct === CONFIG_CMD && io.in_ex.cmd.rs1(1,0).asUInt() === CONFIG_EX && !ex_set_only_strides // TODO magic numbers + val ex_config_dataflow = io.in_ex.cmd.rs1(2) // TODO magic numbers + val ex_config_b_transposed = io.in_ex.cmd.rs1(9) // TODO magic numbers + val ex_is_preload = io.in_ex.cmd.inst.funct === PRELOAD_CMD + val ex_is_compute = io.in_ex.cmd.inst.funct === COMPUTE_AND_STAY_CMD || io.in_ex.cmd.inst.funct === COMPUTE_AND_FLIP_CMD + val ex_preload_rows = { + val default_rows = io.in_ex.cmd.rs1(48 + log2Up(block_size) - 1, 48).asUInt() // TODO magic numbers + val default_cols = io.in_ex.cmd.rs1(32 + log2Up(block_size) - 1, 32).asUInt() // TODO magic numbers + Mux(b_transposed, default_cols, default_rows) + } + val ex_preload_addr = { + val start = io.in_ex.cmd.rs1(31, 0).asTypeOf(local_addr_t) // TODO magic numbers + val (end, wraps_around) = start.add_with_overflow(ex_preload_rows) + + val addr = Wire(new AddressRangeT) + addr.start := start + addr.end := end + addr.wraps_around := wraps_around + + if (!ex_read_from_acc) { + start.is_acc_addr := false.B + end.is_acc_addr := false.B + } + + addr + } + val should_filter_preload = ex_is_preload && df === Dataflow.WS.id.U && preloaded_address === ex_preload_addr + + val ld_is_config = io.in_ld.cmd.inst.funct === CONFIG_CMD + val ld_id = Mux(ld_is_config, io.in_ld.cmd.rs1(4,3).asUInt(), // TODO magic numbers + MuxCase(0.U, Seq((io.in_ld.cmd.inst.funct === LOAD2_CMD) -> 1.U, + (io.in_ld.cmd.inst.funct === LOAD3_CMD) -> 2.U))) + val ld_config_block_stride = io.in_ld.cmd.rs1(31, 16).asUInt() // TODO magic numbers + val ld_total_rows = { + val block_stride = ld_block_strides(ld_id) + val ld_cols = io.in_ld.cmd.rs2(32 + mvin_cols_bits - 1, 32).asUInt() // TODO magic numbers + val ld_rows = io.in_ld.cmd.rs2(48 + mvin_rows_bits - 1, 48).asUInt() // TODO magic numbers + val ld_mats = ld_cols / block_cols.U + (ld_cols % block_cols.U =/= 0.U) + ((ld_mats - 1.U) * block_stride) + ld_rows + } + val ld_addr = { + val start = io.in_ld.cmd.rs2(31, 0).asTypeOf(local_addr_t) // TODO magic numbers + val (end, wraps_around) = start.add_with_overflow(ld_total_rows) + + val addr = Wire(new AddressRangeT) + addr.start := start + addr.end := end + addr.wraps_around := wraps_around + + addr + } + + // Set all state registers + when (io.in_ld.fire()) { + when (ld_is_config) { + ld_block_strides(ld_id) := ld_config_block_stride + }.elsewhen(preloaded_address.overlaps(ld_addr)) { + preloaded_address.make_this_garbage() + } + } + + when (io.in_ex.fire()) { + when (ex_is_config) { + if (dataflow == Dataflow.BOTH) { + df := ex_config_dataflow + } + b_transposed := ex_config_b_transposed + + when (b_transposed =/= ex_config_b_transposed) { + preloaded_address.make_this_garbage() + } + }.elsewhen(ex_is_preload && !ex_preload_addr.is_garbage()) { + preloaded_address := ex_preload_addr + } + + when (should_filter_preload) { + last_preload_was_filtered := true.B + }.elsewhen(ex_is_compute) { + last_preload_was_filtered := false.B + } + } + + // Set outputs + io.out_ld <> io.in_ld + io.out_ex <> io.in_ex + + when (should_filter_preload) { + io.out_ex.cmd.rs1 := (block_rows.U << 48) | (block_cols.U << 32) | GARBAGE_ADDR // TODO magic numbers + }.elsewhen(ex_is_compute && last_preload_was_filtered) { + io.out_ex.cmd.inst.funct := COMPUTE_AND_STAY_CMD + } + + when (reset.toBool()) { + preloaded_address.make_this_garbage() + } + + assert(!(io.in_ex.valid && io.in_ld.valid && !ld_is_config && ex_is_preload && ex_preload_addr.overlaps(ld_addr))) +} + +object PreloadFilter{ + def apply[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: RoCCCommand, ld_issue: ROBIssue[RoCCCommand], ex_issue: ROBIssue[RoCCCommand]) = { + if (config.use_preload_filter) { + val preload_filter = Module(new PreloadFilter(config, cmd_t)) + preload_filter.io.in_ld <> ld_issue + preload_filter.io.in_ex <> ex_issue + (preload_filter.io.out_ld, preload_filter.io.out_ex) + } else { + (ld_issue, ex_issue) + } + } +} + diff --git a/src/main/scala/gemmini/ROB.scala b/src/main/scala/gemmini/ROB.scala index de02780b..1c7873b0 100644 --- a/src/main/scala/gemmini/ROB.scala +++ b/src/main/scala/gemmini/ROB.scala @@ -21,14 +21,14 @@ class ROBIssue[T <: Data](cmd_t: T, rob_entries: Int) extends Bundle { } // TODO we don't need to store the full command in here. We should be able to release the command directly into the relevant controller and only store the associated metadata in the ROB. This would reduce the size considerably -class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: RoCCCommand) extends Module { +class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: RoCCCommand, gemmini_cmd_t: GemminiCmd) extends Module { import config._ val block_rows = tileRows * meshRows val block_cols = tileColumns * meshColumns val io = IO(new Bundle { - val alloc = Flipped(Decoupled(cmd_t.cloneType)) + val alloc = Flipped(Decoupled(gemmini_cmd_t.cloneType)) val completed = Flipped(Valid(UInt(log2Up(rob_entries).W))) @@ -42,6 +42,8 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val st_utilization = Output(UInt(log2Up(rob_entries+1).W)) val ex_utilization = Output(UInt(log2Up(rob_entries+1).W)) + val ex_k_portion_utilizations = Output(Vec(ex_total_k_portions, UInt(log2Up(rob_entries+1).W))) + val busy = Output(Bool()) val solitary_preload = Input(Bool()) // TODO very hacky. from ExecuteController, to prevent infinite fence stalls. remove later @@ -56,10 +58,41 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val end = local_addr_t.cloneType val wraps_around = Bool() - def overlaps(other: OpT): Bool = { - ((other.start <= start && (start < other.end || other.wraps_around)) || + val i = UInt(16.W) + val j = UInt(16.W) + val k = UInt(16.W) + val i_len = UInt(16.W) + val j_len = UInt(16.W) + val k_len = UInt(16.W) + val use_iterators = Bool() + + def overlaps(other: OpT, check_accumulates: Boolean=false, compare_i_and_k: Bool=false.B, compare_i_and_j: Bool=false.B): Bool = { + val without_iterators = ((other.start <= start && (start < other.end || other.wraps_around)) || (start <= other.start && (other.start < end || wraps_around))) && + (!check_accumulates.B || + !(start.is_acc_addr && start.accumulate && other.start.is_acc_addr && other.start.accumulate)) && + !(start.is_garbage() || other.start.is_garbage()) // TODO the "is_garbage" check might not really be necessary + + val with_iterators_ik = ((other.i <= i && (i < other.i + other.i_len)) || + (i <= other.i && other.i < i + i_len)) && + ((other.k <= k && (k < other.k + other.k_len)) || + (k <= other.k && other.k < k + k_len)) && + (!check_accumulates.B || + !(start.is_acc_addr && start.accumulate && other.start.is_acc_addr && other.start.accumulate)) && !(start.is_garbage() || other.start.is_garbage()) // TODO the "is_garbage" check might not really be necessary + + val with_iterators_ij = ((other.i <= i && (i < other.i + other.i_len)) || + (i <= other.i && other.i < i + i_len)) && + ((other.j <= j && (j < other.j + other.j_len)) || + (j <= other.j && other.j < j + j_len)) && + (!check_accumulates.B || + !(start.is_acc_addr && start.accumulate && other.start.is_acc_addr && other.start.accumulate)) && + !(start.is_garbage() || other.start.is_garbage()) // TODO the "is_garbage" check might not really be necessary + + assert(!(compare_i_and_j && compare_i_and_k)) + assert(!compare_i_and_j || !compare_i_and_k || (use_iterators && other.use_iterators)) + + Mux(compare_i_and_k, with_iterators_ik, Mux(compare_i_and_j, with_iterators_ij, without_iterators)) } } @@ -92,13 +125,19 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val deps = Vec(rob_entries, Bool()) def ready(dummy: Int = 0): Bool = !deps.reduce(_ || _) + val ex_k_portion = UInt(log2Up(ex_total_k_portions).W) + + // Signals that are necessary for OoO operation + val waiting_for_compute_inst = Bool() + // Debugging signals val allocated_at = UInt(instructions_allocated.getWidth.W) + val stall_cycles_before_issue = UInt(32.W) // TODO magic number } val full_entries = Reg(Vec(rob_full_entries, UDValid(new Entry))) val partial_entries = Reg(Vec(rob_partial_entries, UDValid(new Entry))) - val entries = full_entries ++ partial_entries + val entries = full_entries ++ partial_entries // WARNING: The last_allocated_preload code below assumes that full_entries comes before the partial_entries val empty = !entries.map(_.valid).reduce(_ || _) val full = entries.map(_.valid).reduce(_ && _) @@ -109,7 +148,6 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val solitary_preload = utilization === 1.U && entries.map(e => e.valid && e.bits.cmd.inst.funct === PRELOAD_CMD).reduce(_ || _) io.busy := !empty && !(solitary_preload && io.solitary_preload) - // Config values set by programmer val a_stride = Reg(UInt(16.W)) // TODO magic numbers val c_stride = Reg(UInt(16.W)) // TODO magic numbers @@ -117,6 +155,16 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val ld_block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) val st_block_stride = block_rows.U val pooling_is_enabled = Reg(Bool()) + val current_dataflow = if (dataflow == Dataflow.BOTH) Reg(UInt(1.W)) else dataflow.id.U // TODO magic number + + val ex_ooo_is_enabled = ex_ooo.B && current_dataflow === Dataflow.WS.id.U + + // Registers to help keep OOO execute working properly + val last_allocated_preload = Reg(UDValid(UInt(log2Up(rob_entries).W))) + val last_allocated_garbage_preload = Reg(UDValid(UInt(log2Up(rob_entries).W))) + val last_allocated_preload_being_updated = WireInit(false.B) + val last_allocated_garbage_preload_being_updated = WireInit(false.B) + val last_ex_issued_was_preload = RegInit(false.B) val new_entry = Wire(new Entry) new_entry := DontCare @@ -154,7 +202,8 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf io.alloc.ready := false.B when (io.alloc.valid) { val spAddrBits = 32 - val cmd = io.alloc.bits + val gemmini_cmd = io.alloc.bits + val cmd = io.alloc.bits.cmd val funct = cmd.inst.funct val funct_is_compute = funct === COMPUTE_AND_STAY_CMD || funct === COMPUTE_AND_FLIP_CMD val config_cmd_type = cmd.rs1(1,0) // TODO magic numbers @@ -164,15 +213,40 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf new_entry.is_config := funct === CONFIG_CMD + new_entry.ex_k_portion := io.alloc.bits.ex_k_portion + + new_entry.stall_cycles_before_issue := 0.U + val op1 = Wire(UDValid(new OpT)) op1.valid := false.B op1.bits := DontCare + op1.bits.i := gemmini_cmd.i + op1.bits.j := gemmini_cmd.j + op1.bits.k := gemmini_cmd.k + op1.bits.i_len := 1.U + op1.bits.j_len := 1.U + op1.bits.k_len := 1.U + op1.bits.use_iterators := gemmini_cmd.use_iterators val op2 = Wire(UDValid(new OpT)) op2.valid := false.B op2.bits := DontCare + op2.bits.i := gemmini_cmd.i + op2.bits.j := gemmini_cmd.j + op2.bits.k := gemmini_cmd.k + op2.bits.i_len := 1.U + op2.bits.j_len := 1.U + op2.bits.k_len := 1.U + op2.bits.use_iterators := gemmini_cmd.use_iterators val dst = Wire(UDValid(new OpT)) dst.valid := false.B dst.bits := DontCare + dst.bits.i := gemmini_cmd.i + dst.bits.j := gemmini_cmd.j + dst.bits.k := gemmini_cmd.k + dst.bits.i_len := 1.U + dst.bits.j_len := 1.U + dst.bits.k_len := 1.U + dst.bits.use_iterators := gemmini_cmd.use_iterators assert(!(op1.valid && op2.valid && dst.valid)) new_entry.opa_is_dst := dst.valid @@ -192,19 +266,37 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf op1.bits.end := op1.bits.start + preload_rows op1.bits.wraps_around := op1.bits.start.add_with_overflow(preload_rows)._2 }.otherwise { - val rows = cmd.rs1(48 + log2Up(block_rows + 1) - 1, 48) + val rows = cmd.rs1(48 + mvin_cols_bits - 1, 48) + val k = gemmini_cmd.max_k + + val mats = rows / block_rows.U + (rows % block_rows.U =/= 0.U) + val total_rows = ((mats - 1.U) * k * block_rows.U) + Mux(rows % block_rows.U === 0.U, block_rows.U, rows % block_rows.U) + val cols = cmd.rs1(32 + log2Up(block_cols + 1) - 1, 32) - val compute_rows = Mux(a_transpose, cols, rows) * a_stride + val compute_rows = Mux(a_transpose, cols, total_rows) * a_stride + op1.bits.end := op1.bits.start + compute_rows op1.bits.wraps_around := op1.bits.start.add_with_overflow(compute_rows)._2 + + op1.bits.i_len := mats } op2.valid := funct_is_compute || funct === STORE_CMD op2.bits.start := cmd.rs2.asTypeOf(local_addr_t) when (funct_is_compute) { - val compute_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) - op2.bits.end := op2.bits.start + compute_rows - op2.bits.wraps_around := op2.bits.start.add_with_overflow(compute_rows)._2 + /* + val rows = cmd.rs2(48 + mvin_cols_bits - 1, 48) + val j = gemmini_cmd.max_j + + val mats = rows / block_rows.U + (rows % block_rows.U =/= 0.U) + val total_rows = ((mats - 1.U) * j * block_rows.U) + Mux(rows % block_rows.U === 0.U, block_rows.U, rows % block_rows.U) + + op2.bits.end := op2.bits.start + total_rows + op2.bits.wraps_around := op2.bits.start.add_with_overflow(total_rows)._2 + */ + + op2.bits.end := GARBAGE_ADDR.asTypeOf(local_addr_t) + op2.bits.wraps_around := false.B }.elsewhen (pooling_is_enabled) { // If pooling is enabled, then we assume that this command simply mvouts everything in this accumulator bank from // start to the end of the bank // TODO this won't work when acc_banks =/= 2 @@ -227,14 +319,25 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf op2.bits.end := op2.bits.start + total_mvout_rows op2.bits.wraps_around := pooling_is_enabled || op2.bits.start.add_with_overflow(total_mvout_rows)._2 + + op2.bits.j_len := mvout_mats } dst.valid := funct === PRELOAD_CMD || funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD dst.bits.start := cmd.rs2(31, 0).asTypeOf(local_addr_t) when (funct === PRELOAD_CMD) { - val preload_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) * c_stride + val rows = cmd.rs2(48 + mvin_cols_bits - 1, 48) + val j = gemmini_cmd.max_j + + val mats = rows / block_rows.U + (rows % block_rows.U =/= 0.U) + val total_rows = ((mats - 1.U) * j * block_rows.U) + Mux(rows % block_rows.U === 0.U, block_rows.U, rows % block_rows.U) + + val preload_rows = total_rows * c_stride + dst.bits.end := dst.bits.start + preload_rows dst.bits.wraps_around := dst.bits.start.add_with_overflow(preload_rows)._2 + + dst.bits.i_len := mats }.otherwise { val id = MuxCase(0.U, Seq((new_entry.cmd.inst.funct === LOAD2_CMD) -> 1.U, (new_entry.cmd.inst.funct === LOAD3_CMD) -> 2.U)) @@ -248,6 +351,8 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf dst.bits.end := dst.bits.start + total_mvin_rows dst.bits.wraps_around := dst.bits.start.add_with_overflow(total_mvin_rows)._2 + + dst.bits.k_len := mvin_mats } val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD) @@ -262,12 +367,14 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf )) assert(is_load || is_store || is_ex) - // This can be RAW op1/op2 <- dst, or WAW dst <- dst - val opa_matches_opa = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opa.bits.overlaps(e.bits.opa.bits) }) + // This can be RAW op1/op2 <- dst + val opa_matches_opa = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opa.bits.overlaps(e.bits.opa.bits, compare_i_and_k=(funct_is_compute && e.bits.cmd.inst.funct === LOAD_CMD && new_entry.opa.bits.use_iterators)) }) + // This can be WAW dst <- dst + val opa_matches_opa_for_waws = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opa.bits.overlaps(e.bits.opa.bits, check_accumulates=true, compare_i_and_j=(funct === PRELOAD_CMD && e.bits.cmd.inst.funct === PRELOAD_CMD && new_entry.opa.bits.use_iterators)) }) // This can be WAR dst <- op1/op2 val opa_matches_opb = VecInit(entries.map { e => e.valid && e.bits.opb.valid && new_entry.opa.bits.overlaps(e.bits.opb.bits) }) // This can be RAW op2 <- dst - val opb_matches_opa = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opb.bits.overlaps(e.bits.opa.bits) }) + val opb_matches_opa = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opb.bits.overlaps(e.bits.opa.bits, compare_i_and_k=(funct_is_compute && e.bits.cmd.inst.funct === LOAD_CMD && new_entry.opa.bits.use_iterators)) }) val op1_matches_opa = VecInit((entries zip (opa_matches_opa zip opb_matches_opa)).map { case (e, (a, b)) => e.valid && op1.valid && Mux(dst.valid, b, a) @@ -278,33 +385,62 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val dst_matches_opa = VecInit((entries zip opa_matches_opa).map { case (e, a) => e.valid && dst.valid && a }) + val dst_matches_opa_for_waws = VecInit((entries zip opa_matches_opa_for_waws).map { case (e, a) => + e.valid && dst.valid && a + }) val dst_matches_opb = VecInit((entries zip opa_matches_opb).map { case (e, b) => e.valid && dst.valid && b }) + def compare_q(e: Entry, new_entry: Entry): Bool = { + // This function returns true if these entries are in different queues, or if they're in the + // same q, but "e" has not been issued yet. + e.q =/= new_entry.q || (e.q === new_entry.q && !e.issued) + } + + val new_entry_is_ld_and_other_is_ex = entries.map { e => + e.valid && e.bits.q === exq && new_entry.q === ldq + } + val op1_raws_opa = VecInit((entries zip op1_matches_opa).map { case (e, m) => - m && op1.valid && e.bits.q =/= new_entry.q && e.bits.opa_is_dst + m && op1.valid && compare_q(e.bits, new_entry) && e.bits.opa_is_dst }) val op2_raws_opa = VecInit((entries zip op2_matches_opa).map { case (e, m) => - m && op2.valid && e.bits.q =/= new_entry.q && e.bits.opa_is_dst + m && op2.valid && compare_q(e.bits, new_entry) && e.bits.opa_is_dst }) val raws = VecInit((op1_raws_opa zip op2_raws_opa).map { case (a, b) => a || b }) val dst_wars_opa = VecInit((entries zip dst_matches_opa).map { case (e, m) => - m && dst.valid && e.bits.q =/= new_entry.q && !e.bits.opa_is_dst + m && dst.valid && compare_q(e.bits, new_entry) && !e.bits.opa_is_dst }) val dst_wars_opb = VecInit((entries zip dst_matches_opb).map { case (e, m) => - m && dst.valid && e.bits.q =/= new_entry.q + m && dst.valid && compare_q(e.bits, new_entry) }) - val wars = VecInit((dst_wars_opa zip dst_wars_opb).map { case (a, b) => a || b }) + val wars = VecInit((dst_wars_opa, dst_wars_opb, new_entry_is_ld_and_other_is_ex).zipped.map { case (a, b, c) => (a || b) && !c }) - val dst_waws_opa = VecInit((entries zip dst_matches_opa).map { case (e, m) => - m && dst.valid && (e.bits.q =/= new_entry.q || new_entry.q === ldq) && e.bits.opa_is_dst + val dst_waws_opa = VecInit((entries zip dst_matches_opa_for_waws).map { case (e, m) => + m && dst.valid && (compare_q(e.bits, new_entry) || new_entry.q === ldq) && e.bits.opa_is_dst }) val waws = dst_waws_opa - val older_in_same_q = VecInit(entries.map { e => - e.valid && e.bits.q === new_entry.q && !e.bits.issued + val older_in_same_q = VecInit(entries.zipWithIndex.map { case (e, i) => + + val ooo_q = (ld_ooo.B && new_entry.q === ldq) || (ex_ooo_is_enabled && new_entry.q === exq) || (st_ooo.B && new_entry.q === stq) + + val is_last_preload = last_allocated_preload.valid && i.U === last_allocated_preload.bits + val is_last_garbage_preload = !lean_ooo_rob.B && last_allocated_garbage_preload.valid && i.U === last_allocated_garbage_preload.bits + + val new_entry_is_compute = new_entry.cmd.inst.funct === COMPUTE_AND_STAY_CMD || + new_entry.cmd.inst.funct === COMPUTE_AND_FLIP_CMD + val new_entry_is_preload = new_entry.cmd.inst.funct === PRELOAD_CMD + val preload_addr = new_entry.cmd.rs1(31, 0).asTypeOf(local_addr_t) // TODO magic number + val preload_garbage = !lean_ooo_rob.B && preload_addr.is_garbage() + + e.valid && e.bits.q === new_entry.q && !e.bits.issued && + (!ooo_q || e.bits.is_config || new_entry.is_config || + ((new_entry_is_compute && is_last_preload) || + (new_entry_is_preload && preload_garbage && is_last_preload) || + (new_entry_is_preload && !preload_garbage && is_last_garbage_preload))) }) val is_st_and_must_wait_for_prior_ex_config = VecInit(entries.map { e => @@ -329,6 +465,8 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf new_entry.complete_on_issue := new_entry.is_config && new_entry.q =/= exq + new_entry.waiting_for_compute_inst := ex_ooo.B && funct === PRELOAD_CMD + val is_full = PopCount(Seq(dst.valid, op1.valid, op2.valid)) > 1.U val full_alloc_id = MuxCase((rob_full_entries-1).U, full_entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) val partial_alloc_id = MuxCase((rob_partial_entries-1).U, partial_entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) @@ -354,6 +492,7 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf val set_only_strides = new_entry.cmd.rs1(7) // TODO magic numbers when (!set_only_strides) { a_transpose := new_entry.cmd.rs1(8) // TODO magic numbers + if (dataflow == Dataflow.BOTH) current_dataflow := new_entry.cmd.rs1(2) // TODO magic numbers } }.elsewhen(new_entry.is_config && new_entry.q === ldq) { val id = new_entry.cmd.rs1(4,3) // TODO magic numbers @@ -362,13 +501,38 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf }.elsewhen(new_entry.is_config && new_entry.q === stq) { val pool_stride = new_entry.cmd.rs1(5, 4) // TODO magic numbers pooling_is_enabled := pool_stride =/= 0.U + }.elsewhen(new_entry.cmd.inst.funct === PRELOAD_CMD) { + last_allocated_preload.push(full_alloc_id) + last_allocated_preload_being_updated := true.B + + val preload_addr = new_entry.cmd.rs1(31, 0).asTypeOf(local_addr_t) // TODO magic number + when (preload_addr.is_garbage()) { + last_allocated_garbage_preload.push(full_alloc_id) + last_allocated_garbage_preload_being_updated := true.B + } + }.elsewhen(funct_is_compute) { + when (last_allocated_preload.valid) { + entries.zipWithIndex.foreach { case (e,i) => + when (i.U === last_allocated_preload.bits) { + e.bits.waiting_for_compute_inst := false.B + e.bits.deps := (e.bits.deps.asUInt() | new_entry.deps.asUInt()).asTypeOf(e.bits.deps) + assert(e.valid) + } + } + } } } } // Issue commands which are ready to be issued Seq((ldq, io.issue.ld), (stq, io.issue.st), (exq, io.issue.ex)).foreach { case (q, io) => - val issue_valids = entries.map(e => e.valid && e.bits.ready() && !e.bits.issued && e.bits.q === q) + val must_be_compute = q === exq && last_ex_issued_was_preload && ex_ooo.B + + val issue_valids = entries.map { e => + val is_compute = e.bits.cmd.inst.funct === COMPUTE_AND_FLIP_CMD || e.bits.cmd.inst.funct === COMPUTE_AND_STAY_CMD + e.valid && e.bits.ready() && !e.bits.issued && e.bits.q === q && (!must_be_compute || is_compute) && + (!ex_ooo.B || !e.bits.waiting_for_compute_inst) + } val issue_sel = PriorityEncoderOH(issue_valids) val issue_id = OHToUInt(issue_sel) val issue_entry = Mux1H(issue_sel, entries) @@ -394,9 +558,18 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf e.valid := !e.bits.complete_on_issue } } + + val stall_limit = 5000.U // ALON: This magic number defines when a command will be considered "stalled" + when (issue_entry.bits.stall_cycles_before_issue > stall_limit && q === exq) { + printf(p"command stalled: (funct: ${issue_entry.bits.cmd.inst.funct}) (k_portion: ${issue_entry.bits.ex_k_portion})\n") + } } } + when (io.issue.ex.fire()) { + last_ex_issued_was_preload := io.issue.ex.cmd.inst.funct === PRELOAD_CMD + } + // Mark entries as completed once they've returned when (io.completed.fire()) { entries.foreach(_.bits.deps(io.completed.bits) := false.B) @@ -405,10 +578,30 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf when (i.U === io.completed.bits) { e.valid := false.B assert(e.valid) + + when (last_allocated_preload.bits === i.U && !last_allocated_preload_being_updated) { + last_allocated_preload.pop() + } + + when (last_allocated_garbage_preload.bits === i.U && !last_allocated_garbage_preload_being_updated) { + last_allocated_garbage_preload.pop() + } } } } + // Increment stall counters + entries.foreach { e => + when (e.valid && !e.bits.issued) { + e.bits.stall_cycles_before_issue := e.bits.stall_cycles_before_issue + 1.U + } + } + + // Hardcode deps that point to the entry that owns the deps to 0 + entries.zipWithIndex.foreach { case (e,i) => + e.bits.deps(i) := false.B + } + // val utilization = PopCount(entries.map(e => e.valid)) val utilization_ld_q_unissued = PopCount(entries.map(e => e.valid && !e.bits.issued && e.bits.q === ldq)) val utilization_st_q_unissued = PopCount(entries.map(e => e.valid && !e.bits.issued && e.bits.q === stq)) @@ -421,6 +614,10 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf io.st_utilization := utilization_st_q io.ex_utilization := utilization_ex_q + io.ex_k_portion_utilizations.zipWithIndex.foreach { case (io, k) => + io := PopCount(entries.map(e => e.valid && e.bits.q === exq && !e.bits.issued && e.bits.ex_k_portion === k.U)) + } + val valids = VecInit(entries.map(_.valid)) val functs = VecInit(entries.map(_.bits.cmd.inst.funct)) val issueds = VecInit(entries.map(_.bits.issued)) @@ -450,6 +647,16 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf dontTouch(e.bits.allocated_at) } + val first_preload_allocated = RegInit(false.B) + when (io.alloc.fire() && io.alloc.bits.cmd.inst.funct === PRELOAD_CMD) { + first_preload_allocated := true.B + } + val cycles_that_ex_stalls_due_to_dependencies = RegInit(0.U(32.W)) + when (first_preload_allocated && utilization_ex_q > 0.U && !io.issue.ex.valid) { + cycles_that_ex_stalls_due_to_dependencies := cycles_that_ex_stalls_due_to_dependencies + 1.U + } + dontTouch(cycles_that_ex_stalls_due_to_dependencies) + val cntr = Counter(10000000) when (cntr.inc()) { printf(p"Utilization: $utilization\n") @@ -462,7 +669,13 @@ class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConf printf(p"Packed deps: $packed_deps\n") } + if (lean_ooo_rob) { + last_allocated_garbage_preload.pop() + } + when (reset.asBool()) { - entries.foreach(_.valid := false.B) + entries.foreach(_.pop()) + last_allocated_preload.pop() + last_allocated_garbage_preload.pop() } } diff --git a/src/main/scala/gemmini/StoreController.scala b/src/main/scala/gemmini/StoreController.scala index 98584bca..bbf2b898 100644 --- a/src/main/scala/gemmini/StoreController.scala +++ b/src/main/scala/gemmini/StoreController.scala @@ -118,7 +118,9 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm ((config.sp_banks * config.sp_bank_entries) max (config.acc_banks * config.acc_bank_entries)) - val cmd_tracker = Module(new DMACommandTracker(nCmds, cmd_tracker_max_rows, deps_t)) + val cmd_tracker = Module(new DMACommandTracker(nCmds, cmd_tracker_max_rows, deps_t, prng_seed = prng_seed, + proportion_of_slow_accesses_out_of_128 = if (delay_sts) proportion_of_slow_accesses_out_of_128 else 0, + stall_delay = stall_delay)) // DMA IO wiring io.dma.req.valid := (control_state === waiting_for_command && cmd.valid && DoStore && cmd_tracker.io.alloc.ready) || diff --git a/src/main/scala/gemmini/Util.scala b/src/main/scala/gemmini/Util.scala index 511cfee2..5bc152f6 100644 --- a/src/main/scala/gemmini/Util.scala +++ b/src/main/scala/gemmini/Util.scala @@ -35,12 +35,12 @@ object Util { Mux(u +& v > max, max, u + v) } - def floorAdd(u: UInt, n: UInt, max_plus_one: UInt, en: Bool = true.B): UInt = { + def floorAdd(u: UInt, n: UInt, max_plus_one: UInt, en: Bool = true.B, min: UInt = 0.U): UInt = { val max = max_plus_one - 1.U MuxCase(u + n, Seq( (!en) -> u, - ((u +& n) > max) -> 0.U + ((u +& n) > max) -> min )) } diff --git a/src/main/scala/gemmini/WeightedArbiter.scala b/src/main/scala/gemmini/WeightedArbiter.scala index 2264aeea..80cabe4d 100644 --- a/src/main/scala/gemmini/WeightedArbiter.scala +++ b/src/main/scala/gemmini/WeightedArbiter.scala @@ -4,7 +4,7 @@ import chisel3._ import chisel3.util._ import Util._ -class WeightedArbiter[T <: Data](t: T, maxWeightA: Int) extends Module { +class WeightedArbiter[T <: Data](t: T, maxWeightA: Int, staticWeightAEnabled: Boolean, onlyStaticWeightA: Boolean) extends Module { val io = IO(new Bundle { val inA = Flipped(Decoupled(t)) val inB = Flipped(Decoupled(t)) @@ -12,13 +12,22 @@ class WeightedArbiter[T <: Data](t: T, maxWeightA: Int) extends Module { val forceA = Input(Bool()) val forceB = Input(Bool()) val out = Decoupled(t) + + val inA_idle = Input(Bool()) + val inB_idle = Input(Bool()) + val inA_k = Input(UInt(16.W)) // TODO magic number + val inB_k = Input(UInt(16.W)) // TODO magic number + val inA_i = Input(UInt(16.W)) // TODO magic number + val inB_j = Input(UInt(16.W)) // TODO magic number }) val count = Reg(UInt(log2Up(maxWeightA+1).W)) val A_chosen = WireInit(false.B) val B_chosen = WireInit(false.B) - val weightA = io.weightA + val weightA = if (onlyStaticWeightA) { 0.U } else { io.weightA } + + val staticWeightA = weightA === 0.U && staticWeightAEnabled.B io.inA.ready := false.B io.inB.ready := false.B @@ -27,18 +36,30 @@ class WeightedArbiter[T <: Data](t: T, maxWeightA: Int) extends Module { io.out <> io.inA }.elsewhen(io.forceB) { io.out <> io.inB - }.elsewhen(io.inA.valid && io.inB.valid) { - when (count < weightA) { + }.elsewhen(!staticWeightA) { + when(io.inA.valid && io.inB.valid) { + when(count < weightA) { + io.out <> io.inA + A_chosen := true.B + }.otherwise { + io.out <> io.inB + B_chosen := true.B + } + }.elsewhen(io.inA.valid) { io.out <> io.inA - A_chosen := true.B }.otherwise { io.out <> io.inB - B_chosen := true.B } - }.elsewhen(io.inA.valid) { - io.out <> io.inA }.otherwise { - io.out <> io.inB + when (io.inA_idle) { + io.out <> io.inB + }.elsewhen(io.inB_idle) { + io.out <> io.inA + }.elsewhen(io.inA_k > io.inB_k || (io.inB_k === 0.U && io.inB_j === 0.U)) { + io.out <> io.inB + }.otherwise { + io.out <> io.inA + } } when (io.out.fire()) { @@ -51,5 +72,5 @@ class WeightedArbiter[T <: Data](t: T, maxWeightA: Int) extends Module { assert(!(io.forceA && io.forceB)) assert(!(A_chosen && B_chosen)) - assert((!io.inA.valid && !io.inB.valid) || weightA > 0.U) + assert((!io.inA.valid && !io.inB.valid) || (weightA > 0.U || staticWeightAEnabled.B)) }