diff --git a/.circleci/build-toolchains.sh b/.circleci/build-toolchains.sh index 39caa772..1a4e4b13 100755 --- a/.circleci/build-toolchains.sh +++ b/.circleci/build-toolchains.sh @@ -28,5 +28,5 @@ if [ ! -d "$HOME/$1-install" ]; then cd $HOME # init all submodules including the tools (doesn't use CI_MAKE_PROC due to mem. constraints) - CHIPYARD_DIR="$LOCAL_CHIPYARD_DIR" NPROC=$CI_MAKE_PROC $LOCAL_CHIPYARD_DIR/scripts/build-toolchains.sh esp-tools + CHIPYARD_DIR="$LOCAL_CHIPYARD_DIR" NPROC=$CI_MAKE_NPROC $LOCAL_CHIPYARD_DIR/scripts/build-toolchains.sh esp-tools fi diff --git a/.circleci/defaults.sh b/.circleci/defaults.sh index 6100774a..2d200104 100755 --- a/.circleci/defaults.sh +++ b/.circleci/defaults.sh @@ -14,7 +14,7 @@ ############# # make parallelism -CI_MAKE_NPROC=8 +CI_MAKE_NPROC=4 LOCAL_MAKE_NPROC=$CI_MAKE_NPROC # verilator version diff --git a/CHIPYARD.hash b/CHIPYARD.hash index 7c244581..70a1842b 100644 --- a/CHIPYARD.hash +++ b/CHIPYARD.hash @@ -1 +1 @@ -6b0d57d60690cc223013ea228b687b519b716c50 +1e2f778a6705033d67ccbcc932e66083e4646f15 diff --git a/SPIKE.hash b/SPIKE.hash index 1d05e3ee..7a511bbc 100644 --- a/SPIKE.hash +++ b/SPIKE.hash @@ -1 +1 @@ -8626fb144e019895767830d850deca7711773e5c +dbd3b0874dde4eead6b8d0c4195ee8b41dd113fc diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 2802ca40..e32f5be3 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 2802ca406323ef511e1c2939387144a47c250638 +Subproject commit e32f5be388e851b7cee073b39f848e8173872700 diff --git a/src/main/scala/gemmini/AccumulatorMem.scala b/src/main/scala/gemmini/AccumulatorMem.scala index e218bd51..0fafb952 100644 --- a/src/main/scala/gemmini/AccumulatorMem.scala +++ b/src/main/scala/gemmini/AccumulatorMem.scala @@ -17,19 +17,21 @@ class AccumulatorReadReq[T <: Data](n: Int, shift_width: Int, scale_t: T) extend override def cloneType: this.type = new AccumulatorReadReq(n, shift_width, scale_t.cloneType).asInstanceOf[this.type] } -class AccumulatorReadResp[T <: Data: Arithmetic](rdataType: Vec[Vec[T]], fullDataType: Vec[Vec[T]]) extends Bundle { - val data = rdataType.cloneType - val full_data = fullDataType.cloneType +class AccumulatorReadResp[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int) extends Bundle { + val data = fullDataType.cloneType val fromDMA = Bool() - - override def cloneType: this.type = new AccumulatorReadResp(rdataType.cloneType, fullDataType.cloneType).asInstanceOf[this.type] + val scale = scale_t.cloneType + val relu6_shift = UInt(shift_width.W) + val act = UInt(2.W) + val acc_bank_id = UInt(2.W) // TODO don't hardcode + override def cloneType: this.type = new AccumulatorReadResp(fullDataType.cloneType, scale_t, shift_width).asInstanceOf[this.type] } -class AccumulatorReadIO[T <: Data: Arithmetic, U <: Data](n: Int, shift_width: Int, rdataType: Vec[Vec[T]], fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { - val req = Decoupled(new AccumulatorReadReq(n, shift_width, scale_t)) - val resp = Flipped(Decoupled(new AccumulatorReadResp(rdataType.cloneType, fullDataType.cloneType))) +class AccumulatorReadIO[T <: Data: Arithmetic, U <: Data](n: Int, shift_width: Int, fullDataType: Vec[Vec[T]], scale_t: U) extends Bundle { + val req = Decoupled(new AccumulatorReadReq[U](n, shift_width, scale_t)) + val resp = Flipped(Decoupled(new AccumulatorReadResp[T, U](fullDataType, scale_t, shift_width))) - override def cloneType: this.type = new AccumulatorReadIO(n, shift_width, rdataType.cloneType, fullDataType.cloneType, scale_t.cloneType).asInstanceOf[this.type] + override def cloneType: this.type = new AccumulatorReadIO(n, shift_width, fullDataType.cloneType, scale_t.cloneType).asInstanceOf[this.type] } class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends Bundle { @@ -42,16 +44,19 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends override def cloneType: this.type = new AccumulatorWriteReq(n, t).asInstanceOf[this.type] } -class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], rdata: Vec[Vec[T]], scale_t: U) extends Bundle { - val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), rdata, t, scale_t)) +class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U) extends Bundle { + val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), t, scale_t)) // val write = Flipped(new AccumulatorWriteIO(n, t)) val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t))) - override def cloneType: this.type = new AccumulatorMemIO(n, t, rdata, scale_t).asInstanceOf[this.type] + override def cloneType: this.type = new AccumulatorMemIO(n, t, scale_t).asInstanceOf[this.type] } -class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Vec[Vec[T]], mem_pipeline: Int, scale_args: ScaleArguments[T, U], read_small_data: Boolean, read_full_data: Boolean) - (implicit ev: Arithmetic[T]) extends Module { +class AccumulatorMem[T <: Data, U <: Data]( + n: Int, t: Vec[Vec[T]], scale_args: ScaleArguments[T, U], + acc_singleported: Boolean, num_acc_sub_banks: Int +) + (implicit ev: Arithmetic[T]) extends Module { // TODO Do writes in this module work with matrices of size 2? If we try to read from an address right after writing // to it, then we might not get the written data. We might need some kind of cooldown counter after addresses in the // accumulator have been written to for configurations with such small matrices @@ -64,9 +69,8 @@ class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Ve import ev._ // TODO unify this with TwoPortSyncMemIO - val io = IO(new AccumulatorMemIO(n, t, rdataType, scale_args.multiplicand_t)) + val io = IO(new AccumulatorMemIO(n, t, scale_args.multiplicand_t)) - val mem = TwoPortSyncMem(n, t, t.getWidth / 8) // TODO We assume byte-alignment here. Use aligned_to instead // For any write operation, we spend 2 cycles reading the existing address out, buffering it in a register, and then // accumulating on top of it (if necessary) @@ -75,55 +79,144 @@ class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Ve val acc_buf = ShiftRegister(io.write.bits.acc, 2) val mask_buf = ShiftRegister(io.write.bits.mask, 2) val w_buf_valid = ShiftRegister(io.write.fire(), 2) - - val w_sum = VecInit((RegNext(mem.io.rdata) zip wdata_buf).map { case (rv, wv) => + val acc_rdata = Wire(t) + acc_rdata := DontCare + val read_rdata = Wire(t) + read_rdata := DontCare + val block_read_req = WireInit(false.B) + val w_sum = VecInit((RegNext(acc_rdata) zip wdata_buf).map { case (rv, wv) => VecInit((rv zip wv).map(t => t._1 + t._2)) }) - mem.io.waddr := waddr_buf - mem.io.wen := w_buf_valid - mem.io.wdata := Mux(acc_buf, w_sum, wdata_buf) - mem.io.mask := mask_buf - - mem.io.raddr := Mux(io.write.fire() && io.write.bits.acc, io.write.bits.addr, io.read.req.bits.addr) - mem.io.ren := io.read.req.fire() || (io.write.fire() && io.write.bits.acc) - - class PipelinedRdataAndActT extends Bundle { - val data = mem.io.rdata.cloneType - val full_data = mem.io.rdata.cloneType - val scale = io.read.req.bits.scale.cloneType - val relu6_shift = io.read.req.bits.relu6_shift.cloneType - val act = io.read.req.bits.act.cloneType - val fromDMA = io.read.req.bits.fromDMA.cloneType + if (!acc_singleported) { + val mem = TwoPortSyncMem(n, t, t.getWidth / 8) // TODO We assume byte-alignment here. Use aligned_to instead + mem.io.waddr := waddr_buf + mem.io.wen := w_buf_valid + mem.io.wdata := Mux(acc_buf, w_sum, wdata_buf) + mem.io.mask := mask_buf + acc_rdata := mem.io.rdata + read_rdata := mem.io.rdata + mem.io.raddr := Mux(io.write.fire() && io.write.bits.acc, io.write.bits.addr, io.read.req.bits.addr) + mem.io.ren := io.read.req.fire() || (io.write.fire() && io.write.bits.acc) + } else { + val mask_len = t.getWidth / 8 + val mask_elem = UInt((t.getWidth / mask_len).W) + val reads = Wire(Vec(2, Decoupled(UInt()))) + reads(0).valid := io.write.valid && io.write.bits.acc + reads(0).bits := io.write.bits.addr + reads(0).ready := true.B + reads(1).valid := io.read.req.valid + reads(1).bits := io.read.req.bits.addr + reads(1).ready := true.B + block_read_req := !reads(1).ready + for (i <- 0 until num_acc_sub_banks) { + def isThisBank(addr: UInt) = addr(log2Ceil(num_acc_sub_banks)-1,0) === i.U + def getBankIdx(addr: UInt) = addr >> log2Ceil(num_acc_sub_banks) + val mem = SyncReadMem(n / num_acc_sub_banks, Vec(mask_len, mask_elem)) + + val ren = WireInit(false.B) + val raddr = WireInit(getBankIdx(reads(0).bits)) + val nEntries = 3 + // Writes coming 2 cycles after read leads to bad bank behavior + // Add another buffer here + class W_Q_Entry[T <: Data](mask_len: Int, mask_elem: T) extends Bundle { + val valid = Bool() + val data = Vec(mask_len, mask_elem) + val mask = Vec(mask_len, Bool()) + val addr = UInt(log2Ceil(n/num_acc_sub_banks).W) + override def cloneType: this.type = new W_Q_Entry(mask_len, mask_elem).asInstanceOf[this.type] + } + val w_q = Reg(Vec(nEntries, new W_Q_Entry(mask_len, mask_elem))) + for (e <- w_q) { + when (e.valid) { + assert(!( + io.write.valid && io.write.bits.acc && + isThisBank(io.write.bits.addr) && getBankIdx(io.write.bits.addr) === e.addr && + ((io.write.bits.mask.asUInt & e.mask.asUInt) =/= 0.U) + )) + when (io.read.req.valid && isThisBank(io.read.req.bits.addr) && getBankIdx(io.read.req.bits.addr) === e.addr) { + reads(1).ready := false.B + } + } + } + val w_q_head = RegInit(1.U(nEntries.W)) + val w_q_tail = RegInit(1.U(nEntries.W)) + when (reset.asBool) { + w_q.foreach(_.valid := false.B) + } + val wen = WireInit(false.B) + val wdata = Mux1H(w_q_head.asBools, w_q.map(_.data)) + val wmask = Mux1H(w_q_head.asBools, w_q.map(_.mask)) + val waddr = Mux1H(w_q_head.asBools, w_q.map(_.addr)) + when (wen) { + w_q_head := w_q_head << 1 | w_q_head(nEntries-1) + for (i <- 0 until nEntries) { + when (w_q_head(i)) { + w_q(i).valid := false.B + } + } + } + + when (w_buf_valid && isThisBank(waddr_buf)) { + assert(!((w_q_tail.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_))) + w_q_tail := w_q_tail << 1 | w_q_tail(nEntries-1) + for (i <- 0 until nEntries) { + when (w_q_tail(i)) { + w_q(i).valid := true.B + w_q(i).data := Mux(acc_buf, w_sum, wdata_buf).asTypeOf(Vec(mask_len, mask_elem)) + w_q(i).mask := mask_buf + w_q(i).addr := getBankIdx(waddr_buf) + } + } + + } + val bank_rdata = mem.read(raddr, ren && !wen).asTypeOf(t) + when (RegNext(ren && reads(0).valid && isThisBank(reads(0).bits))) { + acc_rdata := bank_rdata + } .elsewhen (RegNext(ren)) { + read_rdata := bank_rdata + } + when (wen) { + mem.write(waddr, wdata, wmask) + } + // Three requestors, 1 slot + // Priority is incoming reads for RMW > writes from RMW > incoming reads + when (reads(0).valid && isThisBank(reads(0).bits)) { + ren := true.B + when (isThisBank(reads(1).bits)) { + reads(1).ready := false.B + } + } .elsewhen ((w_q_head.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)) { + wen := true.B + when (isThisBank(reads(1).bits)) { + reads(1).ready := false.B + } + } .otherwise { + ren := isThisBank(reads(1).bits) + raddr := getBankIdx(reads(1).bits) + } + } } - val q = Module(new Queue(new PipelinedRdataAndActT, 1, true, true)) - q.io.enq.bits.data := mem.io.rdata - q.io.enq.bits.full_data := mem.io.rdata + val q = Module(new Queue(new AccumulatorReadResp(t, scale_args.multiplicand_t, log2Ceil(t.head.head.getWidth)), 1, true, true)) + q.io.enq.bits.data := read_rdata q.io.enq.bits.scale := RegNext(io.read.req.bits.scale) q.io.enq.bits.relu6_shift := RegNext(io.read.req.bits.relu6_shift) q.io.enq.bits.act := RegNext(io.read.req.bits.act) q.io.enq.bits.fromDMA := RegNext(io.read.req.bits.fromDMA) + q.io.enq.bits.acc_bank_id := DontCare q.io.enq.valid := RegNext(io.read.req.fire()) - val p = Pipeline(q.io.deq, mem_pipeline, Seq.fill(mem_pipeline)((x: PipelinedRdataAndActT) => x) :+ { - x: PipelinedRdataAndActT => - val activated_rdata = VecInit(x.data.map(v => VecInit(v.map { e => - // val e_scaled = e >> x.shift - val e_scaled = scale_args.scale_func(e, x.scale) - val e_clipped = e_scaled.clippedToWidthOf(rdataType.head.head) - val e_act = MuxCase(e_clipped, Seq( - (x.act === Activation.RELU) -> e_clipped.relu, - (x.act === Activation.RELU6) -> e_clipped.relu6(x.relu6_shift))) - - e_act - }))) + val p = q.io.deq - val result = WireInit(x) - result.data := activated_rdata - - result - }) + io.read.resp.bits.data := p.bits.data + io.read.resp.bits.fromDMA := p.bits.fromDMA + io.read.resp.bits.relu6_shift := p.bits.relu6_shift + io.read.resp.bits.act := p.bits.act + io.read.resp.bits.scale := p.bits.scale + io.read.resp.bits.acc_bank_id := DontCare // This is set in Scratchpad + io.read.resp.valid := p.valid + p.ready := io.read.resp.ready val q_will_be_empty = (q.io.count +& q.io.enq.fire()) - q.io.deq.fire() === 0.U io.read.req.ready := q_will_be_empty && ( @@ -131,27 +224,13 @@ class AccumulatorMem[T <: Data, U <: Data](n: Int, t: Vec[Vec[T]], rdataType: Ve !(io.write.fire() && io.write.bits.acc) && // Make sure we aren't reading something that is still being written !(RegNext(io.write.fire()) && RegNext(io.write.bits.addr) === io.read.req.bits.addr) && - !(w_buf_valid && waddr_buf === io.read.req.bits.addr) - ) - io.read.resp.bits.data := p.bits.data - io.read.resp.bits.full_data := p.bits.full_data - io.read.resp.bits.fromDMA := p.bits.fromDMA - io.read.resp.valid := p.valid - p.ready := io.read.resp.ready - - if (read_small_data) - io.read.resp.bits.data := p.bits.data - else - io.read.resp.bits.data := 0.U.asTypeOf(p.bits.data) // TODO make this DontCare instead - - if (read_full_data) - io.read.resp.bits.full_data := p.bits.full_data - else - io.read.resp.bits.full_data := 0.U.asTypeOf(q.io.enq.bits.full_data) // TODO make this DontCare instead + !(w_buf_valid && waddr_buf === io.read.req.bits.addr) && + !block_read_req + ) // io.write.current_waddr.valid := mem.io.wen // io.write.current_waddr.bits := mem.io.waddr - io.write.ready := !io.write.bits.acc || (!(io.write.bits.addr === mem.io.waddr && mem.io.wen) && + io.write.ready := !io.write.bits.acc || (!(io.write.bits.addr === waddr_buf && w_buf_valid) && !(io.write.bits.addr === RegNext(io.write.bits.addr) && RegNext(io.write.fire()))) // assert(!(io.read.req.valid && io.write.en && io.write.acc), "reading and accumulating simultaneously is not supported") diff --git a/src/main/scala/gemmini/AccumulatorScale.scala b/src/main/scala/gemmini/AccumulatorScale.scala new file mode 100644 index 00000000..304126fe --- /dev/null +++ b/src/main/scala/gemmini/AccumulatorScale.scala @@ -0,0 +1,216 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +import Util._ + +class AccumulatorReadRespWithFullData[T <: Data: Arithmetic, U <: Data](fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int) extends Bundle { + val resp = new AccumulatorReadResp(fullDataType, scale_t, shift_width) + val full_data = fullDataType.cloneType + override def cloneType: this.type = new AccumulatorReadRespWithFullData(fullDataType.cloneType, scale_t, shift_width).asInstanceOf[this.type] +} + + +class AccumulatorScaleResp[T <: Data: Arithmetic](fullDataType: Vec[Vec[T]], rDataType: Vec[Vec[T]]) extends Bundle { + val full_data = fullDataType.cloneType + val data = rDataType.cloneType + val acc_bank_id = UInt(2.W) + val fromDMA = Bool() + override def cloneType: this.type = new AccumulatorScaleResp(fullDataType, rDataType).asInstanceOf[this.type] +} + +class AccumulatorScaleIO[T <: Data: Arithmetic, U <: Data]( + fullDataType: Vec[Vec[T]], scale_t: U, shift_width: Int, + rDataType: Vec[Vec[T]] +) extends Bundle { + val in = Flipped(Decoupled(new AccumulatorReadResp[T,U](fullDataType, scale_t, shift_width))) + val out = Decoupled(new AccumulatorScaleResp[T](fullDataType, rDataType)) + override def cloneType: this.type = new AccumulatorScaleIO(fullDataType, scale_t, + shift_width, rDataType).asInstanceOf[this.type] +} + +class AccScaleDataWithIndex[T <: Data: Arithmetic, U <: Data](t: T, u: U, scale_args: ScaleArguments[T, U]) extends Bundle { + val shift_width = log2Ceil(t.getWidth) + + val scale = u.cloneType + val act = UInt(2.W) + val relu6_shift = UInt(shift_width.W) + val data = t.cloneType + val full_data = t.cloneType + val id = UInt(2.W) // TODO hardcoded + val index = UInt() + override def cloneType: this.type = new AccScaleDataWithIndex(t, u, scale_args: ScaleArguments[T, U]).asInstanceOf[this.type] +} + +class AccScalePipe[T <: Data : Arithmetic, U <: Data](t: T, rDataType: Vec[Vec[T]], scale_args: ScaleArguments[T, U])(implicit ev: Arithmetic[T]) extends Module { + val u = scale_args.multiplicand_t + val io = IO(new Bundle { + val in = Input(Valid(new AccScaleDataWithIndex(t, u, scale_args)(ev))) + val out = Output(Valid(new AccScaleDataWithIndex(t, u, scale_args)(ev))) + }) + import ev._ + val latency = scale_args.latency + val out = WireInit(io.in) + + val e_scaled = scale_args.scale_func(io.in.bits.data, io.in.bits.scale) + val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) + val e_act = MuxCase(e_clipped, Seq( + (io.in.bits.act === Activation.RELU) -> e_clipped.relu, + (io.in.bits.act === Activation.RELU6) -> e_clipped.relu6(io.in.bits.relu6_shift))) + + out.bits.data := e_act + io.out := Pipe(out, latency) +} + + +class AccumulatorScale[T <: Data: Arithmetic, U <: Data]( + fullDataType: Vec[Vec[T]], rDataType: Vec[Vec[T]], + scale_t: U, shift_width: Int, + read_small_data: Boolean, read_full_data: Boolean, + scale_args: ScaleArguments[T, U])(implicit ev: Arithmetic[T]) extends Module { + + import ev._ + val io = IO(new AccumulatorScaleIO[T,U]( + fullDataType, scale_t, shift_width, rDataType + )(ev)) + val t = io.in.bits.data(0)(0).cloneType + val out = Wire(Decoupled(new AccumulatorScaleResp[T]( + fullDataType, rDataType)(ev))) + + val num_scale_units = scale_args.num_scale_units + val acc_scale_latency = scale_args.latency + + if (num_scale_units == -1) { + val in = Wire(Decoupled(new AccumulatorReadRespWithFullData(fullDataType, scale_t, shift_width)(ev))) + in.valid := io.in.valid + io.in.ready := in.ready + in.bits.resp := io.in.bits + in.bits.full_data := io.in.bits.data + + val pipe_out = Pipeline(in, acc_scale_latency, Seq.fill(acc_scale_latency)((x: AccumulatorReadRespWithFullData[T,U]) => x) :+ { + x: AccumulatorReadRespWithFullData[T,U] => + val activated_rdata = VecInit(x.resp.data.map(v => VecInit(v.map { e => + // val e_scaled = e >> x.shiftls + val e_scaled = scale_args.scale_func(e, x.resp.scale) + val e_clipped = e_scaled.clippedToWidthOf(rDataType.head.head) + val e_act = MuxCase(e_clipped, Seq( + (x.resp.act === Activation.RELU) -> e_clipped.relu, + (x.resp.act === Activation.RELU6) -> e_clipped.relu6(x.resp.relu6_shift))) + + e_act + }))) + val result = WireInit(x) + result.resp.data := activated_rdata + result + }) + out.valid := pipe_out.valid + pipe_out.ready := out.ready + out.bits.full_data := pipe_out.bits.full_data + out.bits.data := pipe_out.bits.resp.data + out.bits.fromDMA := pipe_out.bits.resp.fromDMA + out.bits.acc_bank_id := pipe_out.bits.resp.acc_bank_id + } else { + val width = io.in.bits.data.size * io.in.bits.data(0).size + val nEntries = 3 + val regs = Reg(Vec(nEntries, Valid(new AccumulatorReadResp[T,U]( + fullDataType, scale_t, shift_width)(ev)))) + val out_regs = Reg(Vec(nEntries, new AccumulatorScaleResp[T]( + fullDataType, rDataType)(ev))) + + val fired_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val completed_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val head_oh = RegInit(1.U(nEntries.W)) + val tail_oh = RegInit(1.U(nEntries.W)) + out.valid := Mux1H(head_oh.asBools, (regs zip completed_masks).map({case (r, c) => r.valid && c.reduce(_&&_)})) + out.bits := Mux1H(head_oh.asBools, out_regs) + when (out.fire()) { + for (i <- 0 until nEntries) { + when (head_oh(i)) { + regs(i).valid := false.B + } + } + head_oh := (head_oh << 1) | head_oh(nEntries-1) + } + + io.in.ready := !Mux1H(tail_oh.asBools, regs.map(_.valid)) || (tail_oh === head_oh && out.fire()) + when (io.in.fire()) { + for (i <- 0 until nEntries) { + when (tail_oh(i)) { + regs(i).valid := true.B + regs(i).bits := io.in.bits + out_regs(i).fromDMA := io.in.bits.fromDMA + out_regs(i).acc_bank_id := io.in.bits.acc_bank_id + fired_masks(i).foreach(_ := false.B) + completed_masks(i).foreach(_ := false.B) + } + } + tail_oh := (tail_oh << 1) | tail_oh(nEntries-1) + } + + val inputs = Seq.fill(width*nEntries) { Wire(Decoupled(new AccScaleDataWithIndex(t, scale_t, scale_args)(ev))) } + + for (i <- 0 until nEntries) { + for (w <- 0 until width) { + val input = inputs(i*width+w) + input.valid := regs(i).valid && !fired_masks(i)(w) + input.bits.data := regs(i).bits.data(w / io.in.bits.data(0).size)(w % io.in.bits.data(0).size) + input.bits.full_data := regs(i).bits.data(w / io.in.bits.data(0).size)(w % io.in.bits.data(0).size) + input.bits.scale := regs(i).bits.scale + input.bits.act := regs(i).bits.act + input.bits.relu6_shift := regs(i).bits.relu6_shift + input.bits.id := i.U + input.bits.index := w.U + when (input.fire()) { + fired_masks(i)(w) := true.B + } + } + } + for (i <- 0 until num_scale_units) { + val arbIn = inputs.zipWithIndex.filter({ case (_, w) => w % num_scale_units == i }).map(_._1) + val arb = Module(new RRArbiter(new AccScaleDataWithIndex(t, scale_t, scale_args)(ev), arbIn.length)) + arb.io.in <> arbIn + arb.io.out.ready := true.B + val arbOut = Reg(Valid(new AccScaleDataWithIndex(t, scale_t, scale_args)(ev))) + arbOut.valid := arb.io.out.valid + arbOut.bits := arb.io.out.bits + when (reset.asBool) { + arbOut.valid := false.B + } + val pipe = Module(new AccScalePipe(t, rDataType, scale_args)(ev, ev)) + pipe.io.in := arbOut + val pipe_out = pipe.io.out + + for (j <- 0 until nEntries) { + for (w <- 0 until width) { + if ((j*width+w) % num_scale_units == i) { + val id0 = w % io.in.bits.data(0).size + val id1 = w / io.in.bits.data(0).size + when (pipe_out.fire() && pipe_out.bits.id === j.U && pipe_out.bits.index === w.U) { + out_regs(j).data (id1)(id0) := pipe_out.bits.data + out_regs(j).full_data(id1)(id0) := pipe_out.bits.full_data + completed_masks(j)(w) := true.B + } + } + } + } + } + when (reset.asBool) { + regs.foreach(_.valid := false.B) + } + } + + io.out <> out + + if (read_small_data) + io.out.bits.data := out.bits.data + else + io.out.bits.data := DontCare + + if (read_full_data) + io.out.bits.full_data := out.bits.full_data + else + io.out.bits.full_data := DontCare + +} + diff --git a/src/main/scala/gemmini/Arithmetic.scala b/src/main/scala/gemmini/Arithmetic.scala index 0fcac90e..9170b834 100644 --- a/src/main/scala/gemmini/Arithmetic.scala +++ b/src/main/scala/gemmini/Arithmetic.scala @@ -17,6 +17,7 @@ abstract class ArithmeticOps[T <: Data](self: T) { def +(t: T): T def >>(u: UInt): T // This is a rounding shift! Rounds away from 0 def >(t: T): Bool + def identity: T def withWidthOf(t: T): T def clippedToWidthOf(t: T): T // Like "withWidthOf", except that it saturates def relu: T @@ -62,6 +63,7 @@ object Arithmetic { } override def zero: UInt = 0.U + override def identity: UInt = 1.U } } @@ -111,6 +113,7 @@ object Arithmetic { } override def zero: SInt = 0.S + override def identity: SInt = 1.S } } @@ -271,7 +274,25 @@ object Arithmetic { */ } - override def >(t: Float): Bool = true.B // TODO + override def >(t: Float): Bool = { + // Recode all operands + val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits) + val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + + // Resize t to self's width + val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth)) + t_resizer.io.in := t_rec + t_resizer.io.roundingMode := consts.round_near_even + t_resizer.io.detectTininess := consts.tininess_afterRounding + val t_rec_resized = t_resizer.io.out + + val comparator = Module(new CompareRecFN(self.expWidth, self.sigWidth)) + comparator.io.a := self_rec + comparator.io.b := t_rec_resized + comparator.io.signaling := false.B + + comparator.io.gt + } override def withWidthOf(t: Float): Float = { val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) @@ -375,6 +396,7 @@ object Arithmetic { } override def zero: Float = 0.U.asTypeOf(self) + override def identity: Float = Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) } } } diff --git a/src/main/scala/gemmini/BeatMerger.scala b/src/main/scala/gemmini/BeatMerger.scala index cac08aac..c3922f15 100644 --- a/src/main/scala/gemmini/BeatMerger.scala +++ b/src/main/scala/gemmini/BeatMerger.scala @@ -31,7 +31,6 @@ class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWid val io = IO(new Bundle { val req = Flipped(Decoupled(new XactTrackerEntry(maxShift, spadWidth, accWidth, spadRows, accRows, maxReqBytes, mvin_scale_t_bits, nCmds))) val in = Flipped(Decoupled(UInt(beatBits.W))) - // val in = Flipped(Decoupled(new BeatPackerIn(beatBits))) val out = Decoupled(new BeatMergerOut(spadWidth, accWidth, spadRows, accRows, alignedTo)) }) @@ -72,7 +71,7 @@ class BeatMerger[U <: Data](beatBits: Int, maxShift: Int, spadWidth: Int, accWid i.U >= spad_row_offset && i.U < spad_row_offset +& (req.bits.bytes_to_read - bytesSent) }) - io.out.bits.addr := req.bits.addr + meshRows.U * { + io.out.bits.addr := req.bits.addr + req.bits.block_stride * { val total_bytes_sent = req.bits.spad_row_offset + bytesSent Mux(req.bits.has_acc_bitwidth, // We only add "if" statements here to satisfy the Verilator linter. The code would be cleaner without the diff --git a/src/main/scala/gemmini/Configs.scala b/src/main/scala/gemmini/Configs.scala index 5e6b0336..3c9ac682 100644 --- a/src/main/scala/gemmini/Configs.scala +++ b/src/main/scala/gemmini/Configs.scala @@ -35,58 +35,44 @@ class WithMultiRoCC extends Config((site, here, up) => { // ----------------------- object GemminiConfigs { - // import Arithmetic.FloatArithmetic._ - val defaultConfig = GemminiArrayConfig[SInt, Float, Float]( - // val defaultConfig = GemminiArrayConfig[Float, Float]( + opcodes = OpcodeSet.custom3, + tileRows = 1, tileColumns = 1, - // meshRows = 4, - // meshColumns = 4, meshRows = 16, meshColumns = 16, + ld_queue_length = 8, st_queue_length = 2, ex_queue_length = 8, - rob_entries = 16, + + rob_full_entries = 16, + rob_partial_entries = 8, + + hasIm2col = false, //declare im2col block + sp_banks = 4, + sp_singleported = true, acc_banks = 2, + acc_singleported = false, + num_acc_sub_banks = -1, sp_capacity = CapacityInKilobytes(256), shifter_banks = 1, // TODO add separate parameters for left and up shifter banks dataflow = Dataflow.BOTH, acc_capacity = CapacityInKilobytes(64), - mem_pipeline = 1, - hasIm2col = true, //declare im2col block + mem_pipeline = 4, dma_maxbytes = 64, // TODO get this from cacheblockbytes dma_buswidth = 128, // TODO get this from SystemBusKey aligned_to = 1, + tlb_size = 4, + use_tlb_register_filter = true, + max_in_flight_reqs = 16, + use_dedicated_tl_port = false, inputType = SInt(8.W), outputType = SInt(20.W), accType = SInt(32.W), - // inputType = Float(8, 24), - // outputType = Float(8, 24), - // accType = Float(8, 24), - - // mvin_scale_args = Some(MvinScaleArguments((t: SInt, u: SInt) => t * u, 0, SInt(8.W))), - // mvin_scale_acc_args = Some(MvinScaleArguments((t: SInt, u: SInt) => t * u, 0, SInt(8.W))), - // mvin_scale_args = None, - -// mvin_scale_args = Some(ScaleArguments( -// (t: SInt, s: SInt) => { -// // The equation we use can be found here: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm -// -// // TODO Do we need to explicitly handle the cases where "u" is a small number (like 0)? What is the default behavior here? -// val u = s.asUInt() -// val point_five = Mux(u === 0.U, 0.U, t(u - 1.U)) -// val zeros = Mux(u <= 1.U, 0.U, t.asUInt() & ((1.U << (u - 1.U)).asUInt() - 1.U)) =/= 0.U -// val ones_digit = t(u) -// -// val r = (point_five & (zeros | ones_digit)).asBool() -// -// Mux(s >= 0.S, ((t >> u).asSInt() + Mux(r, 1.S, 0.S)).asSInt(), (t << (0.S-s).asUInt()).asSInt()) -// }, -// 0, SInt(8.W), "0")), mvin_scale_args = Some(ScaleArguments( (t: SInt, f: Float) => { @@ -122,13 +108,11 @@ object GemminiConfigs { Mux(overflow, sat, rec_fn_to_in.io.out.asTypeOf(t)) }, - 0, Float(8, 24), + 4, Float(8, 24), 4, identity = "1.0", c_str = "({float y = ROUND_NEAR_EVEN((x) * (scale)); y > INT8_MAX ? INT8_MAX : (y < INT8_MIN ? INT8_MIN : (elem_t)y);})" )), - mvin_scale_acc_args = None, - mvin_scale_shared = false, acc_scale_args = ScaleArguments( @@ -165,20 +149,34 @@ object GemminiConfigs { Mux(overflow, sat, rec_fn_to_in.io.out.asTypeOf(t)) }, - 0, Float(8, 24), + 1, Float(8, 24), -1, // TODO pipelining should be 5 identity = "1.0", c_str = "({float y = ROUND_NEAR_EVEN((x) * (scale)); y > INT8_MAX ? INT8_MAX : (y < INT8_MIN ? INT8_MIN : (acc_t)y);})" ), acc_read_full_width = true, acc_read_small_width = true, - use_dedicated_tl_port = false, + pe_latency = 0, - tlb_size = 4, - use_tlb_register_filter = true, - max_in_flight_reqs = 16, + ex_read_from_spad = true, + ex_read_from_acc = true, + ex_write_to_spad = true, + ex_write_to_acc = true + ) + + val chipConfig = defaultConfig.copy(sp_capacity=CapacityInKilobytes(64), acc_capacity=CapacityInKilobytes(32), dataflow=Dataflow.WS, + acc_scale_args=defaultConfig.acc_scale_args.copy(latency=4), + acc_singleported=true, + num_acc_sub_banks=2, + ex_read_from_acc=false, + ex_write_to_spad=false + ) + val largeChipConfig = chipConfig.copy(sp_capacity=CapacityInKilobytes(128), acc_capacity=CapacityInKilobytes(64), + meshRows=32, meshColumns=32 ) + + val highPerfConfig = defaultConfig.copy(dataflow=Dataflow.WS, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, max_in_flight_reqs = 64) } /** @@ -186,18 +184,19 @@ object GemminiConfigs { Also sets the system bus width to 128 bits (instead of the deafult 64 bits) to allow for the default 16x16 8-bit systolic array to be attached. */ -class DefaultGemminiConfig extends Config((site, here, up) => { +class DefaultGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( + gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.defaultConfig +) extends Config((site, here, up) => { case BuildRoCC => up(BuildRoCC) ++ Seq( - (p: Parameters) => { - implicit val q = p - val gemmini = LazyModule(new Gemmini(OpcodeSet.custom3, GemminiConfigs.defaultConfig)) - gemmini + (p: Parameters) => { + implicit val q = p + val gemmini = LazyModule(new Gemmini(gemminiConfig)) + gemmini } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) - /** * Mixin which configures a smaller host processor for the systolic array. This mixin **replaces** the default host rocket (assuming a single core config). @@ -231,7 +230,7 @@ class GemminiHostMiniCore extends Config((site, here, up) => { (up(RocketTilesKey, site).length - 1 -> Seq((p: Parameters) => { implicit val q = p - val gemmini = LazyModule(new Gemmini(OpcodeSet.custom3, GemminiConfigs.defaultConfig)) + val gemmini = LazyModule(new Gemmini(GemminiConfigs.defaultConfig)) gemmini })) }) @@ -270,7 +269,7 @@ class WithGemminiHostMiniCore extends Config((site, here, up) => { (up(RocketTilesKey, site).length -> Seq((p: Parameters) => { implicit val q = p - val gemmini = LazyModule(new Gemmini(OpcodeSet.custom3, GemminiConfigs.defaultConfig)) + val gemmini = LazyModule(new Gemmini(GemminiConfigs.defaultConfig)) gemmini })) }) @@ -316,5 +315,3 @@ class GemminiAcceleratorDeviceConfig extends Config( new WithoutTLMonitors ++ new freechips.rocketchip.system.DefaultConfig ) - - diff --git a/src/main/scala/gemmini/ConfigsFP.scala b/src/main/scala/gemmini/ConfigsFP.scala new file mode 100644 index 00000000..a7c065f3 --- /dev/null +++ b/src/main/scala/gemmini/ConfigsFP.scala @@ -0,0 +1,168 @@ +package gemmini + +import chisel3._ +import freechips.rocketchip.config.{Config, Parameters} +import freechips.rocketchip.diplomacy.{LazyModule, ValName} +import freechips.rocketchip.subsystem._ +import freechips.rocketchip.tile.{BuildRoCC, OpcodeSet} + +// ----------------------------- +// Floating Point Config Mixins +// ----------------------------- + + +object GemminiFPConfigs { + import Arithmetic.FloatArithmetic._ + val defaultFPConfig = GemminiArrayConfig[Float, Float, Float]( + opcodes = OpcodeSet.custom3, + tileRows = 1, + tileColumns = 1, + meshRows = 4, + meshColumns = 4, + + ld_queue_length = 8, + st_queue_length = 2, + ex_queue_length = 8, + + rob_full_entries = 16, + rob_partial_entries = 8, + + hasIm2col = false, + + sp_banks = 4, + sp_singleported = true, + acc_banks = 1, + acc_singleported = false, + num_acc_sub_banks = -1, + sp_capacity = CapacityInKilobytes(256), + shifter_banks = 1, // TODO add separate parameters for left and up shifter banks + dataflow = Dataflow.BOTH, + acc_capacity = CapacityInKilobytes(64), + mem_pipeline = 1, + + dma_maxbytes = 64, // TODO get this from cacheblockbytes + dma_buswidth = 128, // TODO get this from SystemBusKey + aligned_to = 1, + tlb_size = 4, + use_tlb_register_filter = true, + max_in_flight_reqs = 16, + use_dedicated_tl_port = false, + + inputType = Float(8, 24), + outputType = Float(8, 24), + accType = Float(8, 24), + + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_shared = false, + + acc_scale_args = ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", + c_str = "((x) * (scale))" + ), + acc_read_full_width = true, + acc_read_small_width = true, + + pe_latency = 1, + + ex_read_from_spad = true, + ex_read_from_acc = true, + ex_write_to_spad = true, + ex_write_to_acc = true, + ) + + //FP32 Single Precision Configuration + val FP32DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 24), outputType = Float(8, 24), accType = Float(8, 24), + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + + //FP16 Half Precision Configuration + val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), outputType = Float(5, 11), accType = Float(8, 24), + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + + //Bfloat16 Brain-half Precision Configuration + val BF16DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 8), outputType = Float(8, 8), accType = Float(8, 24), + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + + //Bfloat16 Brain-half Precision Configuration 8x8 array + val BF16Default8Config = defaultFPConfig.copy(inputType = Float(8, 8), outputType = Float(8, 8), accType = Float(8, 24), + meshRows = 8, meshColumns = 8, + pe_latency = 2, + mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), + ) + +} + + +//===========FP32 Default Config========= +class GemminiFP32DefaultConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.FP32DefaultConfig)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + + +//===========FP16 Default Config========= +class GemminiFP16DefaultConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.FP16DefaultConfig)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + +//===========BFLOAT16 Default Config========= +class GemminiBF16DefaultConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.BF16DefaultConfig)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + +class GemminiBF16DefaultHighPerfConfig extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + val gemmini = LazyModule(new Gemmini(GemminiFPConfigs.BF16DefaultConfig.copy( + ex_read_from_acc = false, + ex_write_to_spad = false, + ))) + gemmini + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + +//===========BFLOAT16 Default Config 8x8========= +class GemminiBF16Default8Config extends Config((site, here, up) => { + case BuildRoCC => Seq( + (p: Parameters) => { + implicit val q = p + implicit val v = implicitly[ValName] + LazyModule(new Gemmini(GemminiFPConfigs.BF16Default8Config)) + } + ) + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) +}) + diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index a3d11789..98d779b8 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -9,7 +9,7 @@ import chisel3.util._ import freechips.rocketchip.config._ import freechips.rocketchip.diplomacy._ import freechips.rocketchip.tile._ -import freechips.rocketchip.tilelink.{TLIdentityNode} +import freechips.rocketchip.tilelink.TLIdentityNode import GemminiISA._ import Util._ @@ -20,84 +20,10 @@ class GemminiCmd(rob_entries: Int)(implicit p: Parameters) extends Bundle { override def cloneType: this.type = new GemminiCmd(rob_entries).asInstanceOf[this.type] } - -class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_entries: Int) extends Bundle { - private val localAddrBits = 32 // TODO magic number - - private val spAddrBits = log2Ceil(sp_banks * sp_bank_entries) - private val accAddrBits = log2Ceil(acc_banks * acc_bank_entries) - private val maxAddrBits = spAddrBits max accAddrBits - - private val spBankBits = log2Up(sp_banks) - private val spBankRowBits = log2Up(sp_bank_entries) - - private val accBankBits = log2Up(acc_banks) - private val accBankRowBits = log2Up(acc_bank_entries) - - val is_acc_addr = Bool() - val accumulate = Bool() - val read_full_acc_row = Bool() - val garbage = UInt(((localAddrBits - maxAddrBits - 4) max 0).W) - val garbage_bit = if (localAddrBits - maxAddrBits >= 4) UInt(1.W) else UInt(0.W) - val data = UInt(maxAddrBits.W) - - def sp_bank(dummy: Int = 0) = if (spAddrBits == spBankRowBits) 0.U else data(spAddrBits - 1, spBankRowBits) - def sp_row(dummy: Int = 0) = data(spBankRowBits - 1, 0) - def acc_bank(dummy: Int = 0) = if (accAddrBits == accBankRowBits) 0.U else data(accAddrBits - 1, accBankRowBits) - def acc_row(dummy: Int = 0) = data(accBankRowBits - 1, 0) - - def full_sp_addr(dummy: Int = 0) = data(spAddrBits - 1, 0) - def full_acc_addr(dummy: Int = 0) = data(accAddrBits - 1, 0) - - def is_same_address(other: LocalAddr): Bool = is_acc_addr === other.is_acc_addr && data === other.data - def is_same_address(other: UInt): Bool = is_same_address(other.asTypeOf(this)) - def is_garbage(dummy: Int = 0) = is_acc_addr && accumulate && read_full_acc_row && data.andR() && - (if (garbage_bit.getWidth > 0) garbage_bit.asBool() else true.B) - - def +(other: UInt) = { - require(isPow2(sp_bank_entries)) // TODO remove this requirement - require(isPow2(acc_bank_entries)) // TODO remove this requirement - - val result = WireInit(this) - result.data := data + other - result - } - - def <=(other: LocalAddr) = - is_acc_addr === other.is_acc_addr && - Mux(is_acc_addr, full_acc_addr() <= other.full_acc_addr(), full_sp_addr() <= other.full_sp_addr()) - - def >(other: LocalAddr) = - is_acc_addr === other.is_acc_addr && - Mux(is_acc_addr, full_acc_addr() > other.full_acc_addr(), full_sp_addr() > other.full_sp_addr()) - - def add_with_overflow(other: UInt): Tuple2[LocalAddr, Bool] = { - require(isPow2(sp_bank_entries)) // TODO remove this requirement - require(isPow2(acc_bank_entries)) // TODO remove this requirement - - val sum = data +& other - - val result = WireInit(this) - result.data := sum(data.getWidth-1, 0) - - (result, sum(data.getWidth)) - } - - def make_this_garbage(dummy: Int = 0): Unit = { - is_acc_addr := true.B - accumulate := true.B - read_full_acc_row := true.B - garbage_bit := 1.U - data := ~(0.U(maxAddrBits.W)) - } - - override def cloneType: LocalAddr.this.type = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries).asInstanceOf[this.type] -} - -class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](opcodes: OpcodeSet, val config: GemminiArrayConfig[T, U, V]) +class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiArrayConfig[T, U, V]) (implicit p: Parameters) extends LazyRoCC ( - opcodes = OpcodeSet.custom3, + opcodes = config.opcodes, nPTWPorts = 1) { Files.write(Paths.get(config.headerFilePath), config.generateHeader().getBytes(StandardCharsets.UTF_8)) @@ -179,15 +105,40 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] */ // Incoming commands and ROB - val rob = Module(new ROB(new RoCCCommand, rob_entries, local_addr_t, meshRows*tileRows, meshColumns*tileColumns)) + val rob = Module(new ROB(outer.config, new RoCCCommand)) val raw_cmd = Queue(io.cmd) + // Loop Loader (load A or B) + // ToDo: collaborate with loopconv fsm (currently, only Loop matmul fsm) + val pause_monitor = spad.module.io.pause_out + val (loop_ld_cmd, loop_ld_unroller_busy, loop_ld_latency, loop_ld_alert, loop_ld_pause_turn) = LoopLoader(raw_cmd, pause_monitor, meshRows*tileRows, coreMaxAddrBits, sp_banks * sp_bank_entries, + inputType.getWidth, dma_maxbytes) + loop_ld_cmd.ready := false.B + spad.module.io.latency_in := loop_ld_latency + spad.module.io.alert_cycles_in := loop_ld_alert + spad.module.io.pause_turn_in := loop_ld_pause_turn + + + val max_lds = rob_partial_entries + val max_exs = rob_full_entries + val max_sts = rob_partial_entries / 2 + + // TODO replace 4,12,2 with parameters based on ROB size + val (conv_cmd, loop_conv_unroller_busy) = LoopConv(loop_ld_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, + inputType.getWidth, accType.getWidth, dma_maxbytes) + // val (compressed_cmd, compressor_busy) = InstCompressor(unrolled_cmd) // compressed_cmd.ready := false.B - val (unrolled_cmd, loop_unroller_busy) = LoopMatmul(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, - meshRows*tileRows, coreMaxAddrBits, rob_entries, 4, 12, 2, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, + + // val (unrolled_cmd, loop_matmul_unroller_busy) = LoopMatmul(unrolled_cmd_after_conv, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + + val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(conv_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes) + + val unrolled_cmd = Queue(loop_cmd) unrolled_cmd.ready := false.B // val cmd_decompressor = Module(new InstDecompressor(rob_entries)) @@ -286,7 +237,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] spad.module.io.dma.write <> store_controller.io.dma ex_controller.io.srams.read <> spad.module.io.srams.read ex_controller.io.srams.write <> spad.module.io.srams.write - ex_controller.io.acc.read <> spad.module.io.acc.read + spad.module.io.acc.read_req <> ex_controller.io.acc.read_req + ex_controller.io.acc.read_resp <> spad.module.io.acc.read_resp ex_controller.io.acc.write <> spad.module.io.acc.write // Im2Col unit @@ -361,12 +313,12 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] rob_completed_arb.io.out.ready := true.B // Wire up global RoCC signals - io.busy := raw_cmd.valid || loop_unroller_busy || rob.io.busy || spad.module.io.busy + io.busy := raw_cmd.valid || loop_conv_unroller_busy || loop_matmul_unroller_busy || rob.io.busy || spad.module.io.busy || unrolled_cmd.valid || loop_cmd.valid || conv_cmd.valid io.interrupt := tlb.io.exp.interrupt rob.io.solitary_preload := ex_controller.io.solitary_preload - assert(!io.interrupt, "Interrupt handlers have not been written yet") + // assert(!io.interrupt, "Interrupt handlers have not been written yet") // Cycle counters val ld_cycles = RegInit(0.U(34.W)) diff --git a/src/main/scala/gemmini/DMA.scala b/src/main/scala/gemmini/DMA.scala index d79295b3..2383f4e5 100644 --- a/src/main/scala/gemmini/DMA.scala +++ b/src/main/scala/gemmini/DMA.scala @@ -7,7 +7,7 @@ import chisel3.experimental.DataMirror import freechips.rocketchip.config.Parameters import freechips.rocketchip.diplomacy.{IdRange, LazyModule, LazyModuleImp} import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters} -import freechips.rocketchip.tilelink.{TLBundleA} +import freechips.rocketchip.tilelink.TLBundleA import testchipip.TLHelper import freechips.rocketchip.rocket.MStatus import freechips.rocketchip.rocket.constants.MemoryOpConstants @@ -24,8 +24,17 @@ class StreamReadRequest[U <: Data](spad_rows: Int, acc_rows: Int, mvin_scale_t_b val status = new MStatus val len = UInt(16.W) // TODO magic number val repeats = UInt(16.W) // TODO magic number + val block_stride = UInt(16.W) // TODO magic number val cmd_id = UInt(8.W) // TODO magic number + // for conflict monitoring + val monitor_conflict = Bool() + val monitor_conflict_start = Bool() + val monitor_conflict_end = Bool() + val profile_conflict = Bool() + val profile_conflict_start = Bool() + val profile_conflict_end = Bool() + override def cloneType: StreamReadRequest.this.type = new StreamReadRequest(spad_rows, acc_rows, mvin_scale_t_bits).asInstanceOf[this.type] } @@ -38,7 +47,7 @@ class StreamReadResponse[U <: Data](spadWidth: Int, accWidth: Int, spad_rows: In val accumulate = Bool() val has_acc_bitwidth = Bool() val scale = UInt(mvin_scale_t_bits.W) - val rows = UInt(16.W) // TODO magic number + val repeats = UInt(16.W) // TODO magic number val last = Bool() val bytes_read = UInt(8.W) // TODO magic number val cmd_id = UInt(8.W) // TODO magic number @@ -60,7 +69,25 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T val tlb = new FrontendTLBIO val busy = Output(Bool()) val flush = Input(Bool()) + + //for monitoring conflicts, latency + val latency_in = Input(UInt(16.W)) + val alert_cycles_in = Input(UInt(6.W)) + val latency_out = Output(UInt(16.W)) + val alert_cycles_out = Output(UInt(6.W)) + val pause_turn_in = Input(UInt(3.W)) + val pause_turn_out = Output(UInt(3.W)) + + //for pausing monitoring + val pause_out = Output(Bool()) }) + io.latency_out := io.latency_in + io.alert_cycles_out := io.alert_cycles_in + io.pause_turn_out := io.pause_turn_in + core.module.io.latency := io.latency_out + core.module.io.alert_cycles := io.alert_cycles_out + io.pause_out := core.module.io.pause + core.module.io.pause_turn := io.pause_turn_out val nCmds = (nXacts / meshRows) + 1 @@ -93,7 +120,7 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T io.resp.bits.accumulate := beatPacker.io.out.bits.accumulate io.resp.bits.has_acc_bitwidth := beatPacker.io.out.bits.has_acc_bitwidth io.resp.bits.scale := RegEnable(xactTracker.io.peek.entry.scale, beatPacker.io.req.fire()) - io.resp.bits.rows := RegEnable(xactTracker.io.peek.entry.rows, beatPacker.io.req.fire()) + io.resp.bits.repeats := RegEnable(xactTracker.io.peek.entry.repeats, beatPacker.io.req.fire()) io.resp.bits.cmd_id := RegEnable(xactTracker.io.peek.entry.cmd_id, beatPacker.io.req.fire()) io.resp.bits.bytes_read := RegEnable(xactTracker.io.peek.entry.bytes_to_read, beatPacker.io.req.fire()) io.resp.bits.last := beatPacker.io.out.bits.last @@ -134,6 +161,12 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf val beatData = Decoupled(new StreamReadBeat(nXacts, beatBits, maxBytes)) val tlb = new FrontendTLBIO val flush = Input(Bool()) + + //for monitoring conflicts, latency + val latency = Input(UInt(16.W)) + val alert_cycles = Input(UInt(6.W)) + val pause_turn = Input(UInt(3.W)) + val pause = Output(Bool()) }) val s_idle :: s_req_new_block :: Nil = Enum(2) @@ -156,6 +189,14 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf val shift = UInt(log2Up(maxBytes).W) //val paddr = UInt(paddrBits.W) val vaddr = UInt(vaddrBits.W) + + //for bank conflict monitoring + val monitor_conflict = Bool() + val monitor_conflict_start = Bool() + val monitor_conflict_end = Bool() + val profile_conflict = Bool() + val profile_conflict_start = Bool() + val profile_conflict_end = Bool() } // TODO Can we filter out the larger read_sizes here if the systolic array is small, in the same way that we do so @@ -175,6 +216,14 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf packet.shift := vaddr_offset packet.vaddr := vaddr_aligned_to_size + //for bank conflict monitoring + packet.monitor_conflict := req.monitor_conflict + packet.monitor_conflict_start := req.monitor_conflict_start + packet.monitor_conflict_end := req.monitor_conflict_end + packet.profile_conflict := req.profile_conflict + packet.profile_conflict_end := req.profile_conflict_end + packet.profile_conflict_start := req.profile_conflict_start + packet } val read_packet = read_packets.reduce { (acc, p) => @@ -184,6 +233,13 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf val read_lg_size = read_packet.lg_size val read_bytes_read = read_packet.bytes_read val read_shift = read_packet.shift + //for bank conflict monitoring + val read_monitor = read_packet.monitor_conflict + val read_monitor_start = read_packet.monitor_conflict_start + val read_monitor_end = read_packet.monitor_conflict_end + val profile = read_packet.profile_conflict + val profile_start = read_packet.profile_conflict_start + val profile_end = read_packet.profile_conflict_end // Firing off TileLink read requests and allocating space inside the reservation buffer for them val get = edge.Get( @@ -196,6 +252,15 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf val tl_a = DataMirror.internal.chiselTypeClone[TLBundleA](tl.a.bits) val vaddr = Output(UInt(vaddrBits.W)) val status = Output(new MStatus) + + //for bank conflict monitoring + val monitor_conflict = Output(Bool()) + val monitor_conflict_start = Output(Bool()) + val monitor_conflict_end = Output(Bool()) + + val profile_conflict = Output(Bool()) + val profile_conflict_start = Output(Bool()) + val profile_conflict_end = Output(Bool()) } val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo)) @@ -203,6 +268,13 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf untranslated_a.bits.tl_a := get untranslated_a.bits.vaddr := read_vaddr untranslated_a.bits.status := req.status + //for bank conflict monitoring + untranslated_a.bits.monitor_conflict := read_monitor + untranslated_a.bits.monitor_conflict_start := read_monitor_start + untranslated_a.bits.monitor_conflict_end := read_monitor_end + untranslated_a.bits.profile_conflict := profile + untranslated_a.bits.profile_conflict_end := profile_end + untranslated_a.bits.profile_conflict_start := profile_start // 0 goes to retries, 1 goes to state machine val retry_a = Wire(Decoupled(new TLBundleAWithInfo)) @@ -217,22 +289,224 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf io.tlb.req.bits.tlb_req.vaddr := tlb_q.io.deq.bits.vaddr io.tlb.req.bits.tlb_req.passthrough := false.B io.tlb.req.bits.tlb_req.size := 0.U // send_size - io.tlb.req.bits.tlb_req.cmd := M_XWR + io.tlb.req.bits.tlb_req.cmd := M_XRD io.tlb.req.bits.status := tlb_q.io.deq.bits.status - val translate_q = Module(new Queue(new TLBundleAWithInfo, 1, pipe=true)) translate_q.io.enq <> tlb_q.io.deq translate_q.io.deq.ready := true.B - retry_a.valid := translate_q.io.deq.valid && (io.tlb.resp.miss || !tl.a.ready) + //retry_a.valid := translate_q.io.deq.valid && (io.tlb.resp.miss || !tl.a.ready) + //retry_a.bits := translate_q.io.deq.bits + //assert(retry_a.ready) + + ///////////////////////////////////////////////////////////////////////////////////////// + val conflict_detected = RegInit(false.B) + val pause_turn = RegInit(io.pause_turn) + retry_a.valid := translate_q.io.deq.valid && (io.tlb.resp.miss || !tl.a.ready || conflict_detected) retry_a.bits := translate_q.io.deq.bits assert(retry_a.ready) + when(reset.toBool()){ + translate_q.io.enq.bits.profile_conflict := false.B + translate_q.io.enq.bits.profile_conflict_start := false.B + translate_q.io.enq.bits.profile_conflict_end := false.B + translate_q.io.enq.bits.monitor_conflict := false.B + translate_q.io.enq.bits.monitor_conflict_end := false.B + translate_q.io.enq.bits.monitor_conflict_start := false.B + retry_a.bits.profile_conflict := false.B + retry_a.bits.profile_conflict_start := false.B + retry_a.bits.profile_conflict_end := false.B + retry_a.bits.monitor_conflict_end := false.B + retry_a.bits.monitor_conflict_start := false.B + retry_a.bits.monitor_conflict := false.B + untranslated_a.bits.profile_conflict := false.B + untranslated_a.bits.profile_conflict_start := false.B + untranslated_a.bits.profile_conflict_end := false.B + untranslated_a.bits.monitor_conflict_end := false.B + untranslated_a.bits.monitor_conflict_start := false.B + untranslated_a.bits.monitor_conflict := false.B + } + + val tl_miss = tl.a.valid && !tl.a.ready + val tl_profile = translate_q.io.deq.bits.profile_conflict + val tl_profile_start = translate_q.io.deq.bits.profile_conflict_start && translate_q.io.deq.valid + val tl_profile_end = translate_q.io.deq.bits.profile_conflict_end + val (p_reset :: p_profile_start :: Nil) = Enum(2) + val profile_miss_counter = RegInit(0.U(7.W)) + val p_state = RegInit(p_reset) + val profile_number = RegInit(0.U(9.W)) + val profile_total = RegInit(0.U(12.W)) + val profile_max = RegInit(0.U(7.W)) + val profile_average = RegInit(0.U(7.W)) + // either average or max + val profile_cycle = RegInit(profile_max) + val profile_detected = RegInit(false.B) + when(p_state === p_reset){ + when(tl_profile_start){ + p_state := p_profile_start + } + } + when(p_state === p_profile_start){ + when(tl_miss && tl_profile){ // and here? + when(profile_miss_counter === 5.U){ //only count those that are over 5 cycles (to avoid false detection) + profile_number := profile_number + 1.U + profile_detected := true.B + } + profile_miss_counter := profile_miss_counter + 1.U //which counter to use? + }.otherwise{ + when(profile_detected){ + profile_total := profile_total + profile_miss_counter + profile_max := Mux(profile_max < profile_miss_counter && !translate_q.io.deq.bits.profile_conflict_start, profile_miss_counter, profile_max)//update to max value + profile_detected := false.B + when(profile_number === 64.U){ + profile_average := profile_total / 64.U + } + } + profile_miss_counter := 0.U + } + when(tl_profile_end){ + profile_detected := false.B + p_state := p_reset + profile_miss_counter := 0.U + when(profile_number === 64.U){ + profile_average := profile_total / 64.U + } + //profile_average := profile_total / profile_number // ToDo: need to change (don't use division) + profile_cycle := Mux(io.pause_turn === 1.U, profile_average * 2.U, profile_max + 1.U) //parameterize what to select + } + } + dontTouch(profile_miss_counter) + dontTouch(profile_cycle) + dontTouch(profile_average) + dontTouch(profile_max) + dontTouch(profile_detected) + dontTouch(profile_total) + dontTouch(profile_number) + + val tl_counter_trigger = tl_miss && translate_q.io.deq.bits.monitor_conflict && !translate_q.io.deq.bits.monitor_conflict_end + val tl_miss_counter = RegInit(0.U(6.W)) + val alert_cycles = RegInit(profile_cycle) + val latency = RegInit(io.latency) + val enable_bubble = RegInit(false.B) + //val max_block_len = (maxBytes / (meshRows * spadWidth / 8)) max 1 + val expected_tl_req = (spad_rows / (2*2*4)).asUInt() + //val latency = Mux(translate_q.io.deq.bits.monitor_conflict && !translate_q.io.deq.bits.profile_conflict, + // Mux(io.latency === 0.U, expected_tl_req * profile_average, io.latency), 0.U) //if latency not give, use profiled one + //tl_miss_counter := satAdd(tl_miss_counter, 1.U, alert_cycles + 2.U, tl_counter_trigger) + /* + when(tl_miss_counter >= alert_cycles){ //reached limit + conflict_detected := true.B + }.elsewhen(!tl_counter_trigger){ + tl_miss_counter := 0.U + } + + */ + // pause monitoring detecting logic + val (s_reset :: s_monitor_start :: s_conflict_detected :: Nil) = Enum(3) + val m_state = RegInit(s_reset) + val pause_detect = RegInit(false.B) + val pause_count = RegInit(0.U(2.W)) //Todo: parameterize it? + val tl_miss_timer = RegInit(0.U(16.W)) + /* + tl_miss_timer := floorAdd(tl_miss_timer, 1.U, latency + 1.U, conflict_detected) + when(tl_miss_timer === latency){ //resolve miss counter temporary + tl_miss_counter := 0.U //reset miss counter + conflict_detected := false.B + } + */ + //val pause_monitor_start = RegInit(0.U(6.W)) + io.pause := pause_detect + when(translate_q.io.deq.bits.monitor_conflict && !translate_q.io.deq.bits.monitor_conflict_end){ + when(m_state === s_reset) { + tl_miss_counter := 0.U + tl_miss_timer := 0.U + when(translate_q.io.deq.bits.monitor_conflict_start && translate_q.io.deq.valid){ // to avoid false detection + m_state := s_monitor_start + alert_cycles := Mux(io.alert_cycles === 0.U, profile_cycle, io.alert_cycles) + when(io.latency === 0.U){ + latency := expected_tl_req * profile_average // use profiled one + enable_bubble := true.B + }.elsewhen(io.latency === 1.U){ + latency := 0.U + enable_bubble := false.B + }.otherwise{ + enable_bubble := true.B + latency := io.latency + } + pause_turn := io.pause_turn + } + }.elsewhen(m_state === s_monitor_start){ + tl_miss_counter := satAdd(tl_miss_counter, 1.U, alert_cycles + 2.U, tl_miss) + when(tl_miss_counter >= alert_cycles){ + m_state := s_conflict_detected + when(enable_bubble){ + conflict_detected := true.B + }.otherwise{ + tl_miss_counter := 0.U //resolve miss counter immediately (no bubbbles) + conflict_detected := false.B + } + }.elsewhen(!tl_miss){ + tl_miss_counter := 0.U + } + //pause_monitor_start := 0.U + }.elsewhen(m_state === s_conflict_detected){ + tl_miss_counter := satAdd(tl_miss_counter, 1.U, alert_cycles + 2.U, tl_miss) + tl_miss_timer := floorAdd(tl_miss_timer, 1.U, latency + 1.U, conflict_detected) + when(tl_miss_counter >= alert_cycles){ + when(enable_bubble){ + conflict_detected := true.B + }.otherwise{ + tl_miss_counter := 0.U //resolve miss counter immediately (no bubbbles) + conflict_detected := false.B + } + }.elsewhen(!tl_miss){ + tl_miss_counter := 0.U + } + when(tl_miss_timer === latency){ //resolve miss counter temporary + tl_miss_counter := 0.U //reset miss counter + conflict_detected := false.B + } + } + }.elsewhen(translate_q.io.deq.bits.monitor_conflict_end) { + when(m_state === s_conflict_detected) { + m_state := s_reset + pause_count := 0.U + pause_detect := false.B + }.elsewhen(m_state === s_monitor_start) { // no detection during time window + when(pause_count === pause_turn) { // pause monitoring + pause_detect := true.B + m_state := s_reset + pause_count := 0.U // reset pause counter + }.otherwise { + pause_count := pause_count + 1.U + m_state := s_reset + } + } + //pause_monitor_start := 0.U + } //ToDo: how to restart monitoring after pausing - tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss + + tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss && !(conflict_detected) tl.a.bits := translate_q.io.deq.bits.tl_a tl.a.bits.address := io.tlb.resp.paddr + /* + val cycles = freechips.rocketchip.util.WideCounter(32) + when(tl.a.fire()){ + printf("GEMMINI_MEM %x %x %x %x\n", cycles.value, p(freechips.rocketchip.tile.TileKey).hartId.U, tl.a.bits.address, tl.a.bits.size) + //printf(midas.targetutils.SynthesizePrintf("GEMMINI_MEM: %x %x %x\n", p(freechips.rocketchip.tile.TileKey).hartId.U, tl.a.bits.address, tl.a.bits.size)) + } + when(tl_miss){ + printf("GEMMINI_BLOCK %x %x\n", cycles.value, p(freechips.rocketchip.tile.TileKey).hartId.U) + //printf(midas.targetutils.SynthesizePrintf("GEMMINI_BLOCK: %x %x \n", p(freechips.rocketchip.tile.TileKey).hartId.U, tl.a.bits.address)) + } + + */ + ///////////////////////////////////////////////////////////////////////////////////////// + + //tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss + //tl.a.bits := translate_q.io.deq.bits.tl_a + //tl.a.bits.address := io.tlb.resp.paddr io.reserve.valid := state === s_req_new_block && untranslated_a.ready // TODO decouple "reserve.valid" from "tl.a.ready" io.reserve.entry.shift := read_shift @@ -240,18 +514,20 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf io.reserve.entry.accumulate := req.accumulate io.reserve.entry.has_acc_bitwidth := req.has_acc_bitwidth io.reserve.entry.scale := req.scale - io.reserve.entry.rows := req.repeats + io.reserve.entry.repeats := req.repeats + io.reserve.entry.block_stride := req.block_stride io.reserve.entry.lg_len_req := DontCare // TODO just remove this from the IO completely io.reserve.entry.bytes_to_read := read_bytes_read io.reserve.entry.cmd_id := req.cmd_id - io.reserve.entry.addr := req.spaddr + meshRows.U * + io.reserve.entry.addr := req.spaddr + req.block_stride * Mux(req.has_acc_bitwidth, // We only add "if" statements here to satisfy the Verilator linter. The code would be cleaner without the // "if" condition and the "else" clause if (bytesRequested.getWidth >= log2Up(accWidthBytes+1)) bytesRequested / accWidthBytes.U else 0.U, if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U else 0.U) io.reserve.entry.spad_row_offset := Mux(req.has_acc_bitwidth, bytesRequested % accWidthBytes.U, bytesRequested % spadWidthBytes.U) + when (untranslated_a.fire()) { val next_vaddr = req.vaddr + read_bytes_read // send_size val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U @@ -286,10 +562,11 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf } } -class StreamWriteRequest(val dataWidth: Int)(implicit p: Parameters) extends CoreBundle { +class StreamWriteRequest(val dataWidth: Int, val maxBytes: Int)(implicit p: Parameters) extends CoreBundle { val vaddr = UInt(coreMaxAddrBits.W) val data = UInt(dataWidth.W) - val len = UInt(16.W) // The number of bytes to write // TODO magic number + val len = UInt(log2Up((dataWidth/8 max maxBytes)+1).W) // The number of bytes to write + val block = UInt(8.W) // TODO magic number val status = new MStatus // Pooling variables @@ -311,11 +588,13 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val beatBytes = beatBits / 8 val lgBeatBytes = log2Ceil(beatBytes) val maxBeatsPerReq = maxBytes / beatBytes + val inputTypeRowBytes = block_cols * inputType.getWidth / 8 + val maxBlocks = maxBytes / inputTypeRowBytes require(beatBytes > 0) val io = IO(new Bundle { - val req = Flipped(Decoupled(new StreamWriteRequest(dataWidth))) + val req = Flipped(Decoupled(new StreamWriteRequest(dataWidth, maxBytes))) val tlb = new FrontendTLBIO val busy = Output(Bool()) val flush = Input(Bool()) @@ -324,9 +603,14 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val (s_idle :: s_writing_new_block :: s_writing_beats :: Nil) = Enum(3) val state = RegInit(s_idle) - val req = Reg(new StreamWriteRequest(dataWidth)) + val req = Reg(new StreamWriteRequest(dataWidth, maxBytes)) - val bytesSent = Reg(UInt(log2Ceil(dataBytes).W)) // TODO this only needs to count up to (dataBytes/aligned_to), right? + // TODO use the same register to hold data_blocks and data_single_block, so that this Mux here is not necessary + val data_blocks = Reg(Vec(maxBlocks, UInt((inputTypeRowBytes * 8).W))) + val data_single_block = Reg(UInt(dataWidth.W)) // For data that's just one-block-wide + val data = Mux(req.block === 0.U, data_single_block, data_blocks.asUInt()) + + val bytesSent = Reg(UInt(log2Ceil((dataBytes max maxBytes)+1).W)) // TODO this only needs to count up to (dataBytes/aligned_to), right? val bytesLeft = req.len - bytesSent val xactBusy = RegInit(0.U(nXacts.W)) @@ -346,34 +630,32 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: // Select the size and mask of the TileLink request class Packet extends Bundle { - val size = UInt(log2Ceil(maxBytes).W) - val lg_size = UInt(log2Ceil(log2Ceil(maxBytes)).W) + val size = UInt(log2Ceil(maxBytes+1).W) + val lg_size = UInt(log2Ceil(log2Ceil(maxBytes+1)+1).W) val mask = Vec(maxBeatsPerReq, Vec(beatBytes, Bool())) val vaddr = UInt(vaddrBits.W) val is_full = Bool() - def bytes_written(dummy: Int = 0) = PopCount(mask.flatten) + val bytes_written = UInt(log2Up(maxBytes+1).W) + val bytes_written_per_beat = Vec(maxBeatsPerReq, UInt(log2Up(beatBytes+1).W)) + def total_beats(dummy: Int = 0) = Mux(size < beatBytes.U, 1.U, size / beatBytes.U) } val smallest_write_size = aligned_to max beatBytes val write_sizes = (smallest_write_size to maxBytes by aligned_to). filter(s => isPow2(s)). - filter(s => s % beatBytes == 0). - filter(s => s <= dataBytes*2 || s == smallest_write_size) + filter(s => s % beatBytes == 0) /*. + filter(s => s <= dataBytes*2 || s == smallest_write_size)*/ val write_packets = write_sizes.map { s => val lg_s = log2Ceil(s) val vaddr_aligned_to_size = if (s == 1) vaddr else Cat(vaddr(vaddrBits-1, lg_s), 0.U(lg_s.W)) + val vaddr_offset = if (s > 1) vaddr(lg_s - 1, 0) else 0.U - val mask = (0 until maxBytes).map { i => - if (s > 1) { - val vaddr_offset = vaddr(lg_s - 1, 0) + val mask = (0 until maxBytes).map { i => i.U >= vaddr_offset && i.U < vaddr_offset +& bytesLeft && (i < s).B } - i.U >= vaddr_offset && - i.U < vaddr_offset +& bytesLeft - } else { - true.B - } && (i < s).B + val bytes_written = { + Mux(vaddr_offset +& bytesLeft > s.U, s.U - vaddr_offset, bytesLeft) } val packet = Wire(new Packet()) @@ -383,13 +665,35 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: packet.vaddr := vaddr_aligned_to_size packet.is_full := mask.take(s).reduce(_ && _) + packet.bytes_written := bytes_written + packet.bytes_written_per_beat.zipWithIndex.foreach { case (b, i) => + val start_of_beat = i * beatBytes + val end_of_beat = (i+1) * beatBytes + + val left_shift = Mux(vaddr_offset >= start_of_beat.U && vaddr_offset < end_of_beat.U, + vaddr_offset - start_of_beat.U, + 0.U) + + val right_shift = Mux(vaddr_offset +& bytesLeft >= start_of_beat.U && vaddr_offset +& bytesLeft < end_of_beat.U, + end_of_beat.U - (vaddr_offset +& bytesLeft), + 0.U) + + val too_early = vaddr_offset >= end_of_beat.U + val too_late = vaddr_offset +& bytesLeft <= start_of_beat.U + + b := Mux(too_early || too_late, 0.U, beatBytes.U - (left_shift +& right_shift)) + } + packet } val best_write_packet = write_packets.reduce { (acc, p) => - Mux(p.bytes_written() > acc.bytes_written(), p, acc) + Mux(p.bytes_written > acc.bytes_written, p, acc) } val write_packet = RegEnableThru(best_write_packet, state === s_writing_new_block) + for (wp <- write_packets) + dontTouch(wp) + val write_size = write_packet.size val lg_write_size = write_packet.lg_size val write_beats = write_packet.total_beats() @@ -402,21 +706,21 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val write_mask = write_packet.mask(beatsSent) val write_shift = PriorityEncoder(write_mask) - val bytes_written_this_beat = PopCount(write_mask) + val bytes_written_this_beat = write_packet.bytes_written_per_beat(beatsSent) // Firing off TileLink write requests val putFull = edge.Put( fromSource = RegEnableThru(xactId, state === s_writing_new_block), toAddress = 0.U, lgSize = lg_write_size, - data = (req.data >> (bytesSent * 8.U)).asUInt() + data = (data >> (bytesSent * 8.U)).asUInt() )._2 val putPartial = edge.Put( fromSource = RegEnableThru(xactId, state === s_writing_new_block), toAddress = 0.U, lgSize = lg_write_size, - data = ((req.data >> (bytesSent * 8.U)) << (write_shift * 8.U)).asUInt(), + data = ((data >> (bytesSent * 8.U)) << (write_shift * 8.U)).asUInt(), mask = write_mask.asUInt() )._2 @@ -427,7 +731,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: } val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo)) - xactBusy_fire := untranslated_a.fire() + xactBusy_fire := untranslated_a.fire() && state === s_writing_new_block untranslated_a.valid := (state === s_writing_new_block || state === s_writing_beats) && !xactBusy.andR() untranslated_a.bits.tl_a := Mux(write_full, putFull, putPartial) untranslated_a.bits.vaddr := write_vaddr @@ -435,14 +739,18 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: // 0 goes to retries, 1 goes to state machine val retry_a = Wire(Decoupled(new TLBundleAWithInfo)) - val tlb_arb = Module(new Arbiter(new TLBundleAWithInfo, 2)) + val shadow_retry_a = Module(new Queue(new TLBundleAWithInfo, 1)) + shadow_retry_a.io.enq.valid := false.B + shadow_retry_a.io.enq.bits := DontCare + val tlb_arb = Module(new Arbiter(new TLBundleAWithInfo, 3)) tlb_arb.io.in(0) <> retry_a - tlb_arb.io.in(1) <> untranslated_a + tlb_arb.io.in(1) <> shadow_retry_a.io.deq + tlb_arb.io.in(2) <> untranslated_a val tlb_q = Module(new Queue(new TLBundleAWithInfo, 1, pipe=true)) tlb_q.io.enq <> tlb_arb.io.out - io.tlb.req.valid := tlb_q.io.deq.valid + io.tlb.req.valid := tlb_q.io.deq.fire() io.tlb.req.bits.tlb_req.vaddr := tlb_q.io.deq.bits.vaddr io.tlb.req.bits.tlb_req.passthrough := false.B io.tlb.req.bits.tlb_req.size := 0.U // send_size @@ -451,15 +759,20 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val translate_q = Module(new Queue(new TLBundleAWithInfo, 1, pipe=true)) translate_q.io.enq <> tlb_q.io.deq - translate_q.io.deq.ready := true.B + when (retry_a.valid) { + translate_q.io.enq.valid := false.B + shadow_retry_a.io.enq.valid := tlb_q.io.deq.valid + shadow_retry_a.io.enq.bits := tlb_q.io.deq.bits + } + translate_q.io.deq.ready := tl.a.ready || io.tlb.resp.miss - retry_a.valid := translate_q.io.deq.valid && (io.tlb.resp.miss || !tl.a.ready) + retry_a.valid := translate_q.io.deq.valid && io.tlb.resp.miss retry_a.bits := translate_q.io.deq.bits - assert(retry_a.ready) + assert(!(retry_a.valid && !retry_a.ready)) - tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss - tl.a.bits := translate_q.io.deq.bits.tl_a - tl.a.bits.address := io.tlb.resp.paddr + tl.a.valid := translate_q.io.deq.valid && !io.tlb.resp.miss + tl.a.bits := translate_q.io.deq.bits.tl_a + tl.a.bits.address := RegEnableThru(io.tlb.resp.paddr, RegNext(io.tlb.req.fire())) tl.d.ready := xactBusy.orR() @@ -467,7 +780,7 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: when (state === s_writing_new_block) { beatsLeft := write_beats - 1.U - val next_vaddr = req.vaddr + bytes_written_this_beat + val next_vaddr = req.vaddr + write_packet.bytes_written req.vaddr := next_vaddr bytesSent := bytesSent + bytes_written_this_beat @@ -485,9 +798,9 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: beatsLeft := beatsLeft - 1.U bytesSent := bytesSent + bytes_written_this_beat - when (beatsLeft === 0.U) { - val new_page = req.vaddr(pgIdxBits-1, 0) === 0.U + assert(beatsLeft > 0.U) + when (beatsLeft === 1.U) { when (bytes_written_this_beat >= bytesLeft) { // We're done with this request at this point state_machine_ready_for_req := true.B @@ -504,17 +817,23 @@ class StreamWriter[T <: Data: Arithmetic](nXacts: Int, beatBits: Int, maxBytes: val pooled = { val cols = dataWidth / inputType.getWidth val v1 = io.req.bits.data.asTypeOf(Vec(cols, inputType)) - val v2 = req.data.asTypeOf(Vec(cols, inputType)) + val v2 = data_single_block.asTypeOf(Vec(cols, inputType)) val m = v1.zip(v2) VecInit(m.zipWithIndex.map{case ((x, y), i) => if (i < block_cols) maxOf(x, y) else y}).asUInt() } req := io.req.bits - req.data := Mux(io.req.bits.pool_en, pooled, io.req.bits.data) + req.len := io.req.bits.block * inputTypeRowBytes.U + io.req.bits.len + + data_single_block := Mux(io.req.bits.pool_en, pooled, io.req.bits.data) + data_blocks(io.req.bits.block) := io.req.bits.data bytesSent := 0.U state := Mux(io.req.bits.store_en, s_writing_new_block, s_idle) + + assert(io.req.bits.len <= (block_cols * inputType.getWidth / 8).U || io.req.bits.block === 0.U, "DMA can't write multiple blocks to main memory when writing full accumulator output") + assert(!io.req.bits.pool_en || io.req.bits.block === 0.U, "Can't pool with block-mvout") } } } diff --git a/src/main/scala/gemmini/DMAReadCommandTracker.scala b/src/main/scala/gemmini/DMACommandTracker.scala similarity index 86% rename from src/main/scala/gemmini/DMAReadCommandTracker.scala rename to src/main/scala/gemmini/DMACommandTracker.scala index 386bf52e..2632f753 100644 --- a/src/main/scala/gemmini/DMAReadCommandTracker.scala +++ b/src/main/scala/gemmini/DMACommandTracker.scala @@ -6,7 +6,7 @@ import chisel3.util._ // This module is meant to go inside the Load controller, where it can track which commands are currently // in flight and which are completed -class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => T) extends Module { +class DMACommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: => T) extends Module { def cmd_id_t = UInt((log2Ceil(nCmds) max 1).W) val io = IO(new Bundle { @@ -24,12 +24,6 @@ class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: override def cloneType: this.type = new BitsT(tag_t.cloneType, cmd_id_t.cloneType).asInstanceOf[this.type] } - /*val bits = new Bundle { - val tag = Input(tag_t) - val bytes_to_read = Input(UInt(log2Up(maxBytes+1).W)) - val cmd_id = Output(cmd_id_t) - }*/ - val bits = new BitsT(tag_t.cloneType, cmd_id_t.cloneType) def fire(dummy: Int = 0) = valid && ready @@ -43,11 +37,6 @@ class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: override def cloneType: this.type = new RequestReturnedT(cmd_id_t.cloneType).asInstanceOf[this.type] } - /*val request_returned = Flipped(Valid(new Bundle { - val bytes_read = UInt(log2Up(maxBytes+1).W) - val cmd_id = cmd_id_t - }))*/ - val request_returned = Flipped(Valid(new RequestReturnedT(cmd_id_t.cloneType))) class CmdCompletedT(cmd_id_t: UInt, tag_t: T) extends Bundle { @@ -57,11 +46,6 @@ class DMAReadCommandTracker[T <: Data](val nCmds: Int, val maxBytes: Int, tag_t: override def cloneType: this.type = new CmdCompletedT(cmd_id_t.cloneType, tag_t.cloneType).asInstanceOf[this.type] } - /*val cmd_completed = Decoupled(new Bundle { - val cmd_id = cmd_id_t - val tag = tag_t - })*/ - val cmd_completed = Decoupled(new CmdCompletedT(cmd_id_t.cloneType, tag_t.cloneType)) val busy = Output(Bool()) diff --git a/src/main/scala/gemmini/DMAWriteCommandTracker.scala b/src/main/scala/gemmini/DMAWriteCommandTracker.scala deleted file mode 100644 index 30765a44..00000000 --- a/src/main/scala/gemmini/DMAWriteCommandTracker.scala +++ /dev/null @@ -1,9 +0,0 @@ -package gemmini - -import chisel3._ -import chisel3.util._ - -object DMAWriteCommandTracker { - def apply[T <: Data](nCmds: Int, nRows: Int, tag_t: => T) = Module(new DMAReadCommandTracker(nCmds = nCmds, - maxBytes = nRows, tag_t = tag_t)) -} diff --git a/src/main/scala/gemmini/DSEConfigs.scala b/src/main/scala/gemmini/DSEConfigs.scala index 51bdc192..540fdac6 100644 --- a/src/main/scala/gemmini/DSEConfigs.scala +++ b/src/main/scala/gemmini/DSEConfigs.scala @@ -13,6 +13,7 @@ import freechips.rocketchip.tile.{BuildRoCC, OpcodeSet} object DSEBaseConfig { val baseConfig = GemminiArrayConfig[SInt, Bool, UInt]( + opcodes = OpcodeSet.custom3, tileRows = 1, tileColumns = 1, meshRows = 16, @@ -20,10 +21,15 @@ object DSEBaseConfig { ld_queue_length = 4, st_queue_length = 2, ex_queue_length = 8, - rob_entries = 8, + rob_full_entries = 8, + rob_partial_entries = 1, + sp_banks = 4, // TODO support one-bank designs acc_banks = 1, + acc_singleported = false, + num_acc_sub_banks = -1, sp_capacity = CapacityInKilobytes(64), + sp_singleported = false, shifter_banks = 1, // TODO add separate parameters for left and up shifter banks dataflow = Dataflow.OS, acc_capacity = CapacityInKilobytes(16), @@ -50,12 +56,17 @@ object DSEBaseConfig { val r = (point_five & (zeros | ones_digit)).asBool() (t >> u).asSInt() + Mux(r, 1.S, 0.S) - }, 0, UInt(8.W)), + }, 0, UInt(8.W), -1), acc_read_full_width = true, acc_read_small_width = true, use_dedicated_tl_port = false, pe_latency = 0, + ex_read_from_spad = true, + ex_read_from_acc = true, + ex_write_to_spad = true, + ex_write_to_acc = true, + tlb_size = 4, use_tlb_register_filter = true, max_in_flight_reqs = 16, @@ -90,7 +101,7 @@ class GemminiParamsDSE1 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.baseConfig)) + LazyModule(new Gemmini(DSEConfigs.baseConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -102,7 +113,7 @@ class GemminiParamsDSE2 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.wsOnlyConfig)) + LazyModule(new Gemmini(DSEConfigs.wsOnlyConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -114,7 +125,7 @@ class GemminiParamsDSE3 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.bothDataflowsConfig)) + LazyModule(new Gemmini(DSEConfigs.bothDataflowsConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -126,7 +137,7 @@ class GemminiParamsDSE4 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.highBitwidthConfig)) + LazyModule(new Gemmini(DSEConfigs.highBitwidthConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -138,7 +149,7 @@ class GemminiParamsDSE5 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.largerDimConfig)) + LazyModule(new Gemmini(DSEConfigs.largerDimConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -150,7 +161,7 @@ class GemminiParamsDSE6 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.fullyCombinationalConfig)) + LazyModule(new Gemmini(DSEConfigs.fullyCombinationalConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -162,7 +173,7 @@ class GemminiParamsDSE7 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.moreMemoryConfig)) + LazyModule(new Gemmini(DSEConfigs.moreMemoryConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -174,7 +185,7 @@ class GemminiParamsDSE8 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.moreBanksConfig)) + LazyModule(new Gemmini(DSEConfigs.moreBanksConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -186,7 +197,7 @@ class GemminiParamsDSE10 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.narrowerBusConfig)) + LazyModule(new Gemmini(DSEConfigs.narrowerBusConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 8) @@ -198,7 +209,7 @@ class GemminiParamsPnR16 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.pnr16Config)) + LazyModule(new Gemmini(DSEConfigs.pnr16Config)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -210,7 +221,7 @@ class GemminiParamsPnR32 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.pnr32Config)) + LazyModule(new Gemmini(DSEConfigs.pnr32Config)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) @@ -222,7 +233,7 @@ class GemminiParamsDSE11 extends Config((site, here, up) => { (p: Parameters) => { implicit val q = p implicit val v = implicitly[ValName] - LazyModule(new Gemmini(OpcodeSet.custom3, DSEConfigs.baseConfig)) + LazyModule(new Gemmini(DSEConfigs.baseConfig)) } ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index c3601d50..d4e2089c 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -27,7 +27,15 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } val acc = new Bundle { - val read = Vec(acc_banks, new AccumulatorReadIO(acc_bank_entries, log2Up(accType.getWidth), Vec(meshColumns, Vec(tileColumns, inputType)), Vec(meshColumns, Vec(tileColumns, accType)), acc_scale_args.multiplicand_t)) + val read_req = Vec(acc_banks, Decoupled(new AccumulatorReadReq( + acc_bank_entries, log2Up(accType.getWidth), acc_scale_args.multiplicand_t + ))) + + val read_resp = Flipped(Vec(acc_banks, Decoupled(new AccumulatorScaleResp( + Vec(meshColumns, Vec(tileColumns, inputType)), + Vec(meshColumns, Vec(tileColumns, accType)) + )))) + // val write = Vec(acc_banks, new AccumulatorWriteIO(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType)))) val write = Vec(acc_banks, Decoupled(new AccumulatorWriteReq(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType))))) } @@ -131,7 +139,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In //val row_turn_counter = RegInit(row_turn) im2col_en := Mux(weight_stride === 0.U, false.B, true.B) - // SRAM addresses of matmul operands val a_address_rs1 = rs1s(a_address_place).asTypeOf(local_addr_t) val b_address_rs2 = rs2s(b_address_place).asTypeOf(local_addr_t) @@ -180,19 +187,22 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh.io.a.valid := false.B mesh.io.b.valid := false.B mesh.io.d.valid := false.B - mesh.io.tag_in.valid := false.B - mesh.io.flush.valid := control_state === flush && !cntl_valid // We want to make sure that the mesh has absorbed all inputs before flushing + mesh.io.req.valid := control_state === flush mesh.io.a.bits := DontCare mesh.io.b.bits := DontCare mesh.io.d.bits := DontCare - mesh.io.tag_in.bits := DontCare - mesh.io.pe_control.propagate := Mux(control_state === flush, in_prop_flush, cntl.prop) - mesh.io.pe_control.dataflow := cntl.dataflow - mesh.io.pe_control.shift := cntl.shift - mesh.io.a_transpose := a_transpose - mesh.io.bd_transpose := bd_transpose - mesh.io.flush.bits := 0.U + mesh.io.req.bits.tag := DontCare + mesh.io.req.bits.tag.cols := cntl.c_cols + mesh.io.req.bits.tag.rows := cntl.c_rows + mesh.io.req.bits.total_rows := block_size.U + mesh.io.req.bits.pe_control.propagate := Mux(control_state === flush, in_prop_flush, cntl.prop) + mesh.io.req.bits.pe_control.dataflow := cntl.dataflow + mesh.io.req.bits.pe_control.shift := cntl.shift + mesh.io.req.bits.a_transpose := cntl.a_transpose + mesh.io.req.bits.bd_transpose := cntl.bd_transpose + mesh.io.req.bits.tag.rob_id := cntl.rob_id + mesh.io.req.bits.flush := Mux(control_state === flush && !cntl_valid, 1.U, 0.U) // We want to make sure that the mesh has absorbed all inputs before flushing // Hazards val raw_hazard_pre = mesh.io.tags_in_progress.map { t => @@ -211,6 +221,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In !is_garbage && (mul_raw_haz || pre_raw_haz) }.reduce(_ || _) + val raw_hazards_are_impossible = !ex_read_from_acc && !ex_write_to_spad // Special case where RAW hazards are impossible + val matmul_in_progress = mesh.io.tags_in_progress.map(_.rob_id.valid).reduce(_ || _) io.busy := cmd.valid(0) || matmul_in_progress @@ -221,7 +233,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val b_fire_counter = Reg(UInt(log2Up(block_size).W)) val d_fire_counter = Reg(UInt(log2Up(block_size).W)) - // These "*_fire_started" variables are only needed for 2x2 systolic arrays val a_fire_started = RegInit(false.B) val d_fire_started = RegInit(false.B) val b_fire_started = RegInit(false.B) @@ -242,15 +253,19 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val dataBBankAcc = b_address.acc_bank() val dataDBankAcc = d_address.acc_bank() - val a_read_from_acc = a_address_rs1.is_acc_addr - val b_read_from_acc = b_address_rs2.is_acc_addr - val d_read_from_acc = d_address_rs1.is_acc_addr + val a_read_from_acc = ex_read_from_acc.B && a_address_rs1.is_acc_addr + val b_read_from_acc = ex_read_from_acc.B && b_address_rs2.is_acc_addr + val d_read_from_acc = ex_read_from_acc.B && d_address_rs1.is_acc_addr val start_inputting_a = WireInit(false.B) val start_inputting_b = WireInit(false.B) val start_inputting_d = WireInit(false.B) val start_array_outputting = WireInit(false.B) + val a_garbage = a_address_rs1.is_garbage() || !start_inputting_a + val b_garbage = b_address_rs2.is_garbage() || !start_inputting_b + val d_garbage = d_address_rs1.is_garbage() || !start_inputting_d + // TODO merge these into one enum val perform_single_preload = RegInit(false.B) val perform_single_mul = RegInit(false.B) @@ -260,6 +275,25 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val performing_single_mul = WireInit(perform_single_mul && control_state === compute) val performing_mul_pre = WireInit(perform_mul_pre && control_state === compute) + val total_rows = WireInit(block_size.U) // The total number of rows of A, B, and D to feed into the mesh + + // TODO Also reduce the number of rows when "perform_single_preload === true.B" + when (current_dataflow === Dataflow.WS.id.U && d_garbage && + !a_should_be_fed_into_transposer && !b_should_be_fed_into_transposer && !d_should_be_fed_into_transposer) { + val rows_a = Mux(a_garbage, 1.U, a_rows) + val rows_b = Mux(b_garbage, 1.U, b_rows) + + /* We can only retire one ROB instruction per cycle (max), but if total_rows == 1, then we would be trying to retire + 2 ROB instructions per cycle (one for the preload, and one for the compute). Therefore, to prevent ROB + instructions from being lost, we set a minimum floor for total_rows of 2. + + Furthermore, two writes to the same accumulator address must occur at least 4 cycles apart to allow the write to + fully propagate through. Therefore, we raise the minimum floor for total_rows to 4. + TODO: add a WAW check to the ROB so that we can lower the floor back to 2 + */ + total_rows := maxOf(maxOf(rows_a, rows_b), 4.U) + } + //added for mul_pre sync val mul_pre_counter_sub = RegInit(0.U(3.W)) val mul_pre_counter_count = RegInit(0.U(3.W)) @@ -301,16 +335,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val others = operands.filter(_.priority != priority) val same_banks = others.map(o => same_bank(addr, o.addr, is_garbage, o.is_garbage, start_inputting, o.start_inputting, can_be_im2colled || o.can_be_im2colled)) - val same_counter = others.map(o => counter === o.counter) - - val one_ahead = { - if (block_size > 2) - others.map(o => counter === wrappingAdd(o.counter, 1.U, block_size)) - else { - others.map(o => (started && !o.started && counter === 1.U && o.counter === 0.U) || - (started && o.started && counter === 0.U && o.counter === 1.U)) - } - } + val same_counter = others.map(o => started === o.started && counter === o.counter) + + val one_ahead = others.map(o => started && counter === wrappingAdd(o.counter, 1.U, total_rows)) val higher_priorities = others.map(o => (o.priority < priority).B) @@ -322,9 +349,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In !must_wait_for.reduce(_ || _) } - val a_fire = a_valid && a_ready - dontTouch(a_fire) val b_fire = b_valid && b_ready val d_fire = d_valid && d_ready @@ -334,26 +359,25 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In a_fire_counter := 0.U a_addr_offset := 0.U }.elsewhen (firing && a_fire && cntl_ready) { - a_fire_counter := wrappingAdd(a_fire_counter, 1.U, block_size) - a_addr_offset := Mux(a_fire_counter === (block_size-1).U, 0.U, a_addr_offset + a_addr_stride) + a_fire_counter := wrappingAdd(a_fire_counter, 1.U, total_rows) + a_addr_offset := Mux(a_fire_counter === (total_rows-1.U), 0.U, a_addr_offset + a_addr_stride) a_fire_started := true.B } when (!firing) { b_fire_counter := 0.U }.elsewhen (firing && b_fire && cntl_ready) { - b_fire_counter := wrappingAdd(b_fire_counter, 1.U, block_size) + b_fire_counter := wrappingAdd(b_fire_counter, 1.U, total_rows) b_fire_started := true.B } when (!firing) { d_fire_counter := 0.U }.elsewhen (firing && d_fire && cntl_ready) { - d_fire_counter := wrappingAdd(d_fire_counter, 1.U, block_size) + d_fire_counter := wrappingAdd(d_fire_counter, 1.U, total_rows) d_fire_started := true.B } - when(performing_mul_pre && !cntl_ready && !mul_pre_counter_lock){ mul_pre_counter_count := d_fire_counter //store 2 }.elsewhen(!performing_mul_pre){ @@ -371,18 +395,16 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In // The last line in this (long) Boolean is just to make sure that we don't think we're done as soon as we begin firing // TODO change when square requirement lifted - val about_to_fire_all_rows = ((a_fire_counter === (block_size-1).U && a_valid) || a_fire_counter === 0.U) && - ((b_fire_counter === (block_size-1).U && b_valid) || b_fire_counter === 0.U) && - ((d_fire_counter === (block_size-1).U && d_valid) || d_fire_counter === 0.U) && - (a_fire_counter =/= 0.U || b_fire_counter =/= 0.U || d_fire_counter =/= 0.U) && + val about_to_fire_all_rows = ((a_fire_counter === (total_rows-1.U) && a_fire) || a_fire_counter === 0.U) && + ((b_fire_counter === (total_rows-1.U) && b_fire) || b_fire_counter === 0.U) && + ((d_fire_counter === (total_rows-1.U) && d_fire) || d_fire_counter === 0.U) && + (a_fire_started || b_fire_started || d_fire_started) && cntl_ready - if (block_size == 2) { - when (about_to_fire_all_rows) { - a_fire_started := false.B - b_fire_started := false.B - d_fire_started := false.B - } + when (about_to_fire_all_rows) { + a_fire_started := false.B + b_fire_started := false.B + d_fire_started := false.B } val d_fire_counter_mulpre = WireInit(b_fire_counter) @@ -396,26 +418,32 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val read_b = b_valid && !b_read_from_acc && dataBbank === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire val read_d = d_valid && !d_read_from_acc && dataDbank === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire - Seq((read_a, a_ready), (read_b, b_ready), (read_d, d_ready)).foreach { case (rd, r) => when (rd && !io.srams.read(i).req.ready) { r := false.B } } - io.srams.read(i).req.valid := read_a || read_b || read_d - io.srams.read(i).req.bits.fromDMA := false.B - io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter, - Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter), - read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre))) - - when(im2col_en === false.B){ - io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(), - Seq(read_b -> b_address.sp_row(), - read_d -> d_address.sp_row())) + if (ex_read_from_spad) { + io.srams.read(i).req.valid := (read_a || read_b || read_d) && cntl_ready + io.srams.read(i).req.bits.fromDMA := false.B + io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter, + Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter), + read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre))) + + // TODO this just overrides the previous line. Should we erase the previous line? + when(im2col_en === false.B) { + io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(), + Seq(read_b -> b_address.sp_row(), + read_d -> d_address.sp_row())) + } + } else { + io.srams.read(i).req.valid := false.B + io.srams.read(i).req.bits.fromDMA := false.B + io.srams.read(i).req.bits.addr := DontCare } - io.srams.read(i).resp.ready := true.B + io.srams.read(i).resp.ready := false.B } // Accumulator read @@ -425,45 +453,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire Seq((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) => - when(rd && !io.acc.read(i).req.ready) { + when(rd && !io.acc.read_req(i).ready) { r := false.B } } - /* - io.acc.read(i).req.valid := read_a_from_acc || read_b_from_acc || read_d_from_acc - io.acc.read(i).req.bits.scale := acc_scale - io.acc.read(i).req.bits.full := false.B - io.acc.read(i).req.bits.relu6_shift := relu6_shift - io.acc.read(i).req.bits.act := activation - io.acc.read(i).req.bits.fromDMA := false.B - io.acc.read(i).req.bits.addr := MuxCase(a_address_rs1.acc_row() + a_fire_counter, - Seq(read_b_from_acc -> (b_address_rs2.acc_row() + b_fire_counter), - read_d_from_acc -> (d_address_rs1.acc_row() + block_size.U - 1.U - d_fire_counter))) - - when(im2col_en === false.B){ - io.acc.read(i).req.bits.addr := MuxCase(a_address.acc_row(), - Seq(read_b_from_acc -> b_address.acc_row(), - read_d_from_acc -> d_address.acc_row())) - } - */ - - // TODO Remove the ability to read into Mesh from AccumulatorMem completely - io.acc.read(i).req.valid := false.B - io.acc.read(i).req.bits.scale := acc_scale - io.acc.read(i).req.bits.full := false.B - io.acc.read(i).req.bits.relu6_shift := relu6_shift - io.acc.read(i).req.bits.act := activation - io.acc.read(i).req.bits.fromDMA := false.B - io.acc.read(i).req.bits.addr := DontCare - - when(im2col_en === false.B){ - io.acc.read(i).req.bits.addr := MuxCase(a_address.acc_row(), - Seq(read_b_from_acc -> b_address.acc_row(), - read_d_from_acc -> d_address.acc_row())) + if (ex_read_from_acc) { + io.acc.read_req(i).valid := read_a_from_acc || read_b_from_acc || read_d_from_acc + io.acc.read_req(i).bits.scale := acc_scale + io.acc.read_req(i).bits.full := false.B + io.acc.read_req(i).bits.relu6_shift := relu6_shift + io.acc.read_req(i).bits.act := activation + io.acc.read_req(i).bits.fromDMA := false.B + io.acc.read_req(i).bits.addr := MuxCase(a_address_rs1.acc_row() + a_fire_counter, + Seq(read_b_from_acc -> (b_address_rs2.acc_row() + b_fire_counter), + read_d_from_acc -> (d_address_rs1.acc_row() + block_size.U - 1.U - d_fire_counter))) + + // TODO this just overrides the previous line. Should we erase the previous line? + when(im2col_en === false.B){ + io.acc.read_req(i).bits.addr := MuxCase(a_address.acc_row(), + Seq(read_b_from_acc -> b_address.acc_row(), + read_d_from_acc -> d_address.acc_row())) + } + } else { + io.acc.read_req(i).valid := false.B + io.acc.read_req(i).bits.scale := acc_scale + io.acc.read_req(i).bits.full := false.B + io.acc.read_req(i).bits.relu6_shift := relu6_shift + io.acc.read_req(i).bits.act := activation + io.acc.read_req(i).bits.fromDMA := false.B + io.acc.read_req(i).bits.addr := DontCare } - io.acc.read(i).resp.ready := true.B + io.acc.read_resp(i).ready := false.B } // Im2Col reads @@ -495,7 +517,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In io.im2col.resp.ready := mesh.io.a.ready } - // FSM logic switch (control_state) { is(waiting_for_cmd) { @@ -514,7 +535,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In in_shift := rs2s(0)(31, 0) // TODO magic number acc_scale := rs1s(0)(xLen - 1, 32).asTypeOf(acc_scale_args.multiplicand_t) // TODO magic number relu6_shift := rs2s(0)(xLen - 1, 32) // TODO magic number - a_addr_stride := rs1s(0)(31, 16) // TODO magic number + a_addr_stride := rs1s(0)(31, 16) // TODO magic number // TODO this needs to be kept in sync with ROB.scala a_transpose := rs1s(0)(8) bd_transpose := rs1s(0)(9) @@ -538,55 +559,51 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In io.completed := cmd.bits(0).rob_id cmd.pop := 1.U - } - // Preload - .elsewhen(DoPreloads(0) && cmd.valid(1) && !raw_hazard_pre) { - perform_single_preload := true.B - performing_single_preload := true.B + // Preload + .elsewhen(DoPreloads(0) && cmd.valid(1) && (raw_hazards_are_impossible.B || !raw_hazard_pre)) { + perform_single_preload := true.B + performing_single_preload := true.B - //start_inputting_a := current_dataflow === Dataflow.OS.id.U - //start_inputting_d := true.B + //start_inputting_a := current_dataflow === Dataflow.OS.id.U + //start_inputting_d := true.B - start_inputting_a := a_should_be_fed_into_transposer - start_inputting_b := b_should_be_fed_into_transposer - start_inputting_d := true.B - - control_state := compute - } + start_inputting_a := a_should_be_fed_into_transposer + start_inputting_b := b_should_be_fed_into_transposer + start_inputting_d := true.B - // Overlap compute and preload - .elsewhen(DoComputes(0) && cmd.valid(1) && DoPreloads(1) && cmd.valid(2) && !raw_hazard_mulpre) { - perform_mul_pre := true.B - performing_mul_pre := true.B + control_state := compute + } - start_inputting_a := true.B - start_inputting_b := true.B - start_inputting_d := true.B + // Overlap compute and preload + .elsewhen(DoComputes(0) && cmd.valid(1) && DoPreloads(1) && (raw_hazards_are_impossible.B || (cmd.valid(2) && !raw_hazard_mulpre))) { + perform_mul_pre := true.B + performing_mul_pre := true.B - control_state := compute - } + start_inputting_a := true.B + start_inputting_b := true.B + start_inputting_d := true.B - // Single mul - .elsewhen(DoComputes(0)) { - perform_single_mul := true.B - performing_single_mul := true.B + control_state := compute + } - //start_inputting_a := current_dataflow === Dataflow.WS.id.U - //start_inputting_b := true.B + // Single mul + .elsewhen(DoComputes(0)) { + perform_single_mul := true.B + performing_single_mul := true.B - start_inputting_a := !a_should_be_fed_into_transposer - start_inputting_b := !b_should_be_fed_into_transposer - start_inputting_b := true.B + start_inputting_a := !a_should_be_fed_into_transposer + start_inputting_b := !b_should_be_fed_into_transposer + start_inputting_b := true.B - control_state := compute - } + control_state := compute + } - // Flush - .elsewhen(matmul_in_progress) { - control_state := flush - } - }.elsewhen(matmul_in_progress) { + // Flush + .elsewhen(matmul_in_progress && (current_dataflow === Dataflow.OS.id.U || DoConfig)) { + control_state := flush + } + }.elsewhen(matmul_in_progress && current_dataflow === Dataflow.OS.id.U) { // TODO code duplication control_state := flush } @@ -610,51 +627,49 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } } - // Overlapping - .elsewhen(perform_mul_pre) { - start_inputting_a := true.B - start_inputting_b := true.B - start_inputting_d := true.B + // Overlapping + .elsewhen(perform_mul_pre) { + start_inputting_a := true.B + start_inputting_b := true.B + start_inputting_d := true.B - when(about_to_fire_all_rows) { - cmd.pop := 2.U - control_state := waiting_for_cmd + when(about_to_fire_all_rows) { + cmd.pop := 2.U + control_state := waiting_for_cmd - pending_completed_rob_ids(0) := cmd.bits(0).rob_id - pending_completed_rob_ids(1).valid := cmd.bits(1).rob_id.valid && c_address_rs2.is_garbage() - pending_completed_rob_ids(1).bits := cmd.bits(1).rob_id.bits + pending_completed_rob_ids(0) := cmd.bits(0).rob_id + pending_completed_rob_ids(1).valid := cmd.bits(1).rob_id.valid && c_address_rs2.is_garbage() + pending_completed_rob_ids(1).bits := cmd.bits(1).rob_id.bits - when(current_dataflow === Dataflow.OS.id.U) { - in_prop_flush := !rs2s(1).asTypeOf(local_addr_t).is_garbage() - } + when(current_dataflow === Dataflow.OS.id.U) { + in_prop_flush := !rs2s(1).asTypeOf(local_addr_t).is_garbage() } } - // Only compute - .elsewhen(perform_single_mul) { - start_inputting_a := !a_should_be_fed_into_transposer - start_inputting_b := !b_should_be_fed_into_transposer - - when(about_to_fire_all_rows) { - cmd.pop := 1.U - control_state := waiting_for_cmd + } + // Only compute + .elsewhen(perform_single_mul) { + start_inputting_a := !a_should_be_fed_into_transposer + start_inputting_b := !b_should_be_fed_into_transposer - pending_completed_rob_ids(0) := cmd.bits(0).rob_id - } - } - } - is(flush) { - when(mesh.io.flush.fire()) { - control_state := flushing - } - } - is(flushing) { - when(mesh.io.flush.ready) { - // TODO we waste a cycle here if it was better to continue with the flush - control_state := waiting_for_cmd - } - } + when(about_to_fire_all_rows) { + cmd.pop := 1.U + control_state := waiting_for_cmd + pending_completed_rob_ids(0) := cmd.bits(0).rob_id } - + } + } + is(flush) { + when(mesh.io.req.fire()) { + control_state := flushing + } + } + is(flushing) { + when(mesh.io.req.ready) { + // TODO we waste a cycle here if it was better to continue with the flush + control_state := waiting_for_cmd + } + } + } // Computing logic val computing = performing_mul_pre || performing_single_mul || performing_single_preload @@ -695,6 +710,11 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val c_rows = UInt(log2Up(block_size + 1).W) val c_cols = UInt(log2Up(block_size + 1).W) + val a_transpose = Bool() + val bd_transpose = Bool() + + val total_rows = UInt(log2Up(block_size + 1).W) + val rob_id = UDValid(UInt(log2Up(rob_entries).W)) val dataflow = UInt(1.W) @@ -702,6 +722,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val shift = UInt(log2Up(accType.getWidth).W) val im2colling = Bool() + + val first = Bool() } mesh_cntl_signals_q.io.enq.valid := computing @@ -718,9 +740,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.enq.bits.b_bank_acc := dataBBankAcc mesh_cntl_signals_q.io.enq.bits.d_bank_acc := dataDBankAcc - mesh_cntl_signals_q.io.enq.bits.a_garbage := a_address_rs1.is_garbage() || !start_inputting_a - mesh_cntl_signals_q.io.enq.bits.b_garbage := b_address_rs2.is_garbage() || !start_inputting_b - mesh_cntl_signals_q.io.enq.bits.d_garbage := d_address_rs1.is_garbage() || !start_inputting_d + mesh_cntl_signals_q.io.enq.bits.a_garbage := a_garbage + mesh_cntl_signals_q.io.enq.bits.b_garbage := b_garbage + mesh_cntl_signals_q.io.enq.bits.d_garbage := d_garbage mesh_cntl_signals_q.io.enq.bits.a_read_from_acc := a_read_from_acc mesh_cntl_signals_q.io.enq.bits.b_read_from_acc := b_read_from_acc @@ -733,6 +755,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.enq.bits.b_unpadded_cols := Mux(b_row_is_not_all_zeros, b_cols, 0.U) mesh_cntl_signals_q.io.enq.bits.d_unpadded_cols := Mux(d_row_is_not_all_zeros, d_cols, 0.U) + mesh_cntl_signals_q.io.enq.bits.total_rows := total_rows + mesh_cntl_signals_q.io.enq.bits.a_fire := a_fire mesh_cntl_signals_q.io.enq.bits.b_fire := b_fire mesh_cntl_signals_q.io.enq.bits.d_fire := d_fire @@ -741,6 +765,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.enq.bits.c_rows := c_rows mesh_cntl_signals_q.io.enq.bits.c_cols := c_cols + mesh_cntl_signals_q.io.enq.bits.a_transpose := a_transpose + mesh_cntl_signals_q.io.enq.bits.bd_transpose := bd_transpose + mesh_cntl_signals_q.io.enq.bits.rob_id.valid := !performing_single_mul && !c_address_rs2.is_garbage() mesh_cntl_signals_q.io.enq.bits.rob_id.bits := cmd.bits(preload_cmd_place).rob_id.bits @@ -750,17 +777,20 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In mesh_cntl_signals_q.io.enq.bits.im2colling := im2col_wire && im2col_en //im2col_wire + mesh_cntl_signals_q.io.enq.bits.first := !a_fire_started && !b_fire_started && !d_fire_started + val readData = VecInit(io.srams.read.map(_.resp.bits.data)) - val accReadData = readData // VecInit(io.acc.read.map(_.resp.bits.data.asUInt())) // TODO remove ability to read from AccumulatorMem + val accReadData = if (ex_read_from_acc) VecInit(io.acc.read_resp.map(_.bits.data.asUInt())) else readData val im2ColData = io.im2col.resp.bits.a_im2col.asUInt() - val readValid = VecInit(io.srams.read.map(bank => bank.resp.valid && !bank.resp.bits.fromDMA)) - val accReadValid = false.B // VecInit(io.acc.read.map(bank => bank.resp.valid && !bank.resp.bits.fromDMA)) // TODO remove ability to read from AccumulatorMem + val readValid = VecInit(io.srams.read.map(bank => ex_read_from_spad.B && bank.resp.valid && !bank.resp.bits.fromDMA)) + val accReadValid = VecInit(io.acc.read_resp.map(bank => ex_read_from_acc.B && bank.valid && !bank.bits.fromDMA)) val im2ColValid = io.im2col.resp.valid mesh_cntl_signals_q.io.deq.ready := (!cntl.a_fire || mesh.io.a.fire() || !mesh.io.a.ready) && (!cntl.b_fire || mesh.io.b.fire() || !mesh.io.b.ready) && - (!cntl.d_fire || mesh.io.d.fire() || !mesh.io.d.ready) + (!cntl.d_fire || mesh.io.d.fire() || !mesh.io.d.ready) && + (!cntl.first || mesh.io.req.ready) val dataA_valid = cntl.a_garbage || cntl.a_unpadded_cols === 0.U || Mux(cntl.im2colling, im2ColValid, Mux(cntl.a_read_from_acc, accReadValid(cntl.a_bank_acc), readValid(cntl.a_bank))) @@ -786,59 +816,92 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}) val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}) + // Pop responses off the scratchpad io ports + when (mesh_cntl_signals_q.io.deq.fire()) { + when (cntl.a_fire && mesh.io.a.fire() && !cntl.a_garbage && cntl.a_unpadded_cols > 0.U && !cntl.im2colling) { + when (cntl.a_read_from_acc) { + io.acc.read_resp(cntl.a_bank_acc).ready := !io.acc.read_resp(cntl.a_bank_acc).bits.fromDMA + }.otherwise { + io.srams.read(cntl.a_bank).resp.ready := !io.srams.read(cntl.a_bank).resp.bits.fromDMA + } + } + + when (cntl.b_fire && mesh.io.b.fire() && !cntl.b_garbage && !cntl.accumulate_zeros && cntl.b_unpadded_cols > 0.U) { + when (cntl.b_read_from_acc) { + io.acc.read_resp(cntl.b_bank_acc).ready := !io.acc.read_resp(cntl.b_bank_acc).bits.fromDMA + }.otherwise { + io.srams.read(cntl.b_bank).resp.ready := !io.srams.read(cntl.b_bank).resp.bits.fromDMA + } + } + + when (cntl.d_fire && mesh.io.d.fire() && !cntl.d_garbage && !cntl.preload_zeros && cntl.d_unpadded_cols > 0.U) { + when (cntl.d_read_from_acc) { + io.acc.read_resp(cntl.d_bank_acc).ready := !io.acc.read_resp(cntl.d_bank_acc).bits.fromDMA + }.otherwise { + io.srams.read(cntl.d_bank).resp.ready := !io.srams.read(cntl.d_bank).resp.bits.fromDMA + } + } + } + + if (!ex_read_from_acc) { + for (acc_r <- io.acc.read_resp) { + acc_r.ready := true.B + } + } + when (cntl_valid) { // Default inputs mesh.io.a.valid := cntl.a_fire && dataA_valid - mesh.io.b.valid := (cntl.b_fire && dataB_valid) - mesh.io.d.valid := (cntl.d_fire && dataD_valid) - mesh.io.tag_in.valid := true.B + mesh.io.b.valid := cntl.b_fire && dataB_valid + mesh.io.d.valid := cntl.d_fire && dataD_valid mesh.io.a.bits := dataA.asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) - mesh.io.tag_in.bits.rob_id := cntl.rob_id - mesh.io.tag_in.bits.addr := cntl.c_addr - mesh.io.tag_in.bits.cols := cntl.c_cols - mesh.io.tag_in.bits.rows := cntl.c_rows + mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire() && (cntl.a_fire || cntl.b_fire || cntl.d_fire) + + mesh.io.req.bits.tag.addr := cntl.c_addr + + mesh.io.req.bits.total_rows := cntl.total_rows } when (cntl_valid && cntl.perform_single_preload) { - // mesh.io.a.bits := Mux(cntl.dataflow === Dataflow.WS.id.U, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.a.bits := Mux(a_should_be_fed_into_transposer, dataA.asUInt, 0.U).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) - // mesh.io.b.bits := 0.U.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, dataB.asUInt, 0.U).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) } when (cntl_valid && cntl.perform_single_mul) { - // mesh.io.a.bits := Mux(cntl.dataflow === Dataflow.OS.id.U, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.a.bits := Mux(a_should_be_fed_into_transposer, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, 0.U, dataB.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) - mesh.io.tag_in.bits.addr.make_this_garbage() + mesh.io.req.bits.tag.addr.make_this_garbage() } // Scratchpad writes - val output_counter = new Counter(block_size) + // val output_counter = new Counter(block_size) + val output_counter = RegInit(0.U(log2Up(block_size).W)) - val w_address = Mux(current_dataflow === Dataflow.WS.id.U, mesh.io.tag_out.addr + output_counter.value, - mesh.io.tag_out.addr + (block_size.U - 1.U - output_counter.value)) + val w_total_output_rows = mesh.io.resp.bits.total_rows + + val w_address = Mux(current_dataflow === Dataflow.WS.id.U, mesh.io.resp.bits.tag.addr + output_counter, + mesh.io.resp.bits.tag.addr + (w_total_output_rows - 1.U - output_counter)) val write_to_acc = w_address.is_acc_addr val w_bank = Mux(write_to_acc, w_address.acc_bank(), w_address.sp_bank()) val w_row = Mux(write_to_acc, w_address.acc_row(), w_address.sp_row()) - val is_garbage_addr = mesh.io.tag_out.addr.is_garbage() + val is_garbage_addr = mesh.io.resp.bits.tag.addr.is_garbage() - val w_matrix_rows = mesh.io.tag_out.rows - val w_matrix_cols = mesh.io.tag_out.cols + val w_matrix_rows = mesh.io.resp.bits.tag.rows + val w_matrix_cols = mesh.io.resp.bits.tag.cols - val write_this_row = Mux(current_dataflow === Dataflow.WS.id.U, output_counter.value < w_matrix_rows, - block_size.U - 1.U - output_counter.value < w_matrix_rows) + val write_this_row = Mux(current_dataflow === Dataflow.WS.id.U, output_counter < w_matrix_rows, + w_total_output_rows - 1.U - output_counter < w_matrix_rows) val w_mask = (0 until block_size).map(_.U < w_matrix_cols) // This is an element-wise mask, rather than a byte-wise mask // Write to normal scratchpad for(i <- 0 until sp_banks) { - val activated_wdata = VecInit(mesh.io.out.bits.map(v => VecInit(v.map { e => + val activated_wdata = VecInit(mesh.io.resp.bits.data.map(v => VecInit(v.map { e => val e_clipped = e.clippedToWidthOf(inputType) val e_act = MuxCase(e_clipped, Seq( (activation === Activation.RELU) -> e_clipped.relu, @@ -847,20 +910,34 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In e_act }))) - io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row - io.srams.write(i).addr := w_row - io.srams.write(i).data := activated_wdata.asUInt() - // io.srams.write(i).mask := VecInit(Seq.fill(io.srams.write(0).mask.length)(true.B)) - io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) + if (ex_write_to_spad) { + io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row + io.srams.write(i).addr := w_row + io.srams.write(i).data := activated_wdata.asUInt() + io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) + } else { + io.srams.write(i).en := false.B + io.srams.write(i).addr := DontCare + io.srams.write(i).data := DontCare + io.srams.write(i).mask := DontCare + } } // Write to accumulator for (i <- 0 until acc_banks) { - io.acc.write(i).valid := start_array_outputting && w_bank === i.U && write_to_acc && !is_garbage_addr && write_this_row - io.acc.write(i).bits.addr := w_row - io.acc.write(i).bits.data := VecInit(mesh.io.out.bits.map(v => VecInit(v.map(e => e.withWidthOf(accType))))) - io.acc.write(i).bits.acc := w_address.accumulate - io.acc.write(i).bits.mask := w_mask.flatMap(b => Seq.fill(accType.getWidth / (aligned_to * 8))(b)) + if (ex_write_to_acc) { + io.acc.write(i).valid := start_array_outputting && w_bank === i.U && write_to_acc && !is_garbage_addr && write_this_row + io.acc.write(i).bits.addr := w_row + io.acc.write(i).bits.data := VecInit(mesh.io.resp.bits.data.map(v => VecInit(v.map(e => e.withWidthOf(accType))))) + io.acc.write(i).bits.acc := w_address.accumulate + io.acc.write(i).bits.mask := w_mask.flatMap(b => Seq.fill(accType.getWidth / (aligned_to * 8))(b)) + } else { + io.acc.write(i).valid := false.B + io.acc.write(i).bits.addr := DontCare + io.acc.write(i).bits.data := DontCare + io.acc.write(i).bits.acc := DontCare + io.acc.write(i).bits.mask := DontCare + } assert(!(io.acc.write(i).valid && !io.acc.write(i).ready), "Execute controller write to AccumulatorMem was skipped") } @@ -870,13 +947,14 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In //val complete_lock = RegInit(false.B) //Seah: added for WS accumulator - when(mesh.io.out.fire() && mesh.io.tag_out.rob_id.valid) { - //when(current_dataflow === Dataflow.WS.id.U) { - when(output_counter.inc()) { + when(mesh.io.resp.fire() && mesh.io.resp.bits.tag.rob_id.valid) { + output_counter := wrappingAdd(output_counter, 1.U, w_total_output_rows) + val last = mesh.io.resp.bits.last + + when(last) { mesh_completed_rob_id_fire := true.B io.completed.valid := true.B - io.completed.bits := mesh.io.tag_out.rob_id.bits - + io.completed.bits := mesh.io.resp.bits.tag.rob_id.bits } start_array_outputting := !is_garbage_addr } @@ -891,7 +969,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } val complete_bits_count = RegInit(0.U(15.W)) - when(io.completed.valid){ + when(io.completed.valid) { complete_bits_count := complete_bits_count + 1.U } dontTouch(complete_bits_count) diff --git a/src/main/scala/gemmini/FrontendTLB.scala b/src/main/scala/gemmini/FrontendTLB.scala index 0980d811..73416816 100644 --- a/src/main/scala/gemmini/FrontendTLB.scala +++ b/src/main/scala/gemmini/FrontendTLB.scala @@ -96,7 +96,6 @@ class FrontendTLB(nClients: Int, entries: Int, maxSize: Int) val l0_tlb_hit = last_translated_valid && ((client.req.bits.tlb_req.vaddr >> pgIdxBits) === (last_translated_vpn >> pgIdxBits)) val l0_tlb_paddr = Cat(last_translated_ppn >> pgIdxBits, client.req.bits.tlb_req.vaddr(pgIdxBits-1,0)) - when (req.fire() && !tlb.io.resp.miss) { last_translated_valid := true.B last_translated_vpn := req.bits.tlb_req.vaddr diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index fb9026d4..6ee0c168 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -4,15 +4,18 @@ package gemmini import scala.math.{pow,sqrt} import chisel3._ import chisel3.util._ +import freechips.rocketchip.tile._ sealed abstract trait GemminiMemCapacity case class CapacityInKilobytes(kilobytes: Int) extends GemminiMemCapacity case class CapacityInMatrices(matrices: Int) extends GemminiMemCapacity case class ScaleArguments[T <: Data, U <: Data](scale_func: (T, U) => T, latency: Int, multiplicand_t: U, + num_scale_units: Int, identity: String="0", c_str: String="ROUNDING_RIGHT_SHIFT(x, scale)") case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( + opcodes: OpcodeSet, tileRows: Int, tileColumns: Int, meshRows: Int, @@ -20,10 +23,14 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( ld_queue_length: Int, st_queue_length: Int, ex_queue_length: Int, - rob_entries: Int, + rob_full_entries: Int, + rob_partial_entries: Int, sp_banks: Int, // TODO support one-bank designs + sp_singleported: Boolean, sp_capacity: GemminiMemCapacity, acc_banks: Int, + acc_singleported: Boolean, + num_acc_sub_banks: Int, acc_capacity: GemminiMemCapacity, shifter_banks: Int, dataflow: Dataflow.Value, @@ -50,6 +57,11 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( use_tlb_register_filter: Boolean, max_in_flight_reqs: Int, + ex_read_from_spad: Boolean, + ex_read_from_acc: Boolean, + ex_write_to_spad: Boolean, + ex_write_to_acc: Boolean, + headerFileName: String = "gemmini_params.h" ) { val sp_width = meshColumns * tileColumns * inputType.getWidth @@ -61,16 +73,17 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( case CapacityInKilobytes(kb) => kb * 1024 * 8 / (acc_banks * meshColumns * tileColumns * accType.getWidth) case CapacityInMatrices(ms) => ms * meshRows * tileRows / acc_banks } + require (!acc_singleported || (num_acc_sub_banks <= 4 && isPow2(num_acc_sub_banks))) val local_addr_t = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries) val mvin_scale_t = mvin_scale_args match { - case Some(ScaleArguments(_, _, t, _, _)) => t + case Some(ScaleArguments(_, _, t, _, _, _)) => t case None => Bool() // TODO replace this with UInt(0.W) } val mvin_scale_acc_t = mvin_scale_acc_args match { - case Some(ScaleArguments(_, _, t, _, _)) => t + case Some(ScaleArguments(_, _, t, _, _, _)) => t case None => Bool() // TODO replace this with UInt(0.W) } @@ -81,13 +94,14 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( val acc_scale_t_bits = acc_scale_t.getWidth - // val max_in_flight_reqs = 16 // TODO calculate this somehow - - val mvin_len_bits = log2Up(((dma_maxbytes / (inputType.getWidth / 8)) max (meshColumns * tileColumns)) + 1) - val mvin_rows_bits = 16 // log2Up(meshRows * tileRows + 1) - val mvout_len_bits = log2Up(meshColumns * tileColumns + 1) + val mvin_cols_bits = log2Up(((dma_maxbytes / (inputType.getWidth / 8)) max (meshColumns * tileColumns)) + 1) + val mvin_rows_bits = log2Up(meshRows * tileRows + 1) + val mvout_cols_bits = log2Up(((dma_maxbytes / (inputType.getWidth / 8)) max (meshColumns * tileColumns)) + 1) val mvout_rows_bits = log2Up(meshRows * tileRows + 1) + val load_states = 3 + val block_stride_bits = 16 + //========================================================================== // sanity check mesh size //========================================================================== @@ -103,6 +117,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( //========================================================================== // cisc-gemmini miscellaneous constants (some redundant with above) //========================================================================== + val rob_entries = rob_full_entries + rob_partial_entries val ROB_ENTRIES = rob_entries val LOG2_ROB_ENTRIES = log2Up(rob_entries) @@ -110,13 +125,15 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( // cisc-gemmini hardware-specific compile-time global constants //========================================================================== + val cisc_dim = (meshRows * tileRows) / 2 + val ITYPE_BITS = inputType.getWidth - val ITYPE_BYTES = (inputType.getWidth+7) / 8 + val ITYPE_BYTES = (inputType.getWidth+cisc_dim-1) / cisc_dim val LOG2_ITYPE_BYTES = if(ITYPE_BYTES <= 1) 0 else log2Up(ITYPE_BYTES) val OTYPE_BITS = accType.getWidth val LOG2_OTYPE_BITS = log2Up(OTYPE_BITS) - val OTYPE_BYTES = (accType.getWidth+7) / 8 + val OTYPE_BYTES = (accType.getWidth+cisc_dim-1) / cisc_dim val LOG2_OTYPE_BYTES = if(OTYPE_BYTES <= 1) 0 else log2Up(OTYPE_BYTES) val SP_BANKS = sp_banks @@ -133,12 +150,12 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( val LOG2_MNK_BYTES = log2Up(MNK_BYTES) val MNK_BYTES_PER_TILE_ROW = MNK_BYTES * DIM val LOG2_MNK_BYTES_PER_TILE_ROW = log2Up(MNK_BYTES_PER_TILE_ROW) - val TILE_IDX = MNK_BYTES / (DIM / 8) + val TILE_IDX = MNK_BYTES / (DIM / cisc_dim) val LOG2_TILE_IDX = log2Up(TILE_IDX) //-------------------------------------------------------------------------- - val I_TILE_BYTE_WIDTH = DIM * ((inputType.getWidth+7) / 8) - val O_TILE_BYTE_WIDTH = DIM * ((accType.getWidth+7) / 8) + val I_TILE_BYTE_WIDTH = DIM * ((inputType.getWidth+cisc_dim-1) / cisc_dim) + val O_TILE_BYTE_WIDTH = DIM * ((accType.getWidth+cisc_dim-1) / cisc_dim) val I_TILE_BYTE_WIDTH_LOG2 = log2Up(I_TILE_BYTE_WIDTH) val O_TILE_BYTE_WIDTH_LOG2 = log2Up(O_TILE_BYTE_WIDTH) require(pow(2,I_TILE_BYTE_WIDTH_LOG2) == I_TILE_BYTE_WIDTH, @@ -187,7 +204,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( (dt.expWidth, dt.sigWidth) match { case (8, 24) => (scala.Float.MinValue.toString, scala.Float.MaxValue.toString) case (11, 53) => (scala.Double.MinValue.toString, scala.Double.MaxValue.toString) - case _ => throw new IllegalArgumentException(s"Only single- and double-precision IEEE754 floating point types are currently supported") + case _ => (((Range(-1,-(dt.sigWidth),-1).map(-Math.pow(2, _)).foldLeft(-1.0)(_ + _)) * Math.pow(2, Math.pow(2, dt.expWidth - 1) - 1)).toString, ((Range(-1,-(dt.sigWidth),-1).map(Math.pow(2, _)).foldLeft(1.0)(_ + _)) * Math.pow(2, Math.pow(2, dt.expWidth - 1) - 1)).toString) } case _ => throw new IllegalArgumentException(s"Data type $dataType is unknown") } @@ -201,7 +218,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( (dt.expWidth, dt.sigWidth) match { case (8, 24) => "float" case (11, 53) => "double" - case _ => throw new IllegalArgumentException(s"Only single- and double-precision IEEE754 floating point types are currently supported") + case _ => s"uint" + (Math.pow(2, Math.ceil(Math.log(dt.expWidth + dt.sigWidth)/Math.log(2.0)))).toInt.toString + s"_t" } case _ => throw new IllegalArgumentException(s"Data type $dataType is unknown") } @@ -221,7 +238,6 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( // assert(Set(8, 16, 32, 64).contains(outputType.getWidth)) assert(Set(8, 16, 32, 64).contains(accType.getWidth)) - assert(acc_scale_args.latency == 0, "Accumulator's scale latency must be 0 cycles") val header = new StringBuilder() header ++= s"#ifndef $guard\n" @@ -230,6 +246,13 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#include \n" header ++= s"#include \n\n" + val opcodeid = Seq( + OpcodeSet.custom0, OpcodeSet.custom1, OpcodeSet.custom2, OpcodeSet.custom3 + ).indexWhere(o => o.opcodes(0).litValue == opcodes.opcodes(0).litValue) + println(opcodeid, opcodes.opcodes) + require (opcodeid != -1 && opcodes.opcodes.size == 1) + header ++= s"#define XCUSTOM_ACC $opcodeid\n" + header ++= s"#define DIM ${tileColumns*meshColumns}\n" header ++= s"#define ADDR_LEN 32\n" header ++= s"#define BANK_NUM $sp_banks\n" @@ -254,8 +277,15 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( // Datatype of the systolic array val limits = limitsOfDataType(inputType) header ++= s"typedef ${c_type(inputType)} elem_t;\n" - header ++= s"static const elem_t elem_t_max = ${limits._2};\n" - header ++= s"static const elem_t elem_t_min = ${limits._1};\n" + if (inputType.isInstanceOf[Float] && !((inputType.asInstanceOf[Float].expWidth, inputType.asInstanceOf[Float].sigWidth) == (8, 24) || (inputType.asInstanceOf[Float].expWidth, inputType.asInstanceOf[Float].sigWidth) == (11, 53))) + { + header ++= "#define ELEM_T_IS_LOWPREC_FLOAT\n" + header ++= s"static const float elem_t_max = ${limits._2};\n" + header ++= s"static const float elem_t_min = ${limits._1};\n" + } else { + header ++= s"static const elem_t elem_t_max = ${limits._2};\n" + header ++= s"static const elem_t elem_t_min = ${limits._1};\n" + } header ++= s"typedef ${c_type(accType)} acc_t;\n" header ++= s"typedef ${full_c_type(inputType)} full_t;\n\n" @@ -296,7 +326,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#define row_align_acc(blocks) __attribute__((aligned(blocks*DIM*sizeof(acc_t))))\n\n" val mvin_scale_identity = mvin_scale_args match { - case Some(ScaleArguments(_, _, _, identity, _)) => identity + case Some(ScaleArguments(_, _, _, _, identity, _)) => identity case None => "0" } header ++= s"#define MVIN_SCALE_IDENTITY $mvin_scale_identity\n\n" @@ -333,6 +363,13 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( |""".stripMargin header ++= "\n" + header ++= """// Rounding right shift equation: https://riscv.github.io/documents/riscv-v-spec/#_vector_fixed_point_rounding_mode_register_vxrm +#define ROUNDING_RIGHT_SHIFT_BITS(x, shift) \ +((shift) > 0 ? (((x) >> (shift)) + \ + (((shift) == 0 ? 0 : (((x) >> ((shift)-1)) & 1)) & \ + ((((shift) <= 1 ? 0 : ((x) & ((1 << ((shift)-1)) - 1))) != 0) | (((x) >> (shift)) & 1)))) : ((x) << (-(shift))))""" + header ++= "\n\n" + header ++= """#define ACC_SCALE(x, scale) \ """ header ++= s" ${acc_scale_args.c_str}" @@ -364,7 +401,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#define ACC_READ_FULL_WIDTH\n" header ++= s"\n" - header ++= s"#endif // $guard" + header ++= s"#endif // $guard\n" header.toString() } diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index 91b30a21..6f181457 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -22,6 +22,20 @@ object GemminiISA { val LOAD3_CMD = 14.U + // TODO add orows and ocols to this as well + val LOOP_CONV_WS = 15.U // no_bias, no_pool + val LOOP_CONV_WS_CONFIG_1 = 16.U // batch_size, in_dim, in_channels, out_channels | out_dim, pool_out_dim, stride, padding + val LOOP_CONV_WS_CONFIG_2 = 17.U // kernel_dim, pool_size, pool_stride, pool_padding | batches, porows, pocols, pochs + val LOOP_CONV_WS_CONFIG_3 = 18.U // krows, kcols, kchs, lpad | rpad, upad, dpad, plpad + val LOOP_CONV_WS_CONFIG_4 = 19.U // prad, pupad, pdpad, orows | ocols + val LOOP_CONV_WS_CONFIG_5 = 20.U // *weights | *output + val LOOP_CONV_WS_CONFIG_6 = 21.U // *bias, *input + + val LOOP_LD_CONFIG_BOUNDS = 22.U + val LOOP_LD_CONFIG_ADDRS = 23.U + val LOOP_CONV_LD_CONFIG_BOUNDS = 24.U + val LOOP_CONV_LD_CONFIG_ADDRS = 25.U + // rs1[2:0] values val CONFIG_EX = 0.U val CONFIG_LOAD = 1.U diff --git a/src/main/scala/gemmini/Im2Col.scala b/src/main/scala/gemmini/Im2Col.scala index 52039d4d..f264ad32 100644 --- a/src/main/scala/gemmini/Im2Col.scala +++ b/src/main/scala/gemmini/Im2Col.scala @@ -415,7 +415,7 @@ class Im2Col[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V when(i.U < channel){ im2col_data(i) := sram_req_output(i) }.otherwise{ - im2col_data(i) := 0.S //when channel < 16, pad with 0 + im2col_data(i) := 0.U.asTypeOf(inputType) //when channel < 16, pad with 0 } } } @@ -446,5 +446,6 @@ class Im2Col[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V io.resp.valid := false.B io.req.ready := true.B io.sram_reads.foreach(_.req.valid := false.B) + io.sram_reads.foreach(_.resp.ready := false.B) } } diff --git a/src/main/scala/gemmini/LoadController.scala b/src/main/scala/gemmini/LoadController.scala index cf5f0c57..d3a07275 100644 --- a/src/main/scala/gemmini/LoadController.scala +++ b/src/main/scala/gemmini/LoadController.scala @@ -6,6 +6,7 @@ import GemminiISA._ import Util._ import freechips.rocketchip.config.Parameters +// TODO we need to check for WAW errors here // TODO deal with errors when reading scratchpad responses class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], coreMaxAddrBits: Int, local_addr_t: LocalAddr) (implicit p: Parameters) extends Module { @@ -24,9 +25,10 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val waiting_for_command :: waiting_for_dma_req_ready :: sending_rows :: Nil = Enum(3) val control_state = RegInit(waiting_for_command) - val strides = Reg(Vec(3, UInt(coreMaxAddrBits.W))) - val scales = Reg(Vec(3, UInt(mvin_scale_t_bits.W))) - val shrinks = Reg(Vec(3, Bool())) // Shrink inputs to accumulator + val strides = Reg(Vec(load_states, UInt(coreMaxAddrBits.W))) + val scales = Reg(Vec(load_states, UInt(mvin_scale_t_bits.W))) + val shrinks = Reg(Vec(load_states, Bool())) // Shrink inputs to accumulator + val block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) // Spad stride during block move-ins val block_rows = meshRows * tileRows val block_cols = meshColumns * tileColumns val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) @@ -34,22 +36,36 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val cmd = Queue(io.cmd, ld_queue_length) val vaddr = cmd.bits.cmd.rs1 val localaddr = cmd.bits.cmd.rs2.asTypeOf(local_addr_t) - val cols = cmd.bits.cmd.rs2(32 + mvin_len_bits - 1, 32) // TODO magic numbers - val rows = cmd.bits.cmd.rs2(48 + mvin_rows_bits - 1, 48) // TODO magic numbers + //val cols = cmd.bits.cmd.rs2(32 + mvin_cols_bits - 1, 32) // TODO magic numbers + val cols = cmd.bits.cmd.rs2(44, 32) + //val rows = cmd.bits.cmd.rs2(48 + mvin_rows_bits - 1, 48) // TODO magic numbers + val rows = cmd.bits.cmd.rs2(60, 48) // TODO magic numbers val config_stride = cmd.bits.cmd.rs2 val config_scale = cmd.bits.cmd.rs1(32 + mvin_scale_t_bits - 1, 32) // TODO magic numbers - val config_shrink = cmd.bits.cmd.rs1(2) + val config_shrink = cmd.bits.cmd.rs1(2) // TODO magic numbers + val config_block_stride = cmd.bits.cmd.rs1(31, 16) // TODO magic numbers + //monitor conflict using either A or B + val monitor_conflict = (cmd.bits.cmd.inst.funct === LOAD2_CMD || cmd.bits.cmd.inst.funct === LOAD_CMD) && cmd.bits.cmd.rs2(63) + val monitor_conflict_start = monitor_conflict && cmd.bits.cmd.rs2(61) + val monitor_conflict_end = monitor_conflict && cmd.bits.cmd.rs2(62) + //profiling + val profile_conflict = (cmd.bits.cmd.inst.funct === LOAD2_CMD || cmd.bits.cmd.inst.funct === LOAD_CMD) && cmd.bits.cmd.rs2(47) + val profile_conflict_start = profile_conflict && cmd.bits.cmd.rs2(45) + val profile_conflict_end = profile_conflict && cmd.bits.cmd.rs2(46) val mstatus = cmd.bits.cmd.status val load_state_id = MuxCase(0.U, Seq((cmd.bits.cmd.inst.funct === LOAD2_CMD) -> 1.U, (cmd.bits.cmd.inst.funct === LOAD3_CMD) -> 2.U)) - val config_state_id = cmd.bits.cmd.rs1(4,3) + val config_state_id = cmd.bits.cmd.rs1(4,3) // TODO magic numbers val state_id = Mux(cmd.bits.cmd.inst.funct === CONFIG_CMD, config_state_id, load_state_id) val stride = strides(state_id) val scale = scales(state_id) val shrink = shrinks(state_id) + val block_stride = block_strides(state_id) + + val all_zeros = vaddr === 0.U val localaddr_plus_row_counter = localaddr + row_counter @@ -71,7 +87,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig (block_cols * config.accType.getWidth / 8) val maxBytesInMatRequest = block_rows * maxBytesInRowRequest - val cmd_tracker = Module(new DMAReadCommandTracker(nCmds, maxBytesInMatRequest, deps_t)) + val cmd_tracker = Module(new DMACommandTracker(nCmds, maxBytesInMatRequest, deps_t)) io.busy := cmd.valid || cmd_tracker.io.busy @@ -81,11 +97,19 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig (control_state === sending_rows && row_counter =/= 0.U) io.dma.req.bits.vaddr := vaddr + row_counter * stride io.dma.req.bits.laddr := localaddr_plus_row_counter - io.dma.req.bits.len := cols - io.dma.req.bits.repeats := Mux(stride === 0.U, rows - 1.U, 0.U) + io.dma.req.bits.cols := cols + io.dma.req.bits.repeats := Mux(stride === 0.U && !all_zeros, rows - 1.U, 0.U) + io.dma.req.bits.block_stride := block_stride io.dma.req.bits.scale := scale io.dma.req.bits.has_acc_bitwidth := localaddr_plus_row_counter.is_acc_addr && !shrink + io.dma.req.bits.all_zeros := all_zeros io.dma.req.bits.status := mstatus + io.dma.req.bits.monitor_conflict := monitor_conflict + io.dma.req.bits.monitor_conflict_start := monitor_conflict_start + io.dma.req.bits.monitor_conflict_end := monitor_conflict_end + io.dma.req.bits.profile_conflict_start := profile_conflict_start + io.dma.req.bits.profile_conflict_end := profile_conflict_end + io.dma.req.bits.profile_conflict := profile_conflict // Command tracker IO cmd_tracker.io.alloc.valid := control_state === waiting_for_command && cmd.valid && DoLoad @@ -109,6 +133,8 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig // Row counter when (io.dma.req.fire()) { row_counter := wrappingAdd(row_counter, 1.U, actual_rows_read) + + assert(block_stride >= rows) } // Control logic @@ -120,6 +146,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig stride := config_stride scale := config_scale shrink := config_shrink + block_stride := config_block_stride cmd.ready := true.B } diff --git a/src/main/scala/gemmini/LocalAddr.scala b/src/main/scala/gemmini/LocalAddr.scala new file mode 100644 index 00000000..b003fd7b --- /dev/null +++ b/src/main/scala/gemmini/LocalAddr.scala @@ -0,0 +1,83 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_entries: Int) extends Bundle { + private val localAddrBits = 32 // TODO magic number + + private val spAddrBits = log2Ceil(sp_banks * sp_bank_entries) + private val accAddrBits = log2Ceil(acc_banks * acc_bank_entries) + private val maxAddrBits = spAddrBits max accAddrBits + + private val spBankBits = log2Up(sp_banks) + private val spBankRowBits = log2Up(sp_bank_entries) + + private val accBankBits = log2Up(acc_banks) + val accBankRowBits = log2Up(acc_bank_entries) + + val is_acc_addr = Bool() + val accumulate = Bool() + val read_full_acc_row = Bool() + val garbage = UInt(((localAddrBits - maxAddrBits - 4) max 0).W) + val garbage_bit = if (localAddrBits - maxAddrBits >= 4) UInt(1.W) else UInt(0.W) + val data = UInt(maxAddrBits.W) + + def sp_bank(dummy: Int = 0) = if (spAddrBits == spBankRowBits) 0.U else data(spAddrBits - 1, spBankRowBits) + def sp_row(dummy: Int = 0) = data(spBankRowBits - 1, 0) + def acc_bank(dummy: Int = 0) = if (accAddrBits == accBankRowBits) 0.U else data(accAddrBits - 1, accBankRowBits) + def acc_row(dummy: Int = 0) = data(accBankRowBits - 1, 0) + + def full_sp_addr(dummy: Int = 0) = data(spAddrBits - 1, 0) + def full_acc_addr(dummy: Int = 0) = data(accAddrBits - 1, 0) + + def is_same_address(other: LocalAddr): Bool = is_acc_addr === other.is_acc_addr && data === other.data + def is_same_address(other: UInt): Bool = is_same_address(other.asTypeOf(this)) + def is_garbage(dummy: Int = 0) = is_acc_addr && accumulate && read_full_acc_row && data.andR() && + (if (garbage_bit.getWidth > 0) garbage_bit.asBool() else true.B) + + def +(other: UInt) = { + require(isPow2(sp_bank_entries)) // TODO remove this requirement + require(isPow2(acc_bank_entries)) // TODO remove this requirement + + val result = WireInit(this) + result.data := data + other + result + } + + def <=(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() <= other.full_acc_addr(), full_sp_addr() <= other.full_sp_addr()) + + def <(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() < other.full_acc_addr(), full_sp_addr() < other.full_sp_addr()) + + def >(other: LocalAddr) = + is_acc_addr === other.is_acc_addr && + Mux(is_acc_addr, full_acc_addr() > other.full_acc_addr(), full_sp_addr() > other.full_sp_addr()) + + def add_with_overflow(other: UInt): Tuple2[LocalAddr, Bool] = { + require(isPow2(sp_bank_entries)) // TODO remove this requirement + require(isPow2(acc_bank_entries)) // TODO remove this requirement + + val sum = data +& other + + val overflow = Mux(is_acc_addr, sum(accAddrBits), sum(spAddrBits)) + + val result = WireInit(this) + result.data := sum(maxAddrBits - 1, 0) + + (result, overflow) + } + + def make_this_garbage(dummy: Int = 0): Unit = { + is_acc_addr := true.B + accumulate := true.B + read_full_acc_row := true.B + garbage_bit := 1.U + data := ~(0.U(maxAddrBits.W)) + } + + override def cloneType: LocalAddr.this.type = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries).asInstanceOf[this.type] +} diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala new file mode 100644 index 00000000..7cf24a47 --- /dev/null +++ b/src/main/scala/gemmini/LoopConv.scala @@ -0,0 +1,1120 @@ +package gemmini + +import chisel3._ +import chisel3.util._ +import chisel3.experimental._ +import freechips.rocketchip.tile.RoCCCommand +import freechips.rocketchip.config.Parameters +import GemminiISA._ +import Util._ + +class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { + val batch_size = UInt(large_iterator_bitwidth.W) + val in_dim = UInt(small_iterator_bitwidth.W) + val in_channels = UInt(large_iterator_bitwidth.W) + val out_channels = UInt(large_iterator_bitwidth.W) + val out_stride = UInt(large_iterator_bitwidth.W) //stride for output activation + val in_stride = UInt(large_iterator_bitwidth.W) //stride for input activation + val weight_stride = UInt(large_iterator_bitwidth.W) //stride for weight + val out_dim = UInt(small_iterator_bitwidth.W) + val pool_out_dim = UInt(small_iterator_bitwidth.W) + val stride = UInt(tiny_iterator_bitwidth.W) + val padding = UInt(tiny_iterator_bitwidth.W) + val kernel_dim = UInt(tiny_iterator_bitwidth.W) + val pool_size = UInt(tiny_iterator_bitwidth.W) + val pool_stride = UInt(tiny_iterator_bitwidth.W) + val pool_padding = UInt(tiny_iterator_bitwidth.W) +} + +class LoopConvInnerBounds(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { + val batches = UInt(large_iterator_bitwidth.W) + val porows = UInt(small_iterator_bitwidth.W) + val pocols = UInt(small_iterator_bitwidth.W) + val pochs = UInt(large_iterator_bitwidth.W) + val krows = UInt(tiny_iterator_bitwidth.W) + val kcols = UInt(tiny_iterator_bitwidth.W) + val kchs = UInt(large_iterator_bitwidth.W) + val lpad = UInt(tiny_iterator_bitwidth.W) + val rpad = UInt(tiny_iterator_bitwidth.W) + val upad = UInt(tiny_iterator_bitwidth.W) + val dpad = UInt(tiny_iterator_bitwidth.W) + val plpad = UInt(tiny_iterator_bitwidth.W) + val prad = UInt(tiny_iterator_bitwidth.W) + val pupad = UInt(tiny_iterator_bitwidth.W) + val pdpad = UInt(tiny_iterator_bitwidth.W) + val orows = UInt(small_iterator_bitwidth.W) + val ocols = UInt(small_iterator_bitwidth.W) +} + +class LoopConvDerivedParams(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { + val ochs = UInt(large_iterator_bitwidth.W) + + val irows = UInt(small_iterator_bitwidth.W) + val icols = UInt(small_iterator_bitwidth.W) + val irows_unpadded = UInt(small_iterator_bitwidth.W) + val icols_unpadded = UInt(small_iterator_bitwidth.W) + val ichs = UInt(large_iterator_bitwidth.W) + + val out_channels_per_bank = UInt(small_iterator_bitwidth.W) // TODO this won't work for systolic arrays above 256 in size + + val bias_spad_stride = UInt(large_iterator_bitwidth.W) + val input_spad_stride = UInt(large_iterator_bitwidth.W) + val weight_spad_stride = UInt(large_iterator_bitwidth.W) + + val ex_overwrite = Bool() +} + +class LoopConvLdBiasReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_start = UInt(log2Up(max_acc_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val no_bias = Bool() + val partial_sum_mvin = Bool() //for partial sum move-in + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_acc_addr: Int, acc_w: Int, + max_block_len_acc: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvLdBiasReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + val wait_for_prev_loop = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, config, ld = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvLdBiasReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops)) + import req.inner_bounds._ + import req.outer_bounds._ + import req.derived_params._ + + val acc_addr_start = (BigInt(1) << 31).U | req.addr_start + + // Derived parameters + val max_ochs_per_mvin = Mux(ochs < (max_block_len_acc * block_size).U, ochs, (max_block_len_acc * block_size).U) + + val skip = req.no_bias || (req.dram_addr === 0.U) + + // Iterators + val b = Reg(UInt(large_iterator_bitwidth.W)) + val orow = Reg(UInt(small_iterator_bitwidth.W)) + val ocol = Reg(UInt(small_iterator_bitwidth.W)) + val och = Reg(UInt(large_iterator_bitwidth.W)) + + // Addresses +// val dram_addr = req.dram_addr +& och * (acc_w/8).U + val dram_addr = Mux(req.partial_sum_mvin, req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * out_channels + och) * (acc_w/8).U, + req.dram_addr +& och * (acc_w/8).U) + //stride for partial sum: out_channels (not och_stride) + + val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + + // Sizes + val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) + val J = Mux(ochs - och > max_ochs_per_mvin, max_ochs_per_mvin, ochs - och) + + // Commands + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U) | (req.derived_params.bias_spad_stride << 16.U) | (2.U << 3) | 1.U + config_cmd.rs2 := Mux(req.partial_sum_mvin, out_channels * (acc_w/8).U, 0.U) + // to move in partial sum, need stride + + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD3_CMD + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := (I << 48.U) | (J << 32.U) | spad_addr + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.rob_overloaded && !io.wait_for_prev_loop && !skip + io.cmd.bits := Mux(state === config, config_cmd, mvin_cmd) + + // Sending outputs + when (skip) { + state := idle + }.elsewhen(io.cmd.fire()) { + when (state === config) { + state := ld + }.otherwise { + val next_och = floorAdd(och, max_ochs_per_mvin, ochs) + val next_ocol = floorAdd(ocol, block_size.U, ocols, next_och === 0.U) + val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U && next_och === 0.U) + val next_b = floorAdd(b, 1.U, batches, next_orow === 0.U && next_ocol === 0.U && next_och === 0.U) + + och := next_och + ocol := next_ocol + orow := next_orow + b := next_b + + state := Mux(next_b === 0.U && next_orow === 0.U && next_ocol === 0.U && next_och === 0.U, + idle, ld) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := config + b := 0.U + orow := 0.U + ocol := 0.U + och := 0.U + } +} + +class LoopConvLdInputReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_start = UInt(log2Up(max_acc_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, + max_block_len: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvLdInputReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + val wait_for_prev_loop = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, config, ld = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvLdInputReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + // Derived parameters + val max_ichs_per_mvin = Mux(ichs < (max_block_len * block_size).U, ichs, (max_block_len * block_size).U).zext() + + // Iterators + val b = Reg(SInt(large_iterator_bitwidth.W)) + val irow = Reg(SInt(small_iterator_bitwidth.W)) + val icol = Reg(SInt(small_iterator_bitwidth.W)) + val ich = Reg(SInt(large_iterator_bitwidth.W)) + + // Calculated params + val irow_padded = irow +& upad.zext() + val icol_padded = icol +& lpad.zext() + val is_zeros = irow < 0.S || irow >= irows_unpadded.zext() || icol < 0.S || icol >= icols_unpadded.zext() + + val ich_stride = in_stride + + // Addresses + val dram_addr = Mux(is_zeros, 0.U, + req.dram_addr +& (((b * in_dim * in_dim +& irow*in_dim +& icol) * ich_stride +& ich) * (input_w/8).U).asUInt()) + val spad_addr = req.addr_start.zext() +& (ich / block_size.S) * batches * irows * icols +& b * irows * icols +& irow_padded * icols +& icol_padded + + // Sizes + val I = MuxCase( + Mux(icols_unpadded.zext() -& icol > block_size.S, block_size.S, icols_unpadded.zext() -& icol), + Seq( + (icol < 0.S) -> Mux((0.S-&icol) > block_size.S, block_size.S, 0.S-&icol), + (icol >= icols_unpadded.zext()) -> Mux(icols_unpadded.zext() +& rpad.zext() -& icol > block_size.S, block_size.S, icols_unpadded.zext() +& rpad.zext() -& icol) + ) + ) + val K = Mux(ichs.zext() -& ich > max_ichs_per_mvin, max_ichs_per_mvin, ichs.zext() -& ich) + + // Commands + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U).asUInt() | (req.derived_params.input_spad_stride << 16.U).asUInt() | (0.U << 3).asUInt() | 1.U + config_cmd.rs2 := ich_stride * (input_w/8).U + + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD_CMD + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := (I << 48.U).asUInt() | (K << 32.U).asUInt() | spad_addr.asUInt() + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.wait_for_prev_loop && !io.rob_overloaded && req.dram_addr =/= 0.U + io.cmd.bits := Mux(state === config, config_cmd, mvin_cmd) + + // Sending outputs + when (req.dram_addr === 0.U) { + state := idle + }.elsewhen (io.cmd.fire()) { + when (state === config) { + state := ld + }.otherwise { + val next_ich = sFloorAdd(ich, max_ichs_per_mvin.asUInt(), ichs.zext(), 0.S) + val next_icol = sFloorAdd(icol, I.asUInt(), (icols_unpadded +& rpad).zext(), 0.S-&lpad.zext(), + next_ich === 0.S) + val next_irow = sFloorAdd(irow, 1.U, (irows_unpadded +& dpad).zext(), 0.S-&upad.zext(), + next_icol === 0.S-&lpad.zext() && next_ich === 0.S) + val next_b = sFloorAdd(b, 1.U, batches.zext(), 0.S, + next_irow === 0.S-&upad.zext() && next_icol === 0.S-&lpad.zext() && next_ich === 0.S) + + ich := next_ich + icol := next_icol + irow := next_irow + b := next_b + + state := Mux(next_b === 0.S && next_irow === 0.S-&upad.zext() && next_icol === 0.S-&lpad.zext() && next_ich === 0.S, + idle, ld) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := config + b := 0.S + irow := 0.S -& io.req.bits.inner_bounds.upad.zext() + icol := 0.S -& io.req.bits.inner_bounds.lpad.zext() + ich := 0.S + } +} + +class LoopConvLdWeightReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_end = UInt(log2Up(max_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val loop_id = UInt(log2Up(concurrent_loops).W) + val depthwise = Bool() +} + +class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, input_w: Int, + max_block_len: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvLdWeightReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + val wait_for_prev_loop = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, config, ld = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvLdWeightReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + // Derived parameters + val max_ochs_per_mvin = Mux(ochs < (max_block_len * block_size).U, ochs, (max_block_len * block_size).U) + val B_rows = out_channels_per_bank * kcols * krows * kchs + val addr_start = req.addr_end - B_rows + block_size.U // for possible loopconv bug (like the loopmatmul one) + + // Iterators + val och = Reg(UInt(large_iterator_bitwidth.W)) + val krow = Reg(UInt(tiny_iterator_bitwidth.W)) + val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) + val kch = Reg(UInt(large_iterator_bitwidth.W)) + + val och_stride = weight_stride + + // Addresses + val dram_addr = Mux(req.depthwise, req.dram_addr +& ((krow*kernel_dim +& kcol +& kch) * och_stride +& och) * (input_w/8).U, req.dram_addr +& ((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * och_stride +& och) * (input_w/8).U) + val spad_addr = addr_start + (och / block_size.U) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch + + // Sizes + val J = Mux(ochs - och > max_ochs_per_mvin, max_ochs_per_mvin, ochs - och) + val K = Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) + + // Commands + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U).asUInt() | (req.derived_params.weight_spad_stride << 16.U).asUInt() | (1.U << 3).asUInt() | 1.U + config_cmd.rs2 := och_stride * (input_w/8).U + + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD2_CMD + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := (K << 48.U).asUInt() | (J << 32.U).asUInt() | spad_addr + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.wait_for_prev_loop && !io.rob_overloaded && req.dram_addr =/= 0.U + io.cmd.bits := Mux(state === config, config_cmd, mvin_cmd) + + // Sending outputs + when (req.dram_addr === 0.U) { + state := idle + }.elsewhen (io.cmd.fire()) { + when (state === config) { + state := ld + }.otherwise { + val next_kch = floorAdd(kch, block_size.U, kchs) + val next_kcol = floorAdd(kcol, 1.U, kcols, next_kch === 0.U) + val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U) + val next_och = floorAdd(och, max_ochs_per_mvin, ochs, next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + + kch := next_kch + kcol := next_kcol + krow := next_krow + och := next_och + + state := Mux(next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U, + idle, ld) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := config + kch := 0.U + kcol := 0.U + krow := 0.U + och := 0.U + } +} + +class LoopConvExecuteReq(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_addr: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val a_addr_start = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr).W) + val c_addr_start = UInt(log2Up(max_acc_addr).W) + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_addr: Int, + max_acc_addr: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val GARBAGE_ADDR = (~0.U(32.W)).asUInt() + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvExecuteReq(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val lda_completed = Input(Bool()) + val ldb_completed = Input(Bool()) + val ldd_completed = Input(Bool()) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, pre, comp = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvExecuteReq(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, + max_addr, max_acc_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + // Derived parameters + val B_rows = out_channels_per_bank * kcols * krows * kchs + + val a_addr_start = req.a_addr_start + val b_addr_start = req.b_addr_end - B_rows + block_size.U //for possible loopconv bug (like loopmatmul) + val d_addr_start = (BigInt(1) << 31).U | req.c_addr_start + val c_addr_start = (BigInt(3) << 30).U | req.c_addr_start + + // Iterators + val och = Reg(UInt(large_iterator_bitwidth.W)) + val krow = Reg(UInt(tiny_iterator_bitwidth.W)) + val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) + val kch = Reg(UInt(large_iterator_bitwidth.W)) + val b = Reg(UInt(large_iterator_bitwidth.W)) + val orow = Reg(UInt(small_iterator_bitwidth.W)) + val ocol = Reg(UInt(small_iterator_bitwidth.W)) + + val irow = orow * stride +& krow + val icol = ocol * stride +& kcol + + val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) + val J = Mux(ochs - och > block_size.U, block_size.U, ochs - och) + val K = Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) + + // Addresses + val a_addr = a_addr_start +& (kch / block_size.U) * batches * irows * icols +& b * irows * icols +& irow * icols +& icol + val c_addr = Mux(ex_overwrite && krow === 0.U && kcol === 0.U && kch === 0.U, d_addr_start, c_addr_start) +& + (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + + val new_weights = b === 0.U && orow === 0.U && ocol === 0.U + val b_addr = Mux(new_weights, + b_addr_start +& (och / block_size.U) * krows * kcols * kchs +& krow * kcols * kchs +& kcol * kchs +& kch, + GARBAGE_ADDR) + + // Commands + val pre_cmd = Wire(new RoCCCommand) + pre_cmd := DontCare + pre_cmd.inst.funct := PRELOAD_CMD + pre_cmd.rs1 := (K << 48) | (J << 32) | b_addr + pre_cmd.rs2 := (I << 48) | (J << 32) | c_addr + + val comp_cmd = Wire(new RoCCCommand()) + comp_cmd := DontCare + comp_cmd.inst.funct := Mux(new_weights, COMPUTE_AND_FLIP_CMD, COMPUTE_AND_STAY_CMD) + comp_cmd.rs1 := (I << 48) | (K << 32) | a_addr + comp_cmd.rs2 := (I << 48) | (J << 32) | GARBAGE_ADDR + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + + val ld_ahead = io.lda_completed && io.ldb_completed && io.ldd_completed + + io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead + io.cmd.bits := Mux(state === pre, pre_cmd, comp_cmd) + + io.loop_id := req.loop_id + + // Sending outputs + when (io.cmd.fire()) { + when (state === pre) { + state := comp + }.otherwise { + val next_ocol = floorAdd(ocol, block_size.U, ocols) + val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U) + val next_b = floorAdd(b, 1.U, batches, next_orow === 0.U && next_ocol === 0.U) + val next_kch = floorAdd(kch, block_size.U, kchs, + next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) + val next_kcol = floorAdd(kcol, 1.U, kcols, + next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) + val next_krow = floorAdd(krow, 1.U, krows, + next_kcol === 0.U && next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) + val next_och = floorAdd(och, block_size.U, ochs, next_krow === 0.U && + next_kcol === 0.U && next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) + + ocol := next_ocol + orow := next_orow + b := next_b + kch := next_kch + kcol := next_kcol + krow := next_krow + och := next_och + + state := Mux(next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U && next_b === 0.U && + next_orow === 0.U && next_ocol === 0.U, + idle, pre) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := pre + + b := 0.U + orow := 0.U + ocol := 0.U + och := 0.U + krow := 0.U + kcol := 0.U + kch := 0.U + } +} + +class LoopConvStReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val max_acc_addr: Int, val concurrent_loops: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val addr_start = UInt(log2Up(max_acc_addr).W) + val dram_addr = UInt(coreMaxAddrBits.W) + val dram_addr_pool = UInt(coreMaxAddrBits.W) + val no_pool = Bool() + val both_out = Bool() // output both pooled and unpooled + val partial_sum = Bool() // move out 32 bits of partial sum + val loop_id = UInt(log2Up(concurrent_loops).W) +} + +class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: Int, small_iterator_bitwidth: Int, tiny_iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int)(implicit p: Parameters) extends Module { + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new LoopConvStReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops))) + val cmd = Decoupled(Output(new RoCCCommand)) + + val ex_completed = Input(Bool()) + + val idle = Output(Bool()) + val rob_overloaded = Input(Bool()) + + val loop_id = Output(UInt(log2Up(concurrent_loops).W)) + }) + + object State extends ChiselEnum { + val idle, st, pre_pool_config, pool, post_pool_config = Value + } + import State._ + val state = RegInit(idle) + + val req = Reg(new LoopConvStReq(coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth: Int, max_acc_addr, concurrent_loops)) + import req.outer_bounds._ + import req.inner_bounds._ + import req.derived_params._ + + val acc_addr_start = (BigInt(1) << 31).U | req.addr_start + + // Derived parameters + val skip = req.dram_addr === 0.U + + // Iterators + val b = Reg(UInt(large_iterator_bitwidth.W)) + val orow = Reg(UInt(small_iterator_bitwidth.W)) + val ocol = Reg(UInt(small_iterator_bitwidth.W)) + val och = Reg(UInt(large_iterator_bitwidth.W)) + + //further divide due to squeezenet fire module concatenation + val och_stride = out_stride + // Addresses + val dram_addr = req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * och_stride + och) * Mux(req.partial_sum, (acc_w/8).U, (input_w/8).U) + val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol + + val pool_dram_addr = Mux(req.both_out, req.dram_addr_pool, req.dram_addr) + ((b * pool_out_dim * pool_out_dim) * och_stride + och) * (input_w/8).U + val pool_spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols + + // Sizes + val I = Mux(ocols - ocol > block_size.U, block_size.U, ocols - ocol) + val J = Mux(ochs - och > block_size.U, block_size.U, ochs - och) + + val channels = J + + // Commands + val mvout_cmd = Wire(new RoCCCommand) + mvout_cmd := DontCare + mvout_cmd.inst.funct := STORE_CMD + mvout_cmd.rs1 := dram_addr + mvout_cmd.rs2 := Mux(req.partial_sum, (I << 48.U) | (J << 32.U) | spad_addr | (1.U << 29), (I << 48.U) | (J << 32.U) | spad_addr) + + val pre_pool_config_cmd = Wire(new RoCCCommand) + pre_pool_config_cmd := DontCare + pre_pool_config_cmd.inst.funct := CONFIG_CMD + pre_pool_config_cmd.rs1 := (ocols << 56) | (orows << 48) | (pocols << 40) | (porows << 32) | (pool_out_dim << 24) | + (plpad << 10) | (pupad << 8) | (pool_size << 6) | (pool_stride << 4) | // TODO magic numbers + CONFIG_STORE + pre_pool_config_cmd.rs2 := och_stride * (input_w / 8).U + + val post_pool_config_cmd = Wire(new RoCCCommand) + post_pool_config_cmd := DontCare + post_pool_config_cmd.inst.funct := CONFIG_CMD + post_pool_config_cmd.rs1 := CONFIG_STORE + post_pool_config_cmd.rs2 := Mux(req.partial_sum, och_stride * (acc_w / 8).U, och_stride * (input_w / 8).U) + //need 32 bits stride to move out partial sum + + val pool_cmd = Wire(new RoCCCommand) + pool_cmd := DontCare + pool_cmd.inst.funct := STORE_CMD + pool_cmd.rs1 := pool_dram_addr + pool_cmd.rs2 := (channels << 32.U) | pool_spad_addr + + // Inputs and outputs + io.req.ready := state === idle + io.idle := state === idle + io.loop_id := req.loop_id + + io.cmd.valid := state =/= idle && !io.rob_overloaded && !skip && io.ex_completed + io.cmd.bits := MuxLookup(state.asUInt, mvout_cmd, Seq(pre_pool_config.asUInt -> pre_pool_config_cmd, + pool.asUInt -> pool_cmd, post_pool_config.asUInt -> post_pool_config_cmd)) + + val second_pool = RegInit(false.B) //need to output pool next + // Sending outputs + when (skip) { + state := idle + }.elsewhen(io.cmd.fire()) { + when (req.no_pool || (req.both_out && !second_pool)) { // needs normal output first before pool + val next_och = floorAdd(och, block_size.U, ochs) + val next_ocol = floorAdd(ocol, block_size.U, ocols, next_och === 0.U) + val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U && next_och === 0.U) + val next_b = floorAdd(b, 1.U, batches, next_orow === 0.U && next_ocol === 0.U && next_och === 0.U) + + och := next_och + ocol := next_ocol + orow := next_orow + b := next_b + val next_all_zero = next_b === 0.U && next_orow === 0.U && next_ocol === 0.U && next_och === 0.U + state := Mux(next_all_zero, Mux(req.both_out, pre_pool_config, idle), st) + when(next_all_zero && !second_pool && req.both_out){ + second_pool := true.B + } + }.elsewhen(state === pre_pool_config) { + state := pool + }.elsewhen(state === post_pool_config) { + state := idle + second_pool := false.B + }.otherwise { + val next_och = floorAdd(och, block_size.U, ochs) + val next_b = floorAdd(b, 1.U, batches, next_och === 0.U) + + och := next_och + b := next_b + + state := Mux(next_b === 0.U && next_och === 0.U, + post_pool_config, pool) + } + } + + // Accepting requests + when (io.req.fire()) { + req := io.req.bits + state := Mux(io.req.bits.no_pool || io.req.bits.both_out, st, pre_pool_config) + + b := 0.U + orow := 0.U + ocol := 0.U + och := 0.U + } +} + +class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int, val coreMaxAddrBits: Int, val max_addr: Int, val max_acc_addr: Int) extends Bundle { + val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) + + val bias_dram_addr = UInt(coreMaxAddrBits.W) + val weights_dram_addr = UInt(coreMaxAddrBits.W) + val input_dram_addr = UInt(coreMaxAddrBits.W) + val output_dram_addr = UInt(coreMaxAddrBits.W) + val pool_output_dram_addr = UInt(coreMaxAddrBits.W) + + val no_bias = Bool() + val no_pool = Bool() + val both_out = Bool() // both pool and not pool + val partial_sum_mvout = Bool() + val partial_sum_mvin = Bool() + val depthwise = Bool() + + val configured = Bool() + + val running = Bool() + + val ld_bias_started = Bool() + val ld_input_started = Bool() + val ld_weights_started = Bool() + val ex_started = Bool() + val st_started = Bool() + + val ld_bias_completed = Bool() + val ld_input_completed = Bool() + val ld_weights_completed = Bool() + val ex_completed = Bool() + val st_completed = Bool() + + def all_completed(dummy: Int=0): Bool = ld_bias_completed && ld_input_completed && ld_weights_completed && ex_completed && st_completed + + val a_addr_start = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr).W) + + def derived_params(dummy: Int=0): LoopConvDerivedParams = { + import outer_bounds.stride + import inner_bounds.{batches, pochs, orows, ocols, krows, kcols, upad, dpad, lpad, rpad, kchs} + + val result = Wire(new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth)) + + result.ochs := pochs + + result.irows := orows * stride +& krows - 1.U + result.icols := ocols * stride +& kcols - 1.U + result.irows_unpadded := result.irows - upad - dpad + result.icols_unpadded := result.icols - lpad - rpad + result.ichs := kchs + + result.out_channels_per_bank := result.ochs / block_size.U +& (result.ochs % block_size.U =/= 0.U) + + result.bias_spad_stride := batches * orows * ocols + result.input_spad_stride := batches * result.irows * result.icols + result.weight_spad_stride := krows * kcols * kchs + + result.ex_overwrite := bias_dram_addr =/= 0.U && no_bias + + result + } + + def reset(): Unit = { + configured := false.B + + running := false.B + + ld_bias_started := false.B + ld_input_started := false.B + ld_weights_started := false.B + ex_started := false.B + st_started := false.B + + ld_bias_completed := false.B + ld_input_completed := false.B + ld_weights_completed := false.B + ex_completed := false.B + st_completed := false.B + } +} + +class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, + max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) + (implicit p: Parameters) extends Module { + val large_iterator_bitwidth = 16 + val small_iterator_bitwidth = 8 + val tiny_iterator_bitwidth = 4 + + val max_block_len = (dma_max_bytes / (block_size * (input_w / 8))) max 1 + val max_block_len_acc = (dma_max_bytes / (block_size * (acc_w / 8))) max 1 + + val io = IO(new Bundle { + val in = Flipped(Decoupled(new RoCCCommand)) + val out = Decoupled(new RoCCCommand) + val ld_utilization = Input(UInt(log2Up(rob_size+1).W)) + val st_utilization = Input(UInt(log2Up(rob_size+1).W)) + val ex_utilization = Input(UInt(log2Up(rob_size+1).W)) + val busy = Output(Bool()) + }) + + // Create states + val concurrent_loops = 2 + val loops = Reg(Vec(concurrent_loops, new LoopConvState(block_size, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, coreMaxAddrBits, max_addr, max_acc_addr))) + val head_loop_id = RegInit(0.U(log2Up(concurrent_loops).W)) + val tail_loop_id = (~head_loop_id).asUInt() // This is the loop that we always try to configure if available + val head_loop = loops(head_loop_id) + val tail_loop = loops(tail_loop_id) + + val loop_configured = loops.map(_.configured).reduce(_ || _) + + val loop_being_configured_id = Mux(head_loop.configured, tail_loop_id, head_loop_id) + val loop_being_configured = loops(loop_being_configured_id) + + // Create inner modules + val ld_bias = Module(new LoopConvLdBias(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_acc_addr, acc_w, max_block_len_acc, concurrent_loops)) + val ld_input = Module(new LoopConvLdInput(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) + val ld_weights = Module(new LoopConvLdWeight(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) + val ex = Module(new LoopConvExecute(block_size, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) + val st = Module(new LoopConvSt(block_size, coreMaxAddrBits, large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, concurrent_loops)) + + // Create command queue + val cmd = Queue(io.in) + + io.busy := cmd.valid || loop_configured + + // Create arbiter + val arb = Module(new Arbiter(new RoCCCommand, 5)) + arb.io.in(0) <> st.io.cmd + arb.io.in(1) <> ex.io.cmd + arb.io.in(2) <> ld_bias.io.cmd + arb.io.in(3) <> ld_weights.io.cmd + arb.io.in(4) <> ld_input.io.cmd + val unrolled_cmd = arb.io.out + + // Wire up unrolled command output + val is_loop_run_cmd = cmd.bits.inst.funct === LOOP_CONV_WS + val is_loop_config_cmd = cmd.bits.inst.funct >= LOOP_CONV_WS_CONFIG_1 && cmd.bits.inst.funct <= LOOP_CONV_WS_CONFIG_6 + val is_loop_cmd = is_loop_run_cmd || is_loop_config_cmd + + io.out.bits := Mux(loop_configured, unrolled_cmd.bits, cmd.bits) + io.out.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this + io.out.valid := Mux(loop_configured, unrolled_cmd.valid, cmd.valid && !is_loop_config_cmd && !is_loop_run_cmd) + + cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured, !loop_configured && io.out.ready) + arb.io.out.ready := io.out.ready + + // Wire up waiting-for-loads signals + val ex_is_waiting_for_loads = loops(ex.io.loop_id).ex_started && !loops(ex.io.loop_id).ex_completed && + !(loops(ex.io.loop_id).ld_input_completed && loops(ex.io.loop_id).ld_weights_completed && + loops(ex.io.loop_id).ld_bias_completed) + + ld_bias.io.wait_for_prev_loop := ex_is_waiting_for_loads && ld_bias.io.loop_id =/= ex.io.loop_id + ld_weights.io.wait_for_prev_loop := ex_is_waiting_for_loads && ld_weights.io.loop_id =/= ex.io.loop_id + ld_input.io.wait_for_prev_loop := ex_is_waiting_for_loads && ld_input.io.loop_id =/= ex.io.loop_id + + // Wire up overloaded signals + ld_bias.io.rob_overloaded := io.ld_utilization >= max_lds.U + ld_input.io.rob_overloaded := io.ld_utilization >= max_lds.U + ld_weights.io.rob_overloaded := io.ld_utilization >= max_lds.U + ex.io.rob_overloaded := io.ex_utilization >= max_exs.U + st.io.rob_overloaded := io.st_utilization >= max_sts.U + + // Wire up iterator inputs + ex.io.lda_completed := (ld_input.io.loop_id =/= ex.io.loop_id) || ld_input.io.idle + ex.io.ldb_completed := (ld_weights.io.loop_id =/= ex.io.loop_id) || ld_weights.io.idle + ex.io.ldd_completed := (ld_bias.io.loop_id =/= ex.io.loop_id) || ld_bias.io.idle + st.io.ex_completed := (ex.io.loop_id =/= st.io.loop_id) || ex.io.idle + + // Create config registers + when(cmd.valid && is_loop_cmd && !loop_being_configured.configured) { + + switch (cmd.bits.inst.funct) { + is (LOOP_CONV_WS_CONFIG_1) { + loop_being_configured.outer_bounds.out_channels := cmd.bits.rs1(63, 48) + loop_being_configured.outer_bounds.in_channels := cmd.bits.rs1(47, 32) + loop_being_configured.outer_bounds.in_dim := cmd.bits.rs1(31, 16) + loop_being_configured.outer_bounds.batch_size := cmd.bits.rs1(15, 0) + + loop_being_configured.outer_bounds.padding := cmd.bits.rs2(63, 48) + loop_being_configured.outer_bounds.stride := cmd.bits.rs2(47, 32) + loop_being_configured.outer_bounds.pool_out_dim := cmd.bits.rs2(31, 16) + loop_being_configured.outer_bounds.out_dim := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_2) { + loop_being_configured.outer_bounds.kernel_dim := cmd.bits.rs1(63, 48) + loop_being_configured.outer_bounds.pool_size := cmd.bits.rs1(47, 32) + loop_being_configured.outer_bounds.pool_stride := cmd.bits.rs1(31, 16) + loop_being_configured.outer_bounds.pool_padding := cmd.bits.rs1(15, 0) + + loop_being_configured.inner_bounds.batches := cmd.bits.rs2(63, 48) + loop_being_configured.inner_bounds.porows := cmd.bits.rs2(47, 32) + loop_being_configured.inner_bounds.pocols := cmd.bits.rs2(31, 16) + loop_being_configured.inner_bounds.pochs := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_3) { + loop_being_configured.inner_bounds.krows := cmd.bits.rs1(63, 48) + loop_being_configured.inner_bounds.kcols := cmd.bits.rs1(47, 32) + loop_being_configured.inner_bounds.kchs := cmd.bits.rs1(31, 16) + loop_being_configured.inner_bounds.lpad := cmd.bits.rs1(15, 0) + + loop_being_configured.inner_bounds.rpad := cmd.bits.rs2(63, 48) + loop_being_configured.inner_bounds.upad := cmd.bits.rs2(47, 32) + loop_being_configured.inner_bounds.dpad := cmd.bits.rs2(31, 16) + loop_being_configured.inner_bounds.plpad := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_4) { + loop_being_configured.inner_bounds.orows := cmd.bits.rs1(63, 48) + loop_being_configured.inner_bounds.prad := cmd.bits.rs1(47, 32) + loop_being_configured.inner_bounds.pupad := cmd.bits.rs1(31, 16) + loop_being_configured.inner_bounds.pdpad := cmd.bits.rs1(15, 0) + + loop_being_configured.outer_bounds.in_stride := cmd.bits.rs2(63, 48) + loop_being_configured.outer_bounds.weight_stride := cmd.bits.rs2(47, 32) + loop_being_configured.outer_bounds.out_stride := cmd.bits.rs2(31, 16) + loop_being_configured.inner_bounds.ocols := cmd.bits.rs2(15, 0) + } + + is (LOOP_CONV_WS_CONFIG_5) { + loop_being_configured.weights_dram_addr := cmd.bits.rs1 + + loop_being_configured.output_dram_addr := cmd.bits.rs2 + } + + is (LOOP_CONV_WS_CONFIG_6) { + loop_being_configured.bias_dram_addr := cmd.bits.rs1 + + loop_being_configured.input_dram_addr := cmd.bits.rs2 + } + + is (LOOP_CONV_WS) { + loop_being_configured.pool_output_dram_addr := cmd.bits.rs1 // added for 2 mvout + loop_being_configured.no_bias := cmd.bits.rs2(61) + loop_being_configured.partial_sum_mvin := cmd.bits.rs2(59) + + loop_being_configured.no_pool := cmd.bits.rs2(0) + loop_being_configured.both_out := cmd.bits.rs2(62) + loop_being_configured.partial_sum_mvout := cmd.bits.rs2(60) + loop_being_configured.depthwise := cmd.bits.rs2(63) + + + loop_being_configured.configured := true.B + } + } + } + + // Wire up request signals + val ld_bias_addr_start = RegInit(0.U(log2Up(max_acc_addr).W)) + val ex_c_addr_start = RegInit(0.U(log2Up(max_acc_addr).W)) + val st_addr_start = RegInit(0.U(log2Up(max_acc_addr).W)) + + val loop_requesting_ld_bias_id = Mux(head_loop.ld_bias_started, tail_loop_id, head_loop_id) + val loop_requesting_ld_bias = loops(loop_requesting_ld_bias_id) + ld_bias.io.req.bits.outer_bounds := loop_requesting_ld_bias.outer_bounds + ld_bias.io.req.bits.inner_bounds := loop_requesting_ld_bias.inner_bounds + ld_bias.io.req.bits.derived_params := loop_requesting_ld_bias.derived_params() + ld_bias.io.req.bits.addr_start := ld_bias_addr_start + ld_bias.io.req.bits.dram_addr := loop_requesting_ld_bias.bias_dram_addr + ld_bias.io.req.bits.no_bias := loop_requesting_ld_bias.no_bias + ld_bias.io.req.bits.partial_sum_mvin := loop_requesting_ld_bias.partial_sum_mvin + ld_bias.io.req.bits.loop_id := loop_requesting_ld_bias_id + + + ld_bias.io.req.valid := !loop_requesting_ld_bias.ld_bias_started && loop_requesting_ld_bias.configured + + when (ld_bias.io.req.fire()) { + loop_requesting_ld_bias.running := true.B + loop_requesting_ld_bias.ld_bias_started := true.B + + // when (loop_requesting_ld_bias.bias_dram_addr =/= 0.U) { + when (loop_requesting_ld_bias.output_dram_addr =/= 0.U) { + ld_bias_addr_start := floorAdd(ld_bias_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + + val loop_requesting_ld_input_id = Mux(head_loop.ld_input_started, tail_loop_id, head_loop_id) + val loop_requesting_ld_input = loops(loop_requesting_ld_input_id) + ld_input.io.req.bits.outer_bounds := loop_requesting_ld_input.outer_bounds + ld_input.io.req.bits.inner_bounds := loop_requesting_ld_input.inner_bounds + ld_input.io.req.bits.derived_params := loop_requesting_ld_input.derived_params() + ld_input.io.req.bits.addr_start := loop_requesting_ld_input.a_addr_start + ld_input.io.req.bits.dram_addr := loop_requesting_ld_input.input_dram_addr + ld_input.io.req.bits.loop_id := loop_requesting_ld_input_id + ld_input.io.req.valid := !loop_requesting_ld_input.ld_input_started && loop_requesting_ld_input.configured + + when (ld_input.io.req.fire()) { + loop_requesting_ld_input.running := true.B + loop_requesting_ld_input.ld_input_started := true.B + } + + val loop_requesting_ld_weights_id = Mux(head_loop.ld_weights_started, tail_loop_id, head_loop_id) + val loop_requesting_ld_weights = loops(loop_requesting_ld_weights_id) + ld_weights.io.req.bits.outer_bounds := loop_requesting_ld_weights.outer_bounds + ld_weights.io.req.bits.inner_bounds := loop_requesting_ld_weights.inner_bounds + ld_weights.io.req.bits.derived_params := loop_requesting_ld_weights.derived_params() + ld_weights.io.req.bits.addr_end := loop_requesting_ld_weights.b_addr_end + ld_weights.io.req.bits.dram_addr := loop_requesting_ld_weights.weights_dram_addr + ld_weights.io.req.bits.loop_id := loop_requesting_ld_weights_id + ld_weights.io.req.bits.depthwise := loop_requesting_ld_weights.depthwise + + ld_weights.io.req.valid := !loop_requesting_ld_weights.ld_weights_started && loop_requesting_ld_weights.configured + + when (ld_weights.io.req.fire()) { + loop_requesting_ld_weights.running := true.B + loop_requesting_ld_weights.ld_weights_started := true.B + } + + val loop_requesting_ex_id = Mux(head_loop.ex_started, tail_loop_id, head_loop_id) + val loop_requesting_ex = loops(loop_requesting_ex_id) + ex.io.req.bits.outer_bounds := loop_requesting_ex.outer_bounds + ex.io.req.bits.inner_bounds := loop_requesting_ex.inner_bounds + ex.io.req.bits.derived_params := loop_requesting_ex.derived_params() + ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start + ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end + ex.io.req.bits.c_addr_start := ex_c_addr_start + ex.io.req.bits.loop_id := loop_requesting_ex_id + + ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.ld_bias_started && + loop_requesting_ex.ld_input_started && loop_requesting_ex.ld_weights_started && loop_requesting_ex.configured + + when (ex.io.req.fire()) { + loop_requesting_ex.running := true.B + loop_requesting_ex.ex_started := true.B + + when (loop_requesting_ex.output_dram_addr =/= 0.U) { + ex_c_addr_start := floorAdd(ex_c_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + + val loop_requesting_st_id = Mux(head_loop.st_started, tail_loop_id, head_loop_id) + val loop_requesting_st = loops(loop_requesting_st_id) + st.io.req.bits.outer_bounds := loop_requesting_st.outer_bounds + st.io.req.bits.inner_bounds := loop_requesting_st.inner_bounds + st.io.req.bits.derived_params := loop_requesting_st.derived_params() + st.io.req.bits.addr_start := st_addr_start + st.io.req.bits.dram_addr := loop_requesting_st.output_dram_addr + st.io.req.bits.no_pool := loop_requesting_st.no_pool + st.io.req.bits.both_out := loop_requesting_st.both_out + st.io.req.bits.partial_sum := loop_requesting_st.partial_sum_mvout + st.io.req.bits.loop_id := loop_requesting_st_id + // added for 2 mvout + st.io.req.bits.dram_addr_pool := loop_requesting_st.pool_output_dram_addr + + + st.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.ex_started && loop_requesting_st.configured + + when (st.io.req.fire()) { + loop_requesting_st.running := true.B + loop_requesting_st.st_started := true.B + + when (loop_requesting_st.output_dram_addr =/= 0.U) { + st_addr_start := floorAdd(st_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) + } + } + + // Handle completed signals + when (ld_bias.io.idle && loops(ld_bias.io.loop_id).running && loops(ld_bias.io.loop_id).ld_bias_started) { + loops(ld_bias.io.loop_id).ld_bias_completed := true.B + } + + when (ld_input.io.idle && loops(ld_input.io.loop_id).running && loops(ld_input.io.loop_id).ld_input_started) { + loops(ld_input.io.loop_id).ld_input_completed := true.B + } + + when (ld_weights.io.idle && loops(ld_weights.io.loop_id).running && loops(ld_weights.io.loop_id).ld_weights_started) { + loops(ld_weights.io.loop_id).ld_weights_completed := true.B + } + + when (ex.io.idle && loops(ex.io.loop_id).running && loops(ex.io.loop_id).ex_started) { + loops(ex.io.loop_id).ex_completed := true.B + } + + when (st.io.idle && loops(st.io.loop_id).running && loops(st.io.loop_id).st_started) { + loops(st.io.loop_id).st_completed := true.B + } + + when (head_loop.running && head_loop.all_completed()) { + head_loop.reset() + head_loop_id := ~head_loop_id + } + + // Resets + when (reset.toBool()) { + loops.zipWithIndex.foreach { case (l, i) => + l.reset() + l.a_addr_start := (i * (max_addr / concurrent_loops)).U + l.b_addr_end := ((i+1) * (max_addr / concurrent_loops) - block_size).U + } + } +} + +object LoopConv { + def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt, + block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, + max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) + (implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = { + val mod = Module(new LoopConv(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts, + max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes)) + mod.io.in <> in + mod.io.ld_utilization := ld_utilization + mod.io.st_utilization := st_utilization + mod.io.ex_utilization := ex_utilization + (mod.io.out, mod.io.busy) + } +} diff --git a/src/main/scala/gemmini/LoopLoader.scala b/src/main/scala/gemmini/LoopLoader.scala new file mode 100644 index 00000000..0f7cabe3 --- /dev/null +++ b/src/main/scala/gemmini/LoopLoader.scala @@ -0,0 +1,308 @@ +package gemmini + +import chisel3._ +import chisel3.util._ +import chisel3.experimental._ +import freechips.rocketchip.tile.RoCCCommand +import freechips.rocketchip.config.Parameters +import GemminiISA._ +import Util._ + +class LoopLoader(block_size: Int, coreMaxAddrBits:Int, max_addr: Int, input_w: Int, dma_max_bytes: Int) + (implicit p: Parameters) extends Module { + val iterator_bitwidth = 16 + val max_block_len = (dma_max_bytes / (block_size * input_w / 8)) max 1 + + val io = IO(new Bundle { + val in = Flipped(Decoupled(new RoCCCommand)) + val out = Decoupled(new RoCCCommand) + val busy = Output(Bool()) + val latency = Output(UInt(iterator_bitwidth.W)) + val alert_cycle = Output(UInt(6.W)) + val pause_turn = Output(UInt(3.W)) + val pause_monitor = Input(Bool()) + }) + //queue for cmd + val cmd = Queue(io.in) + //val is_ldloop = cmd.bits.inst.funct === LOOP_LD + val is_matmul_ldconfig = cmd.bits.inst.funct === LOOP_LD_CONFIG_ADDRS || cmd.bits.inst.funct === LOOP_LD_CONFIG_BOUNDS + val is_conv_ldconfig = cmd.bits.inst.funct === LOOP_CONV_LD_CONFIG_ADDRS || cmd.bits.inst.funct === LOOP_CONV_LD_CONFIG_BOUNDS + + val pause_req = RegInit(false.B) + val lock_tag = RegInit(false.B) + val is_conv = RegInit(false.B) + // for switching between conv and matmul + val loop_tag_conv = RegInit(false.B) + val loop_tag_matmul = RegInit(false.B) + val loop_tag = Mux(is_conv, loop_tag_conv, loop_tag_matmul) + + when(cmd.bits.inst.funct === LOOP_LD_CONFIG_ADDRS || cmd.bits.inst.funct === LOOP_CONV_LD_CONFIG_ADDRS){ + lock_tag := true.B + } // no need to force flip once seen LOOP_LD + when(cmd.bits.inst.funct === LOOP_WS || cmd.bits.inst.funct === LOOP_CONV_WS){ + when(lock_tag){ + lock_tag := false.B + }.otherwise{ + when(is_conv){ + loop_tag_conv := ~loop_tag_conv + }.otherwise{ + loop_tag_matmul := ~loop_tag_matmul + } //force to flip to sync with loop matmul afterwards + } + } + // config states + val latency = RegInit(0.U(iterator_bitwidth.W)) //how many cycles to push + val alert_cycle = RegInit(0.U(6.W)) //raise flag after how much cycles? + val pause_turn = RegInit(1.U(3.W)) // how many turns to wait to pause monitoring TL ports + val dram_base_addr = RegInit(0.U(coreMaxAddrBits.W)) + val row_stride = RegInit(0.U(coreMaxAddrBits.W)) + + val row_iterator = RegInit(0.U(iterator_bitwidth.W))//Mux(req.transpose, j, k) //k + val col_iterator = RegInit(0.U(iterator_bitwidth.W))//Mux(req.transpose, k, j) //j + val max_row_iterator = Reg(UInt(iterator_bitwidth.W)) //Mux(req.transpose, max_j, max_k) + val max_col_iterator = Reg(UInt(iterator_bitwidth.W)) //Mux(req.transpose, max_k, max_j) + + val row_pad = Reg(UInt(iterator_bitwidth.W)) //Mux(req.transpose, pad_j, pad_k) + val col_pad = Reg(UInt(iterator_bitwidth.W)) //Mux(req.transpose, pad_k, pad_j) + + //conv parameters + val out_channels = RegInit(0.U(16.W)) + val in_channels = RegInit(0.U(16.W)) + val kernel_dim = RegInit(0.U(4.W)) + val krows = RegInit(0.U(4.W)) + val kcols = RegInit(0.U(4.W)) + val kchs = RegInit(0.U(16.W)) + val ochs = RegInit(0.U(16.W)) + + // conv Iterators + val och = RegInit(0.U(16.W)) + val krow = RegInit(0.U(4.W)) + val kcol = RegInit(0.U(4.W)) + val kch = RegInit(0.U(16.W)) + + val max_blocks = max_block_len.asUInt() + val AB = RegInit(false.B) //false if B, true if A + val profile = RegInit(false.B) + //ToDo: rotate starting address like LoopMatmul.scala + val A_sp_addr_start = Mux(loop_tag, (max_addr/2).U, 0.U)//RegInit(0.U(log2Up(max_addr).W)) + val B_sp_addr_end = Mux(loop_tag, (max_addr - block_size).U, (max_addr/2 - block_size).U)//RegInit((max_addr/2).U(log2Up(max_addr).W)) + //for conv + val depthwise = RegInit(false.B) + val out_channel_stride = RegInit(0.U(coreMaxAddrBits.W)) + val max_ochs_per_mvin = Mux(ochs < (max_block_len * block_size).U, ochs, (max_block_len * block_size).U) + val out_channels_per_bank = WireInit(0.U(8.W)) + out_channels_per_bank := ochs / block_size.U +& (ochs % block_size.U =/= 0.U) + val B_rows = out_channels_per_bank * kcols * krows * kchs + //val addr_start = B_sp_addr_end - B_rows + block_size.U + + val sp_addr_start = Mux(is_conv, B_sp_addr_end - B_rows + block_size.U, + Mux(AB, A_sp_addr_start, B_sp_addr_end - max_row_iterator * max_col_iterator * block_size.U + block_size.U)) // Todo: need mux with 0 (skip A) + val conv_dram_addr = Mux(depthwise, dram_base_addr +& ((krow*kernel_dim +& kcol +& kch) * out_channel_stride +& och) * (input_w/8).U, dram_base_addr +& ((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * out_channel_stride +& och) * (input_w/8).U) + val dram_addr = Mux(!is_conv, dram_base_addr + (row_iterator * row_stride + col_iterator) * block_size.U * (input_w/8).U, + conv_dram_addr) + val sp_addr = sp_addr_start + Mux(is_conv, (och / block_size.U) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch, + (row_iterator * max_col_iterator + col_iterator) * block_size.U) + val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator) + val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U) + val rows = block_size.U - Mux(max_row_iterator === max_row_iterator-1.U, row_pad, 0.U) + // for conv rows and cols + val J = Mux(ochs - och > max_ochs_per_mvin, max_ochs_per_mvin, ochs - och) + val K = Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) + + object State extends ChiselEnum { + val idle, config, ld = Value //added config for conv + } + import State._ + val state = RegInit(idle) + val configured = RegInit(false.B) + + val unlock_monitor = RegInit(0.U(4.W)) + val unlock_cycle = RegInit(0.U(4.W)) + val enable_bubble = RegInit(false.B) // enable monitoring for cache hits + val conflict_monitor = !(unlock_cycle === 0.U)//!((alert_cycle === 0.U) || (latency === 0.U)) + val conflict_monitor_start = conflict_monitor && Mux(is_conv, (och === 0.U && kch === 0.U && kcol === 0.U && krow === 0.U), (row_iterator === 0.U && col_iterator === 0.U)) && (state === ld) //ToDo: with conv + val conflict_monitor_end = conflict_monitor && Mux(is_conv, (kch + block_size.U >= kchs && kcol === kcols - 1.U && krow === krows - 1.U && och + max_ochs_per_mvin >= ochs), + (row_iterator === max_row_iterator - 1.U && col_iterator >= max_col_iterator - max_blocks)) && (state === ld) + + + val profile_hit = profile && (pause_turn =/= 0.U) + val profile_start = profile_hit && (row_iterator === 0.U && col_iterator === 0.U) + val profile_end = profile_hit && (row_iterator === max_row_iterator - 1.U && col_iterator >= max_col_iterator - max_blocks) + //ToDo: either load A or B (for now just do with B) + val load_cmd = Wire(new RoCCCommand()) + load_cmd := DontCare + load_cmd.inst.funct := Mux(AB, LOAD_CMD, LOAD2_CMD) + load_cmd.rs1 := dram_addr + load_cmd.rs2 := ((conflict_monitor && enable_bubble) << 63).asUInt() | (conflict_monitor_end << 62).asUInt() | (conflict_monitor_start << 61).asUInt() | (rows << 48).asUInt() | (profile_hit << 47).asUInt() | (profile_end << 46).asUInt() | (profile_start << 45).asUInt() | (cols << 32).asUInt() | sp_addr + + //for conv + val MVIN_SCALE_IDENTITY = 0x3f800000.U // TODO get this from configs somehow + val weight_spad_stride = krows * kcols * kchs + val config_cmd = Wire(new RoCCCommand) + config_cmd := DontCare + config_cmd.inst.funct := CONFIG_CMD + config_cmd.rs1 := (MVIN_SCALE_IDENTITY << 32.U).asUInt() | (weight_spad_stride << 16.U).asUInt() | (1.U << 3).asUInt() | 1.U + config_cmd.rs2 := out_channel_stride * (input_w/8).U + //for conv + val mvin_cmd = Wire(new RoCCCommand) + mvin_cmd := DontCare + mvin_cmd.inst.funct := LOAD2_CMD // for now, only weight + mvin_cmd.rs1 := dram_addr + mvin_cmd.rs2 := ((conflict_monitor && enable_bubble) << 63).asUInt() | (conflict_monitor_end << 62).asUInt() | (conflict_monitor_start << 61).asUInt() | (K << 48.U).asUInt() | (J << 32.U).asUInt() | sp_addr + + //val expected_tl_req = (max_addr / (2*2*max_block_len)).asUInt() + io.busy := cmd.valid || configured + io.alert_cycle := alert_cycle + io.latency := latency//Mux(enable_bubble, latency, 1.U) // latency + //enable_bubble := (latency =/= 0.U) //if latency == 0, disable bubble + // not enable bubble (loopld+FSM without bubble) + io.pause_turn := pause_turn + // fix loop_ws command + val loop_ws_state = RegInit(idle) + val is_loop_ws_addr = (cmd.bits.inst.funct === LOOP_WS_CONFIG_ADDRS_AB || cmd.bits.inst.funct === LOOP_CONV_WS_CONFIG_5) // for now, only weight for conv + val fixed_loop_cmd = Wire(new RoCCCommand()) + fixed_loop_cmd := DontCare + fixed_loop_cmd.inst.funct := cmd.bits.inst.funct//LOOP_WS_CONFIG_ADDRS_AB + fixed_loop_cmd.rs1 := Mux(cmd.bits.inst.funct === LOOP_CONV_WS_CONFIG_5, 0.U, Mux(AB, 0.U, cmd.bits.rs1)) //if conv, weight + fixed_loop_cmd.rs2 := Mux(is_conv, cmd.bits.rs2, Mux(AB, cmd.bits.rs2, 0.U)) //for now, not do input for conv + + unlock_monitor := floorAdd(unlock_monitor, 1.U, unlock_cycle + pause_turn - 1.U, pause_req && is_loop_ws_addr & lock_tag && cmd.fire()) + when(!pause_req){ + unlock_monitor := 0.U + } + //when(!configured){ + when((cmd.bits.inst.funct === LOOP_LD_CONFIG_BOUNDS || cmd.bits.inst.funct === LOOP_CONV_LD_CONFIG_BOUNDS) && cmd.valid){ + pause_req := io.pause_monitor + } + + val unlock = unlock_monitor + 1.U >= unlock_cycle // ToDo: change this number + + io.out.bits := Mux(configured, Mux(is_conv, Mux(state === config, config_cmd, mvin_cmd), load_cmd), + Mux(lock_tag && is_loop_ws_addr && (!pause_req || unlock) && (conflict_monitor || profile), fixed_loop_cmd, cmd.bits)) + io.out.bits.status := cmd.bits.status + io.out.valid := Mux(configured, state =/= idle, cmd.valid && !is_matmul_ldconfig && !is_conv_ldconfig) + cmd.ready := Mux(is_matmul_ldconfig || is_conv_ldconfig, !configured, !configured && io.out.ready) + +// when(cmd.valid && is_ldconfig && state === idle && (!pause_req || unlock)){ + when(cmd.valid && is_matmul_ldconfig && state === idle){ + switch(cmd.bits.inst.funct){ + is(LOOP_LD_CONFIG_BOUNDS){ + enable_bubble := cmd.bits.rs2(63) //diable: just loop B without bubble insertion + pause_turn := cmd.bits.rs2(iterator_bitwidth * 3 + 12, iterator_bitwidth * 3 + 10) + alert_cycle := cmd.bits.rs2(iterator_bitwidth * 3 + 5, iterator_bitwidth * 3) + latency := cmd.bits.rs2(iterator_bitwidth * 3 - 1, iterator_bitwidth * 2) //ToDo: give this to DMA + unlock_cycle := cmd.bits.rs2(iterator_bitwidth * 3 + 9, iterator_bitwidth * 3 + 6) + max_col_iterator := cmd.bits.rs2(iterator_bitwidth * 2 - 1, iterator_bitwidth) + max_row_iterator := cmd.bits.rs2(iterator_bitwidth-1, 0) + + AB := cmd.bits.rs1(63) + profile := cmd.bits.rs1(62) //added for profiling cache behavior + col_pad := cmd.bits.rs1(iterator_bitwidth * 2 - 1, iterator_bitwidth) + row_pad := cmd.bits.rs1(iterator_bitwidth-1, 0) + is_conv := false.B + } + is(LOOP_LD_CONFIG_ADDRS){ + when(!pause_req || unlock) { + dram_base_addr := cmd.bits.rs1 + row_stride := cmd.bits.rs2 + when(conflict_monitor || profile) { // if latency == 0, don't unroll + configured := true.B + state := ld + }.otherwise { + loop_tag_matmul := ~loop_tag_matmul + } + } + } + } + }.elsewhen(cmd.valid && is_conv_ldconfig && state === idle){ + switch(cmd.bits.inst.funct){ + is(LOOP_CONV_LD_CONFIG_BOUNDS){ + enable_bubble := cmd.bits.rs2(63) //diable: just loop B without bubble insertion + pause_turn := cmd.bits.rs2(60, 58) + unlock_cycle := cmd.bits.rs2(57, 54) + alert_cycle := cmd.bits.rs2(53, 48) + latency := cmd.bits.rs2(47, 32) //ToDo: give this to DMA + kernel_dim := cmd.bits.rs2(15, 0)//can code more if needed + + krows := cmd.bits.rs1(63, 48) + kcols := cmd.bits.rs1(47, 32) + kchs := cmd.bits.rs1(31, 16) + ochs := cmd.bits.rs1(15, 0) + is_conv := true.B + + // initialize for safety + krow := 0.U + kcol := 0.U + kch := 0.U + och := 0.U + } + is(LOOP_CONV_LD_CONFIG_ADDRS){ + when(!pause_req || unlock) { + dram_base_addr := cmd.bits.rs1 + out_channel_stride := cmd.bits.rs2(47, 32) + depthwise := cmd.bits.rs2(63) + out_channels := cmd.bits.rs2(31, 16) + in_channels := cmd.bits.rs2(15, 0) + //can code more + when(conflict_monitor) { + configured := true.B + state := config // for conv, idle -> config -> ld + }.otherwise{ + loop_tag_conv := ~loop_tag_conv + } + } + } + } + } + + + when(io.out.fire() && state === ld) { + when(!is_conv) { + //matmul loop + val row_blocks = 1.U + val col_blocks = max_blocks + + val next_col = floorAdd(col_iterator, col_blocks, max_col_iterator) + val next_row = floorAdd(row_iterator, row_blocks, max_row_iterator, next_col === 0.U) + + row_iterator := next_row + col_iterator := next_col + + when(next_row === 0.U && next_col === 0.U) { //finished loading + state := idle + configured := false.B + loop_tag_matmul := ~loop_tag_matmul + } + }.otherwise{ + //conv loop + val next_kch = floorAdd(kch, block_size.U, kchs) + val next_kcol = floorAdd(kcol, 1.U, kcols, next_kch === 0.U) + val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U) + val next_och = floorAdd(och, max_ochs_per_mvin, ochs, next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) + + kch := next_kch + kcol := next_kcol + krow := next_krow + och := next_och + + when(next_och === 0.U && next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U){ //finished loading + state := idle + configured := false.B + loop_tag_conv := ~loop_tag_conv + } + } + }.elsewhen(io.out.fire() && state === config){ //for conv config + state := ld + } + +} + +object LoopLoader{ + def apply(in: DecoupledIO[RoCCCommand], pause_monitor: Bool, block_size: Int, coreMaxAddrBits: Int, max_addr: Int, input_w: Int, dma_max_bytes: Int) + (implicit p: Parameters): Tuple5[DecoupledIO[RoCCCommand], Bool, UInt, UInt, UInt] = { + val lld = Module(new LoopLoader(block_size, coreMaxAddrBits, max_addr, input_w, dma_max_bytes)) + lld.io.in <> in + lld.io.pause_monitor <> pause_monitor + (lld.io.out, lld.io.busy, lld.io.latency, lld.io.alert_cycle, lld.io.pause_turn) + } +} \ No newline at end of file diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index 181202b3..3da10080 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -77,15 +77,20 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.k := k io.idle := state === idle - io.cmd.valid := state =/= idle && !io.rob_overloaded + io.cmd.valid := state =/= idle && !io.rob_overloaded && (req.dram_addr =/= 0.U) io.cmd.bits := mvin_cmd io.loop_id := req.loop_id - when (io.cmd.fire()) { + when (req.dram_addr === 0.U) { + state := idle + }.elsewhen (io.cmd.fire()) { // The order here is k, j, i - val next_i = floorAdd(i, 1.U, req.max_i) - val next_k = floorAdd(k, max_blocks, req.max_k, next_i === 0.U) + val i_blocks = Mux(req.transpose, max_blocks, 1.U) + val k_blocks = Mux(req.transpose, 1.U, max_blocks) + + val next_i = floorAdd(i, i_blocks, req.max_i) + val next_k = floorAdd(k, k_blocks, req.max_k, next_i === 0.U) i := next_i k := next_k @@ -156,7 +161,7 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val max_col_dim = Mux(req.transpose, req.max_k, req.max_j) val max_blocks = Mux(max_col_dim <= max_block_len.U, max_col_dim, max_block_len.U) - val sp_addr_start = req.addr_end - req.max_k * req.max_j * block_size.U + val sp_addr_start = req.addr_end - req.max_k * req.max_j * block_size.U + block_size.U val dram_addr = req.dram_addr + (row_iterator * req.dram_stride + col_iterator) * block_size.U * (input_w/8).U val sp_addr = sp_addr_start + (row_iterator * max_col_iterator + col_iterator) * block_size.U @@ -175,18 +180,22 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.j := j io.idle := state === idle - io.cmd.valid := state =/= idle && !io.rob_overloaded + io.cmd.valid := state =/= idle && !io.rob_overloaded && (req.dram_addr =/= 0.U) io.cmd.bits := mvin_cmd io.loop_id := req.loop_id - when (io.cmd.fire()) { - // The order here is k, j, i - val next_j = floorAdd(j, max_blocks, req.max_j) - val next_k = floorAdd(k, 1.U, req.max_k, next_j === 0.U) + when (req.dram_addr === 0.U) { + state := idle + }.elsewhen (io.cmd.fire()) { // The order here is k, j, i + val j_blocks = Mux(req.transpose, 1.U, max_blocks) + val k_blocks = Mux(req.transpose, max_blocks, 1.U) + + val next_j = floorAdd(j, j_blocks, req.max_j) + val next_k = floorAdd(k, k_blocks, req.max_k, next_j === 0.U) - k := next_k j := next_j + k := next_k when (next_j === 0.U && next_k === 0.U) { state := idle @@ -229,7 +238,7 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In }) object State extends ChiselEnum { - val idle, st = Value + val idle, ld = Value } import State._ val state = RegInit(idle) @@ -270,8 +279,8 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In state := idle }.elsewhen (io.cmd.fire()) { // The order here is k, j, i - val next_i = floorAdd(i, max_blocks, req.max_i) - val next_j = floorAdd(j, 1.U, req.max_j, next_i === 0.U) + val next_i = floorAdd(i, 1.U, req.max_i) + val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U) i := next_i j := next_j @@ -283,7 +292,7 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In when (io.req.fire()) { req := io.req.bits - state := st + state := ld j := 0.U i := 0.U } @@ -308,7 +317,6 @@ class LoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, val it class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, concurrent_loops: Int) (implicit p: Parameters) extends Module { - val MAX_BLOCK_LEN = 4 // TODO get this from configs val GARBAGE_ADDR = (~0.U(32.W)).asUInt() val io = IO(new Bundle { @@ -343,7 +351,7 @@ class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val d_addr_start = (BigInt(1) << 31).U | req.c_addr_start val c_addr_start = (BigInt(3) << 30).U | req.c_addr_start - val b_addr_start = req.b_addr_end - req.max_k * req.max_j * block_size.U + val b_addr_start = req.b_addr_end - req.max_k * req.max_j * block_size.U + block_size.U val k = Reg(UInt(iterator_bitwidth.W)) val j = Reg(UInt(iterator_bitwidth.W)) @@ -439,11 +447,12 @@ class LoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat val dram_addr = UInt(coreMaxAddrBits.W) val dram_stride = UInt(coreMaxAddrBits.W) val full_c = Bool() + //val partial_sum = Bool() // to move out partial sum val addr_start = UInt(log2Up(max_acc_addr).W) val loop_id = UInt(log2Up(concurrent_loops).W) } -class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, concurrent_loops: Int) +class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int) (implicit p: Parameters) extends Module { val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulStCReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, concurrent_loops))) @@ -471,6 +480,8 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val req = Reg(new LoopMatmulStCReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, concurrent_loops)) + val max_blocks = Mux(req.full_c, 1.U, Mux(req.max_j <= max_block_len.U, req.max_j, max_block_len.U)) + val j = Reg(UInt(iterator_bitwidth.W)) val i = Reg(UInt(iterator_bitwidth.W)) @@ -479,7 +490,8 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val dram_addr = Mux(req.full_c, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (acc_w/8).U, req.dram_addr + (i * req.dram_stride + j) * block_size.U * (input_w/8).U) val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U - val cols = block_size.U - Mux(j + 1.U >= req.max_j, req.pad_j, 0.U) + val blocks = Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j) + val cols = (blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U) val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U) val mvout_cmd = Wire(new RoCCCommand) @@ -494,7 +506,11 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In io.idle := state === idle // The order here is k, j, i - val ex_ahead = io.ex_completed || (io.ex_k === req.max_k - 1.U && (io.ex_j > j || (io.ex_j === j && io.ex_i > i))) + // val ex_ahead = io.ex_completed || (io.ex_k === req.max_k - 1.U && (io.ex_j > j || (io.ex_j === j && io.ex_i > i))) + val ex_ahead = io.ex_completed || + (io.ex_k === req.max_k - 1.U && + (io.ex_j >= j + blocks || + ((io.ex_j === j + blocks - 1.U) && io.ex_i > i))) io.cmd.valid := state =/= idle && !io.rob_overloaded && ex_ahead && req.dram_addr =/= 0.U io.cmd.bits := mvout_cmd @@ -506,7 +522,7 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In }.elsewhen (io.cmd.fire()) { // The order here is k, j, i val next_i = floorAdd(i, 1.U, req.max_i) - val next_j = floorAdd(j, 1.U, req.max_j, next_i === 0.U) + val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U) i := next_i j := next_j @@ -548,6 +564,7 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val val b_transpose = Bool() val low_d = Bool() + //val partial_sum = Bool() //to moveout partial sum val full_c = Bool() val ex_accumulate = Bool() @@ -595,22 +612,22 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) (implicit p: Parameters) extends Module { val iterator_bitwidth = 16 - val max_block_len = (dma_max_bytes / (block_size * input_w * 8)) max 1 - val max_block_len_acc = (dma_max_bytes / (block_size * acc_w * 8)) max 1 + val max_block_len = (dma_max_bytes / (block_size * input_w / 8)) max 1 + val max_block_len_acc = (dma_max_bytes / (block_size * acc_w / 8)) max 1 val io = IO(new Bundle { val in = Flipped(Decoupled(new RoCCCommand)) val out = Decoupled(new RoCCCommand) - val ld_utilization = Input(UInt(log2Up(rob_size).W)) - val st_utilization = Input(UInt(log2Up(rob_size).W)) - val ex_utilization = Input(UInt(log2Up(rob_size).W)) + val ld_utilization = Input(UInt(log2Up(rob_size+1).W)) + val st_utilization = Input(UInt(log2Up(rob_size+1).W)) + val ex_utilization = Input(UInt(log2Up(rob_size+1).W)) val busy = Output(Bool()) }) // Create states val concurrent_loops = 2 val loops = Reg(Vec(concurrent_loops, new LoopMatmulState(iterator_bitwidth, coreMaxAddrBits, max_addr, max_acc_addr))) - val head_loop_id = Reg(UInt(log2Up(concurrent_loops).W)) + val head_loop_id = RegInit(0.U(log2Up(concurrent_loops).W)) val tail_loop_id = (~head_loop_id).asUInt() // This is the loop that we always try to configure if available val head_loop = loops(head_loop_id) val tail_loop = loops(tail_loop_id) @@ -625,7 +642,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val ldB = Module(new LoopMatmulLdB(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, input_w, max_block_len, concurrent_loops)) val ldD = Module(new LoopMatmulLdD(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, max_block_len_acc, concurrent_loops)) val ex = Module(new LoopMatmulExecute(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, max_acc_addr, concurrent_loops)) - val stC = Module(new LoopMatmulStC(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, concurrent_loops)) + val stC = Module(new LoopMatmulStC(block_size, coreMaxAddrBits, iterator_bitwidth, max_acc_addr, input_w, acc_w, max_block_len, concurrent_loops)) // Create command queue val cmd = Queue(io.in) @@ -654,7 +671,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val is_loop_cmd = is_loop_run_cmd || is_loop_config_cmd io.out.bits := Mux(loop_configured, unrolled_cmd.bits, cmd.bits) - io.out.bits.status := cmd.bits.status + io.out.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this io.out.valid := Mux(loop_configured, unrolled_cmd.valid, cmd.valid && !is_loop_config_cmd && !is_loop_run_cmd) cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured, !loop_configured && io.out.ready) @@ -681,6 +698,9 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: stC.io.ex_j := ex.io.j stC.io.ex_i := ex.io.i + val loops_configured = RegInit(0.U(16.W)) + dontTouch(loops_configured) + // Create config registers when(cmd.valid && is_loop_cmd && !loop_being_configured.configured) { @@ -719,10 +739,13 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: loop_being_configured.ex_accumulate := cmd.bits.rs1(0) loop_being_configured.full_c := cmd.bits.rs1(1) loop_being_configured.low_d := cmd.bits.rs1(2) + //loop_being_configured.partial_sum := cmd.bits.rs1(3) loop_being_configured.a_transpose := cmd.bits.rs2(0) loop_being_configured.b_transpose := cmd.bits.rs2(1) loop_being_configured.configured := true.B + + loops_configured := loops_configured + 1.U } } } @@ -831,6 +854,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: stC.io.req.bits.dram_addr := loop_requesting_st.c_dram_addr stC.io.req.bits.dram_stride := loop_requesting_st.c_dram_stride stC.io.req.bits.full_c := loop_requesting_st.full_c + //stC.io.req.bits.partial_sum := loop_requesting_st.partial_sum stC.io.req.bits.addr_start := st_c_addr_start stC.io.req.bits.loop_id := loop_requesting_st_id diff --git a/src/main/scala/gemmini/Mesh.scala b/src/main/scala/gemmini/Mesh.scala index 5f50c992..074ed445 100644 --- a/src/main/scala/gemmini/Mesh.scala +++ b/src/main/scala/gemmini/Mesh.scala @@ -15,21 +15,26 @@ import chisel3.experimental._ * @param meshColumns */ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, - df: Dataflow.Value, pe_latency: Int, + df: Dataflow.Value, pe_latency: Int, max_simultaneous_matmuls: Int, val tileRows: Int, val tileColumns: Int, val meshRows: Int, val meshColumns: Int) extends Module { val io = IO(new Bundle { - val in_a = Input(Vec(meshRows, Vec(tileRows, inputType))) - val in_b = Input(Vec(meshColumns, Vec(tileColumns, inputType))) - val in_d = Input(Vec(meshColumns, Vec(tileColumns, inputType))) - val in_control = Input(Vec(meshColumns, Vec(tileColumns, new PEControl(accType)))) - val out_b = Output(Vec(meshColumns, Vec(tileColumns, outputType))) - val out_c = Output(Vec(meshColumns, Vec(tileColumns, outputType))) + val in_a = Input(Vec(meshRows, Vec(tileRows, inputType))) + val in_b = Input(Vec(meshColumns, Vec(tileColumns, inputType))) + val in_d = Input(Vec(meshColumns, Vec(tileColumns, inputType))) + val in_control = Input(Vec(meshColumns, Vec(tileColumns, new PEControl(accType)))) + val in_id = Input(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) // The unique id of this particular matmul + val in_last = Input(Vec(meshColumns, Vec(tileColumns, Bool()))) + val out_b = Output(Vec(meshColumns, Vec(tileColumns, outputType))) + val out_c = Output(Vec(meshColumns, Vec(tileColumns, outputType))) val in_valid = Input(Vec(meshColumns, Vec(tileColumns, Bool()))) val out_valid = Output(Vec(meshColumns, Vec(tileColumns, Bool()))) + val out_control = Output(Vec(meshColumns, Vec(tileColumns, new PEControl(accType)))) + val out_id = Output(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) + val out_last = Output(Vec(meshColumns, Vec(tileColumns, Bool()))) }) // mesh(r)(c) => Tile at row r, column c - val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, pe_latency, tileRows, tileColumns))) + val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, pe_latency, max_simultaneous_matmuls, tileRows, tileColumns))) val meshT = mesh.transpose // Chain tile_a_out -> tile_a_in (pipeline a across each row) // TODO clock-gate A signals with in_garbage @@ -57,6 +62,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, } } // Chain control signals (pipeline across each column) + assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_))) for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_control(c), io.in_valid(c))) { case ((in_ctrl, valid), tile) => @@ -68,6 +74,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, (tile.io.out_control, tile.io.out_valid) } } + // Chain in_valid (pipeline across each column) for (c <- 0 until meshColumns) { meshT(c).foldLeft(io.in_valid(c)) { @@ -76,13 +83,35 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, tile.io.out_valid } } + + // Chain in_id (pipeline across each column) + for (c <- 0 until meshColumns) { + meshT(c).foldLeft(io.in_id(c)) { + case (in_id, tile) => + tile.io.in_id := RegNext(in_id) + tile.io.out_id + } + } + + // Chain in_last (pipeline across each column) + for (c <- 0 until meshColumns) { + meshT(c).foldLeft(io.in_last(c)) { + case (in_last, tile) => + tile.io.in_last := RegNext(in_last) + tile.io.out_last + } + } + // Capture out_vec and out_control_vec (connect IO to bottom row of mesh) // (The only reason we have so many zips is because Scala doesn't provide a zipped function for Tuple4) - for (((b, c), (v, tile)) <- ((io.out_b zip io.out_c), (io.out_valid zip mesh.last)).zipped) { + for (((((((b, c), v), ctrl), id), last), tile) <- io.out_b zip io.out_c zip io.out_valid zip io.out_control zip io.out_id zip io.out_last zip mesh.last) { // TODO we pipelined this to make physical design easier. Consider removing these if possible // TODO shouldn't we clock-gate these signals with "garbage" as well? b := RegNext(tile.io.out_b) c := RegNext(tile.io.out_c) v := RegNext(tile.io.out_valid) + ctrl := RegNext(tile.io.out_control) + id := RegNext(tile.io.out_id) + last := RegNext(tile.io.out_last) } } diff --git a/src/main/scala/gemmini/MeshWithDelays.scala b/src/main/scala/gemmini/MeshWithDelays.scala index d400c677..ec094a21 100644 --- a/src/main/scala/gemmini/MeshWithDelays.scala +++ b/src/main/scala/gemmini/MeshWithDelays.scala @@ -6,6 +6,26 @@ import chisel3.util._ import gemmini.Util._ +class MeshWithDelaysReq[T <: Data: Arithmetic, TagT <: TagQueueTag with Data](accType: T, tagType: TagT, block_size: Int) extends Bundle { + val pe_control = new PEControl(accType) + val a_transpose = Bool() + val bd_transpose = Bool() + val total_rows = UInt(log2Up(block_size+1).W) + val tag = tagType + val flush = UInt(2.W) // TODO magic number + + override def cloneType: MeshWithDelaysReq.this.type = new MeshWithDelaysReq(accType, tagType, block_size).asInstanceOf[this.type] +} + +class MeshWithDelaysResp[T <: Data: Arithmetic, TagT <: TagQueueTag with Data](outType: T, meshCols: Int, tileCols: Int, block_size: Int, tagType: TagT) extends Bundle { + val data = Vec(meshCols, Vec(tileCols, outType)) + val total_rows = UInt(log2Up(block_size+1).W) + val tag = tagType + val last = Bool() + + override def cloneType: MeshWithDelaysResp.this.type = new MeshWithDelaysResp(outType, meshCols, tileCols, block_size, tagType).asInstanceOf[this.type] +} + // TODO Add io.out.ready back in. Before it was removed, it didn't work when banking, and it seemed to assume that SRAM outputs stay steady when ren is low // TODO Handle matrices where N1 =/= N2 =/= N3 // TODO do we flush for one cycle more than necessary? @@ -15,7 +35,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] (inputType: T, val outputType: T, accType: T, tagType: U, df: Dataflow.Value, pe_latency: Int, tileRows: Int, tileColumns: Int, meshRows: Int, meshColumns: Int, - leftBanks: Int, upBanks: Int, outBanks: Int = 1) + leftBanks: Int, upBanks: Int, outBanks: Int = 1, n_simultaneous_matmuls: Int = -1) extends Module { val A_TYPE = Vec(meshRows, Vec(tileRows, inputType)) @@ -24,26 +44,28 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] val D_TYPE = Vec(meshColumns, Vec(tileColumns, inputType)) val S_TYPE = Vec(meshColumns, Vec(tileColumns, new PEControl(accType))) - val tagqlen = (if (meshColumns == 1) 4 else 5) * (pe_latency+1) // TODO change the tag-queue so we can make this 3 + assert(meshRows*tileRows == meshColumns*tileColumns) + val block_size = meshRows*tileRows + + val max_simultaneous_matmuls = if (n_simultaneous_matmuls == -1) { + 5 * (pe_latency + 1) + } else { + n_simultaneous_matmuls + } + assert(max_simultaneous_matmuls >= 5 * (pe_latency + 1)) + + val tagqlen = max_simultaneous_matmuls+1 val io = IO(new Bundle { val a = Flipped(Decoupled(A_TYPE)) val b = Flipped(Decoupled(B_TYPE)) val d = Flipped(Decoupled(D_TYPE)) - // TODO make pe_control a ready-valid interface as well - val pe_control = Input(new PEControl(accType)) + val req = Flipped(Decoupled(new MeshWithDelaysReq(accType, tagType.cloneType, block_size))) - val a_transpose = Input(Bool()) - val bd_transpose = Input(Bool()) + val resp = Valid(new MeshWithDelaysResp(outputType, meshColumns, tileColumns, block_size, tagType.cloneType)) - val tag_in = Flipped(Decoupled(tagType)) - val tag_out = Output(tagType) val tags_in_progress = Output(Vec(tagqlen, tagType)) - - val out = Valid(C_TYPE) // TODO make this ready-valid - - val flush = Flipped(Decoupled(UInt(2.W))) }) def shifted[T <: Data](x: Vec[Vec[T]], banks: Int, reverse: Boolean = false) = { @@ -70,33 +92,43 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] } } - assert(meshRows*tileRows == meshColumns*tileColumns) - val block_size = meshRows*tileRows + val req = Reg(UDValid(new MeshWithDelaysReq(accType, tagType, block_size))) - val active = RegInit(0.U(1.W)) // Which buffer is currently being read from? - val not_active = (~active).asUInt() + val matmul_id = RegInit(0.U(log2Up(max_simultaneous_matmuls).W)) - val flushing = RegInit(false.B) - val flushing_or_about_to = flushing || io.flush.fire() - - val fire_counter = RegInit(0.U((log2Ceil(block_size) max 1).W)) - val fire_started = RegInit(false.B) + val total_fires = req.bits.total_rows + val fire_counter = RegInit(0.U(log2Up(block_size).W)) val a_buf = RegEnable(io.a.bits, io.a.fire()) val b_buf = RegEnable(io.b.bits, io.b.fire()) val d_buf = RegEnable(io.d.bits, io.d.fire()) - val in_prop_reg = Reg(UInt(1.W)) // TODO inelegant - val in_prop = WireInit(in_prop_reg) - val a_written = RegInit(false.B) val b_written = RegInit(false.B) val d_written = RegInit(false.B) - val tag_written = RegInit(false.B) + val in_prop = Reg(UInt(1.W)) // TODO inelegant - val buffering_done = fire_counter === 0.U && fire_started && tag_written - val waiting_on_non_matrix_inputs = fire_counter === 0.U && !(tag_written || io.tag_in.fire()) // TODO change when more non-matrix inputs are buffered + val input_next_row_into_spatial_array = req.valid && ((a_written && b_written && d_written) || req.bits.flush > 0.U) + + val last_fire = fire_counter === total_fires - 1.U && input_next_row_into_spatial_array + + when (io.req.fire()) { + req.push(io.req.bits) + in_prop := io.req.bits.pe_control.propagate ^ in_prop + matmul_id := wrappingAdd(matmul_id, 1.U, max_simultaneous_matmuls) + }.elsewhen (last_fire) { + req.valid := req.bits.flush > 1.U + req.bits.flush := req.bits.flush - 1.U + } + + when (input_next_row_into_spatial_array) { + a_written := false.B + b_written := false.B + d_written := false.B + + fire_counter := wrappingAdd(fire_counter, 1.U, total_fires) + } when (io.a.fire()) { a_written := true.B @@ -110,66 +142,47 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] d_written := true.B } - val next_row_input = (io.a.fire() || a_written) && (io.b.fire() || b_written) && (io.d.fire() || d_written) - - when (next_row_input || flushing_or_about_to) { - a_written := false.B - b_written := false.B - d_written := false.B - - fire_counter := wrappingAdd(fire_counter, 1.U, block_size) - fire_started := true.B // We only need to write to this here, rather than in a "when (buffering_done)" statement - } + io.a.ready := !a_written || input_next_row_into_spatial_array || io.req.ready + io.b.ready := !b_written || input_next_row_into_spatial_array || io.req.ready + io.d.ready := !d_written || input_next_row_into_spatial_array || io.req.ready - io.a.ready := !a_written - io.b.ready := !b_written - io.d.ready := !d_written + assert(req.valid || !input_next_row_into_spatial_array) - val pause = (waiting_on_non_matrix_inputs || !next_row_input) && !flushing_or_about_to + val pause = !req.valid || !input_next_row_into_spatial_array // Transposer - val a_is_from_transposer = Mux(io.pe_control.dataflow === Dataflow.OS.id.U, !io.a_transpose, io.a_transpose) - val b_is_from_transposer = io.pe_control.dataflow === Dataflow.OS.id.U && io.bd_transpose - val d_is_from_transposer = io.pe_control.dataflow === Dataflow.WS.id.U && io.bd_transpose + val a_is_from_transposer = Mux(req.bits.pe_control.dataflow === Dataflow.OS.id.U, !req.bits.a_transpose, req.bits.a_transpose) + val b_is_from_transposer = req.bits.pe_control.dataflow === Dataflow.OS.id.U && req.bits.bd_transpose + val d_is_from_transposer = req.bits.pe_control.dataflow === Dataflow.WS.id.U && req.bits.bd_transpose val transposer = Module(new AlwaysOutTransposer(block_size, inputType)) transposer.io.inRow.valid := !pause && (a_is_from_transposer || b_is_from_transposer || d_is_from_transposer) - // transposer.io.inRow.bits := VecInit( - // Mux(a_is_from_transposer, Mux(io.a.fire(), io.a.bits, a_buf), Mux(io.b.fire(), io.b.bits, b_buf)).flatten) - transposer.io.inRow.bits := MuxCase(VecInit(Mux(io.a.fire(), io.a.bits, a_buf).flatten), Seq( - b_is_from_transposer -> VecInit(Mux(io.b.fire(), io.b.bits, b_buf).flatten), - d_is_from_transposer -> VecInit(Mux(io.d.fire(), io.d.bits, d_buf).flatten.reverse) + transposer.io.inRow.bits := MuxCase(VecInit(a_buf.flatten), Seq( + b_is_from_transposer -> VecInit(b_buf.flatten), + d_is_from_transposer -> VecInit(d_buf.flatten.reverse), )) transposer.io.outCol.ready := true.B val transposer_out = VecInit(transposer.io.outCol.bits.grouped(tileRows).map(t => VecInit(t)).toSeq) // Wire up mesh's IO to this module's IO - val mesh = Module(new Mesh(inputType, outputType, accType, df, pe_latency, tileRows, tileColumns, meshRows, meshColumns)) + val mesh = Module(new Mesh(inputType, outputType, accType, df, pe_latency, max_simultaneous_matmuls, tileRows, tileColumns, meshRows, meshColumns)) // TODO wire only to *_buf here, instead of io.*.bits - - /*val a_shifter_in = WireInit(Mux(io.pe_control.dataflow === Dataflow.OS.id.U, - a_transposed, Mux(io.a.fire(), io.a.bits, a_buf)))*/ - val a_shifter_in = WireInit(Mux(a_is_from_transposer, - transposer_out, Mux(io.a.fire(), io.a.bits, a_buf))) - // val b_shifter_in = WireInit(Mux(io.b.fire(), io.b.bits, b_buf)) - val b_shifter_in = WireInit(Mux(b_is_from_transposer, - transposer_out, Mux(io.b.fire(), io.b.bits, b_buf))) - // val d_shifter_in = Mux(io.d.fire(), io.d.bits, d_buf) + val a_shifter_in = WireInit(Mux(a_is_from_transposer, transposer_out, a_buf)) + val b_shifter_in = WireInit(Mux(b_is_from_transposer, transposer_out, b_buf)) val d_shifter_in = WireInit(Mux(d_is_from_transposer, - VecInit(transposer_out.flatten.reverse.grouped(tileRows).map(VecInit(_)).toSeq), - Mux(io.d.fire(), io.d.bits, d_buf))) + VecInit(transposer_out.flatten.reverse.grouped(tileRows).map(VecInit(_)).toSeq), d_buf)) mesh.io.in_a := shifted(a_shifter_in, leftBanks) mesh.io.in_b := shifted(b_shifter_in, upBanks) mesh.io.in_d := shifted(d_shifter_in, upBanks) mesh.io.in_control.zipWithIndex.foreach { case (ss, i) => - ss.foreach(_.dataflow := ShiftRegister(io.pe_control.dataflow, i * (pe_latency + 1))) + ss.foreach(_.dataflow := ShiftRegister(req.bits.pe_control.dataflow, i * (pe_latency + 1))) ss.foreach(_.propagate := ShiftRegister(in_prop, i * (pe_latency + 1))) } - val result_shift = RegNext(io.pe_control.shift) // TODO will this arrive at the right time if memory isn't pipelined? + val result_shift = RegNext(req.bits.pe_control.shift) // TODO will this arrive at the right time if memory isn't pipelined? mesh.io.in_control.zipWithIndex.foreach { case (ctrl, i) => ctrl.foreach(_.shift := ShiftRegister(result_shift, i * (pe_latency + 1))) } @@ -177,87 +190,73 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] val not_paused_vec = VecInit(Seq.fill(meshColumns)(VecInit(Seq.fill(tileColumns)(!pause)))) mesh.io.in_valid := shifted(not_paused_vec, upBanks) - // We want to output C when we're output-stationary, but B when we're weight-stationary - // TODO these would actually overlap when we switch from output-stationary to weight-stationary - // TODO should we use io.m, or the mode output of the mesh? - io.out.bits := shifted(Mux(io.pe_control.dataflow === Dataflow.OS.id.U, mesh.io.out_c, mesh.io.out_b), outBanks, true) - - io.out.valid := shifted(mesh.io.out_valid, outBanks, reverse = true)(0)(0) + val matmul_id_vec = VecInit(Seq.fill(meshColumns)(VecInit(Seq.fill(tileColumns)(matmul_id)))) + mesh.io.in_id := shifted(matmul_id_vec, upBanks) - // Tags - val tag_queue = Module(new TagQueue(tagqlen, tagType)) // TODO understand the actual required size better + val matmul_last_vec = VecInit(Seq.fill(meshColumns)(VecInit(Seq.fill(tileColumns)(last_fire)))) + mesh.io.in_last := shifted(matmul_last_vec, upBanks) - val tag_garbage = Wire(tagType.cloneType) - tag_garbage := DontCare - tag_garbage.make_this_garbage() + // We want to output C when we're output-stationary, but B when we're weight-stationary + // TODO these would actually overlap when we switch from output-stationary to weight-stationary + val out_pe_control = shifted(mesh.io.out_control, outBanks, reverse = true)(0)(0) + io.resp.bits.data := shifted(Mux(out_pe_control.dataflow === Dataflow.OS.id.U, mesh.io.out_c, mesh.io.out_b), outBanks, true) - tag_queue.io.in.bits := Mux(flushing, tag_garbage, io.tag_in.bits) + io.resp.valid := shifted(mesh.io.out_valid, outBanks, reverse = true)(0)(0) - val tag_id_reg = RegInit(0.U(1.W)) // Used to keep track of when we should increment // TODO inelegant - val tag_id = WireInit(tag_id_reg) - val tag_id_delayed = ShiftRegister(tag_id, (meshRows + S_TYPE.size - 1) * (pe_latency + 1) + 1, 0.U, true.B) + val out_last = shifted(mesh.io.out_last, outBanks, reverse = true)(0)(0) + io.resp.bits.last := out_last - tag_queue.io.out.next := tag_id_delayed =/= RegNext(tag_id_delayed, 0.U) + // Tags + class TagWithIdAndTotalRows extends Bundle with TagQueueTag { + val tag = tagType.cloneType + val id = UInt(log2Up(max_simultaneous_matmuls).W) + val total_rows = UInt(log2Up(block_size+1).W) + + override def make_this_garbage(dummy: Int=0): Unit = { + total_rows := block_size.U + tag.make_this_garbage() + } - when (io.tag_in.fire()) { - tag_written := true.B - tag_id := ~tag_id_reg - tag_id_reg := tag_id + override def cloneType: TagWithIdAndTotalRows.this.type = (new TagWithIdAndTotalRows).asInstanceOf[this.type] } - io.tag_in.ready := !tag_written - tag_queue.io.in.valid := io.tag_in.fire() - - io.tag_out := tag_queue.io.out.bits(Mux(io.pe_control.dataflow === Dataflow.OS.id.U, 0.U, 1.U)) - io.tags_in_progress := tag_queue.io.out.all - - // Flipping logic - when(buffering_done && (next_row_input || flushing_or_about_to)) { - active := not_active - io.tag_in.ready := true.B - tag_written := io.tag_in.fire() + val matmul_id_of_output = wrappingAdd(matmul_id, Mux(io.req.bits.pe_control.dataflow === Dataflow.OS.id.U, 3.U, 2.U), max_simultaneous_matmuls) + val matmul_id_of_current = wrappingAdd(matmul_id, 1.U, max_simultaneous_matmuls) - tag_id := ~tag_id_reg - tag_id_reg := tag_id + val tagq = Module(new TagQueue(new TagWithIdAndTotalRows, tagqlen)) + tagq.io.enq.valid := io.req.fire() && io.req.bits.flush === 0.U + tagq.io.enq.bits.tag := io.req.bits.tag + tagq.io.enq.bits.total_rows := DontCare + tagq.io.enq.bits.id := matmul_id_of_output - when (!flushing) { - in_prop := io.pe_control.propagate ^ in_prop_reg - in_prop_reg := in_prop - } - } + val tag_garbage = Wire(tagType.cloneType) + tag_garbage := DontCare + tag_garbage.make_this_garbage() - // Flushing logic - val flush_counter = Reg(UInt(2.W)) + val out_matmul_id = WireInit(shifted(mesh.io.out_id, outBanks, reverse = true)(0)(0)) + io.resp.bits.tag := Mux(tagq.io.deq.valid && out_matmul_id === tagq.io.deq.bits.id, tagq.io.deq.bits.tag, tag_garbage) - io.flush.ready := !flushing - // assert(!(io.flush.valid && !buffering_done)) // TODO get rid of this once we get the ability to ignore D + dontTouch(out_matmul_id) - when (io.flush.fire()) { - flushing := true.B - flush_counter := io.flush.bits + tagq.io.deq.ready := io.resp.valid && io.resp.bits.last && out_matmul_id === tagq.io.deq.bits.id - // Avoid overwriting accumulated values - a_buf := 0.U.asTypeOf(A_TYPE) // TODO make 0 an Arithmetic member function - b_buf := 0.U.asTypeOf(B_TYPE) - a_shifter_in := 0.U.asTypeOf(A_TYPE) - b_shifter_in := 0.U.asTypeOf(B_TYPE) - } + val total_rows_q = Module(new Queue(new TagWithIdAndTotalRows, tagqlen)) + total_rows_q.io.enq.valid := io.req.fire() && io.req.bits.flush === 0.U + total_rows_q.io.enq.bits.tag := DontCare + total_rows_q.io.enq.bits.total_rows := io.req.bits.total_rows + total_rows_q.io.enq.bits.id := matmul_id_of_current - when (flushing) { - Seq(io.a.ready, io.b.ready, io.d.ready, io.tag_in.ready).foreach(_ := false.B) + io.resp.bits.total_rows := Mux(total_rows_q.io.deq.valid && out_matmul_id === total_rows_q.io.deq.bits.id, + total_rows_q.io.deq.bits.total_rows, block_size.U) - tag_written := true.B + total_rows_q.io.deq.ready := io.resp.valid && io.resp.bits.last && out_matmul_id === total_rows_q.io.deq.bits.id - when (buffering_done) { - flush_counter := flush_counter - 1.U - tag_queue.io.in.valid := true.B - } + io.req.ready := (!req.valid || last_fire) && tagq.io.enq.ready && total_rows_q.io.enq.ready + io.tags_in_progress := tagq.io.all.map(_.tag) - val about_to_finish_flushing = flush_counter === 0.U && fire_counter === (block_size-1).U // TODO change when non-square requirement lifted - when (about_to_finish_flushing) { - fire_counter := 0.U - tag_queue.io.in.valid := true.B - flushing := false.B - } + when (reset.toBool()) { + req.valid := false.B } + + assert(!(io.req.fire() && !tagq.io.enq.ready && io.req.bits.flush === 0.U)) } diff --git a/src/main/scala/gemmini/PE.scala b/src/main/scala/gemmini/PE.scala index b912ad34..79944b72 100644 --- a/src/main/scala/gemmini/PE.scala +++ b/src/main/scala/gemmini/PE.scala @@ -17,7 +17,7 @@ class PEControl[T <: Data : Arithmetic](accType: T) extends Bundle { * A PE implementing a MAC operation. Configured as fully combinational when integrated into a Mesh. * @param width Data width of operands */ -class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, latency: Int) +class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, latency: Int, max_simultaneous_matmuls: Int) (implicit ev: Arithmetic[T]) extends Module { // Debugging variables import ev._ @@ -32,8 +32,16 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val in_control = Input(new PEControl(accType)) val out_control = Output(new PEControl(accType)) + val in_id = Input(UInt(log2Up(max_simultaneous_matmuls).W)) + val out_id = Output(UInt(log2Up(max_simultaneous_matmuls).W)) + + val in_last = Input(Bool()) + val out_last = Output(Bool()) + val in_valid = Input(Bool()) val out_valid = Output(Bool()) + + val bad_dataflow = Output(Bool()) }) val cType = if (df == Dataflow.WS) inputType else accType @@ -46,12 +54,16 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val dataflow = ShiftRegister(io.in_control.dataflow, latency) val prop = ShiftRegister(io.in_control.propagate, latency) val shift = ShiftRegister(io.in_control.shift, latency) + val id = ShiftRegister(io.in_id, latency) + val last = ShiftRegister(io.in_last, latency) val valid = ShiftRegister(io.in_valid, latency) // TODO should we clockgate the rest of the ShiftRegisters based on the values in this ShiftRegisters io.out_a := a io.out_control.dataflow := dataflow io.out_control.propagate := prop io.out_control.shift := shift + io.out_id := id + io.out_last := last io.out_valid := valid val last_s = RegEnable(prop, valid) @@ -66,6 +78,7 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val COMPUTE = 0.U(1.W) val PROPAGATE = 1.U(1.W) + io.bad_dataflow := false.B when ((df == Dataflow.OS).B || ((df == Dataflow.BOTH).B && dataflow === OUTPUT_STATIONARY)) { when(prop === PROPAGATE) { io.out_c := (c1 >> shift_offset).clippedToWidthOf(outputType) @@ -89,7 +102,8 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, c2 := d } }.otherwise { - assert(false.B, "unknown dataflow") + io.bad_dataflow := true.B + //assert(false.B, "unknown dataflow") io.out_c := DontCare io.out_b := DontCare } diff --git a/src/main/scala/gemmini/ROB.scala b/src/main/scala/gemmini/ROB.scala index ccc6dbd2..f02f43c4 100644 --- a/src/main/scala/gemmini/ROB.scala +++ b/src/main/scala/gemmini/ROB.scala @@ -3,66 +3,84 @@ package gemmini import chisel3._ import chisel3.util._ - import freechips.rocketchip.tile.RoCCCommand - import GemminiISA._ import Util._ -//import midas.targetutils.FpgaDebug // TODO unify this class with GemminiCmdWithDeps -class ROBIssue[T <: Data](cmd_t: T, nEntries: Int) extends Bundle { +class ROBIssue[T <: Data](cmd_t: T, rob_entries: Int) extends Bundle { val valid = Output(Bool()) val ready = Input(Bool()) val cmd = Output(cmd_t.cloneType) - val rob_id = Output(UInt(log2Up(nEntries).W)) + val rob_id = Output(UInt(log2Up(rob_entries).W)) def fire(dummy: Int=0) = valid && ready - override def cloneType: this.type = new ROBIssue(cmd_t, nEntries).asInstanceOf[this.type] + override def cloneType: this.type = new ROBIssue(cmd_t, rob_entries).asInstanceOf[this.type] } // TODO we don't need to store the full command in here. We should be able to release the command directly into the relevant controller and only store the associated metadata in the ROB. This would reduce the size considerably -class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows: Int, block_cols: Int) extends Module { +class ROB[T <: Data : Arithmetic, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V], cmd_t: RoCCCommand) extends Module { + import config._ + + val block_rows = tileRows * meshRows + val block_cols = tileColumns * meshColumns + val io = IO(new Bundle { val alloc = Flipped(Decoupled(cmd_t.cloneType)) - val completed = Flipped(Valid(UInt(log2Up(nEntries).W))) + val completed = Flipped(Valid(UInt(log2Up(rob_entries).W))) val issue = new Bundle { - val ld = new ROBIssue(cmd_t, nEntries) - val st = new ROBIssue(cmd_t, nEntries) - val ex = new ROBIssue(cmd_t, nEntries) + val ld = new ROBIssue(cmd_t, rob_entries) + val st = new ROBIssue(cmd_t, rob_entries) + val ex = new ROBIssue(cmd_t, rob_entries) } - val ld_utilization = Output(UInt(log2Up(nEntries).W)) - val st_utilization = Output(UInt(log2Up(nEntries).W)) - val ex_utilization = Output(UInt(log2Up(nEntries).W)) + val ld_utilization = Output(UInt(log2Up(rob_entries+1).W)) + val st_utilization = Output(UInt(log2Up(rob_entries+1).W)) + val ex_utilization = Output(UInt(log2Up(rob_entries+1).W)) val busy = Output(Bool()) val solitary_preload = Input(Bool()) // TODO very hacky. from ExecuteController, to prevent infinite fence stalls. remove later }) + // TODO make this a ChiselEnum val ldq :: stq :: exq :: Nil = Enum(3) val q_t = ldq.cloneType + class OpT extends Bundle { + val start = local_addr_t.cloneType + val end = local_addr_t.cloneType + val wraps_around = Bool() + + def overlaps(other: OpT): Bool = { + ((other.start <= start && (start < other.end || other.wraps_around)) || + (start <= other.start && (other.start < end || wraps_around))) && + !(start.is_garbage() || other.start.is_garbage()) // TODO the "is_garbage" check might not really be necessary + } + } + + val instructions_allocated = RegInit(0.U(32.W)) + when (io.alloc.fire()) { + instructions_allocated := instructions_allocated + 1.U + } + dontTouch(instructions_allocated) + class Entry extends Bundle { val q = q_t.cloneType val is_config = Bool() - val op1 = UDValid(local_addr_t.cloneType) - val op2 = UDValid(local_addr_t.cloneType) - // val op3 = UDValid(local_addr_t.cloneType) - - val dst = UDValid(new Bundle { - val start = local_addr_t.cloneType - val len = UInt(8.W) // TODO magic number + val opa = UDValid(new OpT) + val opa_is_dst = Bool() + val opb = UDValid(new OpT) - def end(dummy: Int = 0): LocalAddr = start + len * block_rows.U - def wraps_around(dummy: Int = 0): Bool = start.add_with_overflow(len * block_rows.U)._2 - }) + // val op1 = UDValid(new OpT) + // val op1 = UDValid(new OpT) + // val op2 = UDValid(new OpT) + // val dst = UDValid(new OpT) val issued = Bool() @@ -70,31 +88,68 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows val cmd = cmd_t.cloneType - val deps = Vec(nEntries, Bool()) + val deps = Vec(rob_entries, Bool()) def ready(dummy: Int = 0): Bool = !deps.reduce(_ || _) + + // Debugging signals + val allocated_at = UInt(instructions_allocated.getWidth.W) } + val full_entries = Reg(Vec(rob_full_entries, UDValid(new Entry))) + val partial_entries = Reg(Vec(rob_partial_entries, UDValid(new Entry))) - val entries = Reg(Vec(nEntries, UDValid(new Entry))) + val entries = full_entries ++ partial_entries val empty = !entries.map(_.valid).reduce(_ || _) val full = entries.map(_.valid).reduce(_ && _) - // io.busy := !empty + // TODO we could also check for a solitary preload by recording the last instruction that was allocated, rather than + // reading all entries to check for preloads, which is an O(n) operation in terms of area cost val utilization = PopCount(entries.map(_.valid)) val solitary_preload = utilization === 1.U && entries.map(e => e.valid && e.bits.cmd.inst.funct === PRELOAD_CMD).reduce(_ || _) io.busy := !empty && !(solitary_preload && io.solitary_preload) - // Read in commands to the buffer - io.alloc.ready := !full - val last_allocated = Reg(UInt(log2Up(nEntries).W)) + // Config values set by programmer + val a_stride = Reg(UInt(16.W)) // TODO magic numbers // TODO we also need to check the transpose to see how many rows we're reading + val ld_block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) + val st_block_stride = block_rows.U + val pooling_is_enabled = Reg(Bool()) val new_entry = Wire(new Entry) new_entry := DontCare - val new_entry_id = MuxCase((nEntries-1).U, entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) + val new_full_allocs = Wire(Vec(rob_full_entries, Bool())) + new_full_allocs.foreach(_ := false.B) + val new_partial_allocs = Wire(Vec(rob_partial_entries, Bool())) + new_partial_allocs.foreach(_ := false.B) + val new_entry_oh = new_full_allocs ++ new_partial_allocs val alloc_fire = io.alloc.fire() - when (io.alloc.fire()) { + val raws_probe = WireInit(0.U(rob_entries.W)) + val waws_probe = WireInit(0.U(rob_entries.W)) + val wars_probe = WireInit(0.U(rob_entries.W)) + val older_in_same_q_probe = WireInit(0.U(rob_entries.W)) + val is_st_and_must_wait_for_prior_ex_config_probe = WireInit(0.U(rob_entries.W)) + val is_ex_config_and_must_wait_for_prior_st_probe = WireInit(0.U(rob_entries.W)) + + val wars_op1_probe = WireInit(0.U(rob_entries.W)) + val wars_op2_probe = WireInit(0.U(rob_entries.W)) + val raws_op1_probe = WireInit(0.U(rob_entries.W)) + val raws_op2_probe = WireInit(0.U(rob_entries.W)) + + dontTouch(raws_probe) + dontTouch(waws_probe) + dontTouch(wars_probe) + dontTouch(wars_op1_probe) + dontTouch(wars_op2_probe) + dontTouch(raws_op1_probe) + dontTouch(raws_op2_probe) + dontTouch(older_in_same_q_probe) + dontTouch(is_st_and_must_wait_for_prior_ex_config_probe) + dontTouch(is_ex_config_and_must_wait_for_prior_st_probe) + + dontTouch(new_entry) + io.alloc.ready := false.B + when (io.alloc.valid) { val spAddrBits = 32 val cmd = io.alloc.bits val funct = cmd.inst.funct @@ -106,23 +161,93 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows new_entry.is_config := funct === CONFIG_CMD - new_entry.op1.valid := funct === PRELOAD_CMD || funct_is_compute - new_entry.op1.bits := cmd.rs1.asTypeOf(local_addr_t) + val op1 = Wire(UDValid(new OpT)) + op1.valid := false.B + op1.bits := DontCare + val op2 = Wire(UDValid(new OpT)) + op2.valid := false.B + op2.bits := DontCare + val dst = Wire(UDValid(new OpT)) + dst.valid := false.B + dst.bits := DontCare + assert(!(op1.valid && op2.valid && dst.valid)) + + new_entry.opa_is_dst := dst.valid + when (dst.valid) { + new_entry.opa := dst + new_entry.opb := Mux(op1.valid, op1, op2) + } .otherwise { + new_entry.opa := Mux(op1.valid, op1, op2) + new_entry.opb := op2 + } - new_entry.op2.valid := funct_is_compute || funct === STORE_CMD - new_entry.op2.bits := cmd.rs2.asTypeOf(local_addr_t) + op1.valid := funct === PRELOAD_CMD || funct_is_compute + op1.bits.start := cmd.rs1.asTypeOf(local_addr_t) + when (funct === PRELOAD_CMD) { + val preload_rows = cmd.rs1(48 + log2Up(block_rows + 1) - 1, 48) + op1.bits.end := op1.bits.start + preload_rows + op1.bits.wraps_around := op1.bits.start.add_with_overflow(preload_rows)._2 + }.otherwise { + val compute_rows = cmd.rs1(48 + log2Up(block_rows + 1) - 1, 48) * a_stride + op1.bits.end := op1.bits.start + compute_rows + op1.bits.wraps_around := op1.bits.start.add_with_overflow(compute_rows)._2 + } - // new_entry.op3.valid := funct_is_compute - // new_entry.op3.bits := cmd.rs1(63, 32).asTypeOf(local_addr_t) + op2.valid := funct_is_compute || funct === STORE_CMD + op2.bits.start := cmd.rs2.asTypeOf(local_addr_t) + when (funct_is_compute) { + val compute_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) + op2.bits.end := op2.bits.start + compute_rows + op2.bits.wraps_around := op2.bits.start.add_with_overflow(compute_rows)._2 + }.elsewhen (pooling_is_enabled) { + // If pooling is enabled, then we assume that this command simply mvouts everything in this accumulator bank from + // start to the end of the bank + val acc_bank = op2.bits.start.acc_bank() + + val next_bank_addr = WireInit(0.U.asTypeOf(local_addr_t)) + next_bank_addr.is_acc_addr := true.B + next_bank_addr.data := (acc_bank + 1.U) << local_addr_t.accBankRowBits + + op2.bits.end := next_bank_addr + op2.bits.wraps_around := next_bank_addr.acc_bank() === 0.U + }.otherwise { + val block_stride = st_block_stride + + val mvout_cols = cmd.rs2(32 + mvout_cols_bits - 1, 32) + val mvout_rows = cmd.rs2(48 + mvout_rows_bits - 1, 48) + + val mvout_mats = mvout_cols / block_cols.U + (mvout_cols % block_cols.U =/= 0.U) + val total_mvout_rows = ((mvout_mats - 1.U) * block_stride) + mvout_rows + + op2.bits.end := op2.bits.start + total_mvout_rows + op2.bits.wraps_around := pooling_is_enabled || op2.bits.start.add_with_overflow(total_mvout_rows)._2 + } - val mvin_mvout_len = cmd.rs2(48, spAddrBits) - new_entry.dst.valid := funct === PRELOAD_CMD || funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD - new_entry.dst.bits.start := cmd.rs2(31, 0).asTypeOf(local_addr_t) - new_entry.dst.bits.len := Mux(funct === PRELOAD_CMD, 1.U, mvin_mvout_len / block_cols.U + (mvin_mvout_len % block_cols.U =/= 0.U)) + dst.valid := funct === PRELOAD_CMD || funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD + dst.bits.start := cmd.rs2(31, 0).asTypeOf(local_addr_t) + when (funct === PRELOAD_CMD) { + val preload_rows = cmd.rs2(48 + log2Up(block_rows + 1) - 1, 48) + dst.bits.end := dst.bits.start + preload_rows + dst.bits.wraps_around := dst.bits.start.add_with_overflow(preload_rows)._2 + }.otherwise { + val id = MuxCase(0.U, Seq((new_entry.cmd.inst.funct === LOAD2_CMD) -> 1.U, + (new_entry.cmd.inst.funct === LOAD3_CMD) -> 2.U)) + val block_stride = ld_block_strides(id) + + val mvin_cols = cmd.rs2(32 + mvin_cols_bits - 1, 32) + val mvin_rows = cmd.rs2(48 + mvin_rows_bits - 1, 48) + + val mvin_mats = mvin_cols / block_cols.U + (mvin_cols % block_cols.U =/= 0.U) + val total_mvin_rows = ((mvin_mats - 1.U) * block_stride) + mvin_rows + + dst.bits.end := dst.bits.start + total_mvin_rows + dst.bits.wraps_around := dst.bits.start.add_with_overflow(total_mvin_rows)._2 + } val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD) val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_STORE) val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_EX || config_cmd_type === CONFIG_IM2COL)) + val is_im2col = funct === CONFIG_CMD && config_cmd_type === CONFIG_IM2COL // im2col commands are a subset of ex commands, so they still go in the ex queue new_entry.q := Mux1H(Seq( is_load -> ldq, @@ -130,80 +255,134 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows is_ex -> exq )) - val raws = entries.map { e => - // We search for all entries which write to an address which we read from - e.valid && e.bits.dst.valid && e.bits.q =/= new_entry.q && ( - (new_entry.op1.valid && e.bits.dst.bits.start <= new_entry.op1.bits && (e.bits.dst.bits.end() > new_entry.op1.bits || e.bits.dst.bits.wraps_around())) || - (new_entry.op2.valid && e.bits.dst.bits.start <= new_entry.op2.bits && (e.bits.dst.bits.end() > new_entry.op2.bits || e.bits.dst.bits.wraps_around()))) /* || - (new_entry.op3.valid && e.bits.dst.bits.start <= new_entry.op3.bits && e.bits.dst.bits.end() > new_entry.op3.bits)) */ - } + assert(is_load || is_store || is_ex) + // This can be RAW op1/op2 <- dst, or WAW dst <- dst + val opa_matches_opa = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opa.bits.overlaps(e.bits.opa.bits) }) + // This can be WAR dst <- op1/op2 + val opa_matches_opb = VecInit(entries.map { e => e.valid && e.bits.opb.valid && new_entry.opa.bits.overlaps(e.bits.opb.bits) }) + // This can be RAW op2 <- dst + val opb_matches_opa = VecInit(entries.map { e => e.valid && e.bits.opa.valid && new_entry.opb.bits.overlaps(e.bits.opa.bits) }) - val wars = entries.map { e => - // We search for all entries which read from an address that we write to - e.valid && new_entry.dst.valid && e.bits.q =/= new_entry.q && ( - (e.bits.op1.valid && new_entry.dst.bits.start <= e.bits.op1.bits && (new_entry.dst.bits.end() > e.bits.op1.bits || new_entry.dst.bits.wraps_around())) || - (e.bits.op2.valid && new_entry.dst.bits.start <= e.bits.op2.bits && (new_entry.dst.bits.end() > e.bits.op2.bits || new_entry.dst.bits.wraps_around()))) /* || - (e.bits.op3.valid && new_entry.dst.bits.start <= e.bits.op3.bits && new_entry.dst.bits.end() > e.bits.op3.bits)) */ - } + val op1_matches_opa = VecInit((entries zip (opa_matches_opa zip opb_matches_opa)).map { case (e, (a, b)) => + e.valid && op1.valid && Mux(dst.valid, b, a) + }) + val op2_matches_opa = VecInit((entries zip (opa_matches_opa zip opb_matches_opa)).map { case (e, (a, b)) => + e.valid && op2.valid && Mux(dst.valid || op1.valid, b, a) + }) + val dst_matches_opa = VecInit((entries zip opa_matches_opa).map { case (e, a) => + e.valid && dst.valid && a + }) + val dst_matches_opb = VecInit((entries zip opa_matches_opb).map { case (e, b) => + e.valid && dst.valid && b + }) + + val op1_raws_opa = VecInit((entries zip op1_matches_opa).map { case (e, m) => + m && op1.valid && e.bits.q =/= new_entry.q && e.bits.opa_is_dst + }) + val op2_raws_opa = VecInit((entries zip op2_matches_opa).map { case (e, m) => + m && op2.valid && e.bits.q =/= new_entry.q && e.bits.opa_is_dst + }) + val raws = VecInit((op1_raws_opa zip op2_raws_opa).map { case (a, b) => a || b }) - val waws = entries.map { e => - def is_accumulative(laddr: LocalAddr): Bool = laddr.is_acc_addr && laddr.accumulate + val dst_wars_opa = VecInit((entries zip dst_matches_opa).map { case (e, m) => + m && dst.valid && e.bits.q =/= new_entry.q && !e.bits.opa_is_dst + }) + val dst_wars_opb = VecInit((entries zip dst_matches_opb).map { case (e, m) => + m && dst.valid && e.bits.q =/= new_entry.q + }) + val wars = VecInit((dst_wars_opa zip dst_wars_opb).map { case (a, b) => a || b }) - // We search for all entries which write to an address that we write to - e.valid && new_entry.dst.valid && e.bits.dst.valid && e.bits.q =/= new_entry.q && - !(is_accumulative(new_entry.dst.bits.start) && is_accumulative(e.bits.dst.bits.start)) && - ((new_entry.dst.bits.start <= e.bits.dst.bits.start && (new_entry.dst.bits.end() > e.bits.dst.bits.start || new_entry.dst.bits.wraps_around())) || - (e.bits.dst.bits.start <= new_entry.dst.bits.start && (e.bits.dst.bits.end() > new_entry.dst.bits.start || e.bits.dst.bits.wraps_around()))) - } + val dst_waws_opa = VecInit((entries zip dst_matches_opa).map { case (e, m) => + m && dst.valid && (e.bits.q =/= new_entry.q || new_entry.q === ldq) && e.bits.opa_is_dst + }) + val waws = dst_waws_opa - val older_in_same_q = entries.map { e => + val older_in_same_q = VecInit(entries.map { e => e.valid && e.bits.q === new_entry.q && !e.bits.issued - } + }) - val is_st_and_must_wait_for_prior_ex_config = entries.map { e => + val is_st_and_must_wait_for_prior_ex_config = VecInit(entries.map { e => e.valid && new_entry.q === stq && !new_entry.is_config && e.bits.q === exq && e.bits.is_config - } + }) - val is_ex_config_and_must_wait_for_prior_st = entries.map { e => + val is_ex_config_and_must_wait_for_prior_st = VecInit(entries.map { e => e.valid && new_entry.q === exq && new_entry.is_config && e.bits.q === stq && !e.bits.is_config - } + }) new_entry.deps := (Cat(raws) | Cat(wars) | Cat(waws) | Cat(older_in_same_q) | Cat(is_st_and_must_wait_for_prior_ex_config) | Cat(is_ex_config_and_must_wait_for_prior_st)).asBools().reverse + raws_probe := Cat(raws.reverse) + waws_probe := Cat(waws.reverse) + wars_probe := Cat(wars.reverse) + older_in_same_q_probe := Cat(older_in_same_q.reverse) + is_st_and_must_wait_for_prior_ex_config_probe := Cat(is_st_and_must_wait_for_prior_ex_config.reverse) + is_ex_config_and_must_wait_for_prior_st_probe := Cat(is_ex_config_and_must_wait_for_prior_st.reverse) + + new_entry.allocated_at := instructions_allocated + new_entry.complete_on_issue := new_entry.is_config && new_entry.q =/= exq - entries(new_entry_id).valid := true.B - entries(new_entry_id).bits := new_entry + val is_full = PopCount(Seq(dst.valid, op1.valid, op2.valid)) > 1.U + val full_alloc_id = MuxCase((rob_full_entries-1).U, full_entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) + val partial_alloc_id = MuxCase((rob_partial_entries-1).U, partial_entries.zipWithIndex.map { case (e, i) => !e.valid -> i.U }) + + when (!is_full && !partial_entries(partial_alloc_id).valid) { + io.alloc.ready := true.B + partial_entries(partial_alloc_id).valid := true.B + partial_entries(partial_alloc_id).bits := new_entry + partial_entries(partial_alloc_id).bits.opb.valid := false.B + partial_entries(partial_alloc_id).bits.opb.bits := DontCare + new_partial_allocs(partial_alloc_id) := true.B + } .elsewhen (!full_entries(full_alloc_id).valid) { + io.alloc.ready := true.B + full_entries(full_alloc_id).valid := true.B + full_entries(full_alloc_id).bits := new_entry + new_full_allocs(full_alloc_id) := true.B + } - last_allocated := new_entry_id + when (io.alloc.fire()) { + when (new_entry.is_config && new_entry.q === exq && !is_im2col) { + a_stride := new_entry.cmd.rs1(31, 16) // TODO magic numbers // TODO this needs to be kept in sync with ExecuteController.scala + }.elsewhen(new_entry.is_config && new_entry.q === ldq) { + val id = new_entry.cmd.rs1(4,3) // TODO magic numbers + val block_stride = new_entry.cmd.rs1(31, 16) // TODO magic numbers + ld_block_strides(id) := block_stride + }.elsewhen(new_entry.is_config && new_entry.q === stq) { + val pool_stride = new_entry.cmd.rs1(5, 4) // TODO magic numbers + pooling_is_enabled := pool_stride =/= 0.U + } + } } // Issue commands which are ready to be issued Seq((ldq, io.issue.ld), (stq, io.issue.st), (exq, io.issue.ex)).foreach { case (q, io) => - val issue_id = MuxCase((nEntries-1).U, entries.zipWithIndex.map { case (e, i) => - (e.valid && e.bits.ready() && !e.bits.issued && e.bits.q === q) -> i.U - }) + val issue_valids = entries.map(e => e.valid && e.bits.ready() && !e.bits.issued && e.bits.q === q) + val issue_sel = PriorityEncoderOH(issue_valids) + val issue_id = OHToUInt(issue_sel) + val issue_entry = Mux1H(issue_sel, entries) - io.valid := entries.map(e => e.valid && e.bits.ready() && !e.bits.issued && e.bits.q === q).reduce(_ || _) - io.cmd := entries(issue_id).bits.cmd - io.rob_id := issue_id + io.valid := issue_valids.reduce(_||_) + io.cmd := issue_entry.bits.cmd + io.rob_id := OHToUInt(issue_sel) when (io.fire()) { - entries(issue_id).bits.issued := true.B - // Clear out all the dependency bits for instructions which depend on the same queue entries.zipWithIndex.foreach { case (e, i) => - val is_same_q = Mux(alloc_fire && new_entry_id === i.U, - new_entry.q === entries(issue_id).bits.q, - e.bits.q === entries(issue_id).bits.q) + val is_same_q = Mux(alloc_fire && new_entry_oh(i), + new_entry.q === issue_entry.bits.q, + e.bits.q === issue_entry.bits.q) - when (is_same_q || entries(issue_id).bits.complete_on_issue) { + when (is_same_q || issue_entry.bits.complete_on_issue) { e.bits.deps(issue_id) := false.B } } - - entries(issue_id).valid := !entries(issue_id).bits.complete_on_issue + for ((e, i) <- entries.zipWithIndex) { + when (issue_sel(i)) { + e.bits.issued := true.B + e.valid := !e.bits.complete_on_issue + } + } } } @@ -211,8 +390,12 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows when (io.completed.fire()) { entries.foreach(_.bits.deps(io.completed.bits) := false.B) - entries(io.completed.bits).valid := false.B - assert(entries(io.completed.bits).valid) + for ((e, i) <- entries.zipWithIndex) { + when (i.U === io.completed.bits) { + e.valid := false.B + assert(e.valid) + } + } } // val utilization = PopCount(entries.map(e => e.valid)) @@ -227,7 +410,14 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows io.st_utilization := utilization_st_q io.ex_utilization := utilization_ex_q - val packed_deps = VecInit(entries.map(e => Cat(e.bits.deps))) + val valids = VecInit(entries.map(_.valid)) + val functs = VecInit(entries.map(_.bits.cmd.inst.funct)) + val issueds = VecInit(entries.map(_.bits.issued)) + val packed_deps = VecInit(entries.map(e => Cat(e.bits.deps.reverse))) + + dontTouch(valids) + dontTouch(functs) + dontTouch(issueds) dontTouch(packed_deps) val pop_count_packed_deps = VecInit(entries.map(e => Mux(e.valid, PopCount(e.bits.deps), 0.U))) @@ -236,15 +426,18 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows dontTouch(pop_count_packed_deps) dontTouch(min_pop_count) - val cycles_since_issue = RegInit(0.U(16.W)) + val cycles_since_issue = RegInit(0.U(20.W)) - when (io.issue.ld.fire() || io.issue.st.fire() || io.issue.ex.fire() || !io.busy) { + when (io.issue.ld.fire() || io.issue.st.fire() || io.issue.ex.fire() || !io.busy || io.completed.fire()) { cycles_since_issue := 0.U }.elsewhen(io.busy) { cycles_since_issue := cycles_since_issue + 1.U } - assert(cycles_since_issue < 10000.U, "pipeline stall") + assert(cycles_since_issue < 100000.U, "pipeline stall") + for (e <- entries) { + dontTouch(e.bits.allocated_at) + } val cntr = Counter(10000000) when (cntr.inc()) { @@ -256,7 +449,6 @@ class ROB(cmd_t: RoCCCommand, nEntries: Int, local_addr_t: LocalAddr, block_rows printf(p"Utilization st q: $utilization_st_q\n") printf(p"Utilization ex q: $utilization_ex_q\n") printf(p"Packed deps: $packed_deps\n") - printf(p"Last allocated: $last_allocated\n\n") } when (reset.asBool()) { diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index 39596cc3..7c2fc926 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -15,17 +15,24 @@ class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: val vaddr = UInt(coreMaxAddrBits.W) val laddr = local_addr_t.cloneType - val len = UInt(16.W) // TODO don't use a magic number for the width here + val cols = UInt(16.W) // TODO don't use a magic number for the width here val repeats = UInt(16.W) // TODO don't use a magic number for the width here - val scale = UInt(scale_t_bits.W) - val has_acc_bitwidth = Bool() - + val all_zeros = Bool() + val block_stride = UInt(16.W) // TODO magic numbers val cmd_id = UInt(8.W) // TODO don't use a magic number here - val status = new MStatus + //for bank conflict monitoring + val monitor_conflict = Bool() + val monitor_conflict_start = Bool() + val monitor_conflict_end = Bool() + + val profile_conflict = Bool() + val profile_conflict_start = Bool() + val profile_conflict_end = Bool() + override def cloneType: this.type = new ScratchpadMemReadRequest(local_addr_t, scale_t_bits).asInstanceOf[this.type] } @@ -35,9 +42,9 @@ class ScratchpadMemWriteRequest(local_addr_t: LocalAddr) val laddr = local_addr_t.cloneType val len = UInt(16.W) // TODO don't use a magic number for the width here + val block = UInt(8.W) // TODO don't use a magic number for the width here val cmd_id = UInt(8.W) // TODO don't use a magic number here - val status = new MStatus // Pooling variables @@ -95,7 +102,7 @@ class ScratchpadWriteIO(val n: Int, val w: Int, val mask_len: Int) extends Bundl val data = Output(UInt(w.W)) } -class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends Module { +class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int, single_ported: Boolean) extends Module { // This is essentially a pipelined SRAM with the ability to stall pipeline stages require(w % aligned_to == 0 || w < aligned_to) @@ -107,9 +114,11 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends val write = Flipped(new ScratchpadWriteIO(n, w, mask_len)) }) - // val mem = SyncReadMem(n, UInt(w.W)) val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) + // When the scratchpad is single-ported, the writes take precedence + val singleport_busy_with_write = single_ported.B && io.write.en + when (io.write.en) { if (aligned_to >= w) mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem))) @@ -119,7 +128,13 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends val raddr = io.read.req.bits.addr val ren = io.read.req.fire() - val rdata = mem.read(raddr, ren).asUInt() + val rdata = if (single_ported) { + assert(!(ren && io.write.en)) + mem.read(raddr, ren && !io.write.en).asUInt() + } else { + mem.read(raddr, ren).asUInt() + } + val fromDMA = io.read.req.bits.fromDMA // Make a queue which buffers the result of an SRAM read if it can't immediately be consumed @@ -129,7 +144,7 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int) extends q.io.enq.bits.fromDMA := RegNext(fromDMA) val q_will_be_empty = (q.io.count +& q.io.enq.fire()) - q.io.deq.fire() === 0.U - io.read.req.ready := q_will_be_empty + io.read.req.ready := q_will_be_empty && !singleport_busy_with_write // Build the rest of the resp pipeline val rdata_p = Pipeline(q.io.deq, mem_pipeline) @@ -183,9 +198,16 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Accumulator ports val acc = new Bundle { - val read = Flipped(Vec(acc_banks, new AccumulatorReadIO(acc_bank_entries, log2Up(accType.getWidth), Vec(meshColumns, Vec(tileColumns, inputType)), Vec(meshColumns, Vec(tileColumns, accType)), acc_scale_args.multiplicand_t))) - // val write = Flipped(Vec(acc_banks, new AccumulatorWriteReq(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType))))) - val write = Flipped(Vec(acc_banks, Decoupled(new AccumulatorWriteReq(acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType)))))) + val read_req = Flipped(Vec(acc_banks, Decoupled(new AccumulatorReadReq( + acc_bank_entries, log2Up(accType.getWidth), acc_scale_args.multiplicand_t + )))) + val read_resp = Vec(acc_banks, Decoupled(new AccumulatorScaleResp( + Vec(meshColumns, Vec(tileColumns, inputType)), + Vec(meshColumns, Vec(tileColumns, accType)) + ))) + val write = Flipped(Vec(acc_banks, Decoupled(new AccumulatorWriteReq( + acc_bank_entries, Vec(meshColumns, Vec(tileColumns, accType)) + )))) } // TLB ports @@ -194,20 +216,48 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Misc. ports val busy = Output(Bool()) val flush = Input(Bool()) + + // for detecting conflicts + val latency_in = Input(UInt(16.W)) + val alert_cycles_in = Input(UInt(6.W)) + val latency_out = Output(UInt(16.W)) + val alert_cycles_out = Output(UInt(6.W)) + val pause_turn_in = Input(UInt(3.W)) + val pause_turn_out = Output(UInt(3.W)) + + //for pausing monitoring + val pause_out = Output(Bool()) }) val write_dispatch_q = Queue(io.dma.write.req) - write_dispatch_q.ready := false.B - + // Write scale queue is necessary to maintain in-order requests to accumulator scale unit + // Writes from main SPAD just flow directly between scale_q and issue_q, while writes + // From acc are ordered + val write_scale_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t), mem_pipeline)) val write_issue_q = Module(new Queue(new ScratchpadMemWriteRequest(local_addr_t), mem_pipeline+1, pipe=true)) val read_issue_q = Module(new Queue(new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), mem_pipeline+1, pipe=true)) // TODO can't this just be a normal queue? + write_scale_q.io.enq.valid := false.B + write_scale_q.io.enq.bits := write_dispatch_q.bits + write_scale_q.io.deq.ready := false.B + write_issue_q.io.enq.valid := false.B - write_issue_q.io.enq.bits := write_dispatch_q.bits + write_issue_q.io.enq.bits := write_scale_q.io.deq.bits + + + // Garbage can immediately fire between dispatch_q and scale_q + when (write_dispatch_q.bits.laddr.is_garbage()) { + write_scale_q.io.enq <> write_dispatch_q + } + // Non-acc or garbage can immediately fire between scale_q and issue_q + when (write_scale_q.io.deq.bits.laddr.is_garbage() || !write_scale_q.io.deq.bits.laddr.is_acc_addr) { + write_issue_q.io.enq <> write_scale_q.io.deq + } + val writeData = Wire(Valid(UInt((spad_w max acc_w).W))) - writeData.valid := false.B + writeData.valid := write_issue_q.io.deq.bits.laddr.is_garbage() writeData.bits := DontCare val fullAccWriteData = Wire(UInt(acc_w.W)) fullAccWriteData := DontCare @@ -215,8 +265,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.read_full_acc_row val writeData_is_all_zeros = write_issue_q.io.deq.bits.laddr.is_garbage() - writer.module.io.req.valid := write_issue_q.io.deq.valid && (writeData.valid || writeData_is_all_zeros) - write_issue_q.io.deq.ready := writer.module.io.req.ready && (writeData.valid || writeData_is_all_zeros) + writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid + write_issue_q.io.deq.ready := writer.module.io.req.ready && writeData.valid writer.module.io.req.bits.vaddr := write_issue_q.io.deq.bits.vaddr writer.module.io.req.bits.len := Mux(writeData_is_full_width, write_issue_q.io.deq.bits.len * (accType.getWidth / 8).U, @@ -225,43 +275,83 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, writeData_is_all_zeros -> 0.U, writeData_is_full_width -> fullAccWriteData )) + writer.module.io.req.bits.block := write_issue_q.io.deq.bits.block writer.module.io.req.bits.status := write_issue_q.io.deq.bits.status writer.module.io.req.bits.pool_en := write_issue_q.io.deq.bits.pool_en writer.module.io.req.bits.store_en := write_issue_q.io.deq.bits.store_en - // FpgaDebug(write_issue_q.io.deq.bits.laddr.data) - // FpgaDebug(write_issue_q.io.deq.bits.laddr.accumulate) - // FpgaDebug(write_issue_q.io.deq.bits.laddr.is_acc_addr) - io.dma.write.resp.valid := false.B io.dma.write.resp.bits.cmd_id := write_dispatch_q.bits.cmd_id + when (write_dispatch_q.bits.laddr.is_garbage() && write_dispatch_q.fire()) { + io.dma.write.resp.valid := true.B + } read_issue_q.io.enq <> io.dma.read.req + val zero_writer = Module(new ZeroWriter(config, new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits))) + + when (io.dma.read.req.bits.all_zeros) { + read_issue_q.io.enq.valid := false.B + io.dma.read.req.ready := zero_writer.io.req.ready + } + + zero_writer.io.req.valid := io.dma.read.req.valid && io.dma.read.req.bits.all_zeros + zero_writer.io.req.bits.laddr := io.dma.read.req.bits.laddr + zero_writer.io.req.bits.cols := io.dma.read.req.bits.cols + zero_writer.io.req.bits.block_stride := io.dma.read.req.bits.block_stride + zero_writer.io.req.bits.tag := io.dma.read.req.bits + + zero_writer.io.resp.ready := false.B + reader.module.io.req.valid := read_issue_q.io.deq.valid read_issue_q.io.deq.ready := reader.module.io.req.ready reader.module.io.req.bits.vaddr := read_issue_q.io.deq.bits.vaddr reader.module.io.req.bits.spaddr := Mux(read_issue_q.io.deq.bits.laddr.is_acc_addr, read_issue_q.io.deq.bits.laddr.full_acc_addr(), read_issue_q.io.deq.bits.laddr.full_sp_addr()) - reader.module.io.req.bits.len := read_issue_q.io.deq.bits.len + reader.module.io.req.bits.len := read_issue_q.io.deq.bits.cols reader.module.io.req.bits.repeats := read_issue_q.io.deq.bits.repeats reader.module.io.req.bits.scale := read_issue_q.io.deq.bits.scale reader.module.io.req.bits.is_acc := read_issue_q.io.deq.bits.laddr.is_acc_addr reader.module.io.req.bits.accumulate := read_issue_q.io.deq.bits.laddr.accumulate reader.module.io.req.bits.has_acc_bitwidth := read_issue_q.io.deq.bits.has_acc_bitwidth + reader.module.io.req.bits.block_stride := read_issue_q.io.deq.bits.block_stride reader.module.io.req.bits.status := read_issue_q.io.deq.bits.status reader.module.io.req.bits.cmd_id := read_issue_q.io.deq.bits.cmd_id + //for bank conflict monitoring + reader.module.io.req.bits.monitor_conflict := read_issue_q.io.deq.bits.monitor_conflict + reader.module.io.req.bits.monitor_conflict_end := read_issue_q.io.deq.bits.monitor_conflict_end + reader.module.io.req.bits.monitor_conflict_start := read_issue_q.io.deq.bits.monitor_conflict_start + reader.module.io.req.bits.profile_conflict_end := read_issue_q.io.deq.bits.profile_conflict_end + reader.module.io.req.bits.profile_conflict_start := read_issue_q.io.deq.bits.profile_conflict_start + reader.module.io.req.bits.profile_conflict := read_issue_q.io.deq.bits.profile_conflict + when(reset.toBool()){ + reader.module.io.req.bits.profile_conflict := false.B + reader.module.io.req.bits.profile_conflict_start := false.B + reader.module.io.req.bits.profile_conflict_end := false.B + reader.module.io.req.bits.monitor_conflict := false.B + reader.module.io.req.bits.monitor_conflict_end := false.B + reader.module.io.req.bits.monitor_conflict_start := false.B + } - val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier(config.mvin_scale_args, config.inputType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = false) - val (mvin_scale_acc_in, mvin_scale_acc_out) = if (mvin_scale_shared) (mvin_scale_in, mvin_scale_out) else - VectorScalarMultiplier(config.mvin_scale_acc_args, config.accType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = true) + val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier( + config.mvin_scale_args, + config.inputType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), + is_acc = false + ) + val (mvin_scale_acc_in, mvin_scale_acc_out) = if (mvin_scale_shared) (mvin_scale_in, mvin_scale_out) else ( + VectorScalarMultiplier( + config.mvin_scale_acc_args, + config.accType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), + is_acc = true + ) + ) mvin_scale_in.valid := reader.module.io.resp.valid && (mvin_scale_shared.B || !reader.module.io.resp.bits.is_acc || (reader.module.io.resp.bits.is_acc && !reader.module.io.resp.bits.has_acc_bitwidth)) mvin_scale_in.bits.in := reader.module.io.resp.bits.data.asTypeOf(chiselTypeOf(mvin_scale_in.bits.in)) mvin_scale_in.bits.scale := reader.module.io.resp.bits.scale.asTypeOf(mvin_scale_t) - mvin_scale_in.bits.repeats := reader.module.io.resp.bits.rows + mvin_scale_in.bits.repeats := reader.module.io.resp.bits.repeats mvin_scale_in.bits.last := reader.module.io.resp.bits.last mvin_scale_in.bits.tag := reader.module.io.resp.bits @@ -272,7 +362,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, (reader.module.io.resp.bits.is_acc && reader.module.io.resp.bits.has_acc_bitwidth) mvin_scale_acc_in.bits.in := reader.module.io.resp.bits.data.asTypeOf(chiselTypeOf(mvin_scale_acc_in.bits.in)) mvin_scale_acc_in.bits.scale := reader.module.io.resp.bits.scale.asTypeOf(mvin_scale_acc_t) - mvin_scale_acc_in.bits.repeats := reader.module.io.resp.bits.rows + mvin_scale_acc_in.bits.repeats := reader.module.io.resp.bits.repeats mvin_scale_acc_in.bits.last := reader.module.io.resp.bits.last mvin_scale_acc_in.bits.tag := reader.module.io.resp.bits @@ -284,9 +374,22 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val mvin_scale_finished = mvin_scale_out.fire() && mvin_scale_out.bits.last val mvin_scale_acc_finished = mvin_scale_acc_out.fire() && mvin_scale_acc_out.bits.last - io.dma.read.resp.valid := mvin_scale_finished || mvin_scale_acc_finished - io.dma.read.resp.bits.cmd_id := Mux(mvin_scale_finished, mvin_scale_out.bits.tag.cmd_id, mvin_scale_acc_out.bits.tag.cmd_id) - io.dma.read.resp.bits.bytesRead := Mux(mvin_scale_finished, mvin_scale_out.bits.tag.bytes_read, mvin_scale_acc_out.bits.tag.bytes_read) + val zero_writer_finished = zero_writer.io.resp.fire() && zero_writer.io.resp.bits.last + + val zero_writer_bytes_read = Mux(zero_writer.io.resp.bits.laddr.is_acc_addr, + zero_writer.io.resp.bits.tag.cols * (accType.getWidth / 8).U, + zero_writer.io.resp.bits.tag.cols * (inputType.getWidth / 8).U) + + // For DMA read responses, mvin_scale gets first priority, then mvin_scale_acc, and then zero_writer + io.dma.read.resp.valid := mvin_scale_finished || mvin_scale_acc_finished || zero_writer_finished + + io.dma.read.resp.bits.cmd_id := MuxCase(zero_writer.io.resp.bits.tag.cmd_id, Seq( + mvin_scale_finished -> mvin_scale_out.bits.tag.cmd_id, + mvin_scale_acc_finished -> mvin_scale_acc_out.bits.tag.cmd_id)) + + io.dma.read.resp.bits.bytesRead := MuxCase(zero_writer_bytes_read, Seq( + mvin_scale_finished -> mvin_scale_out.bits.tag.bytes_read, + mvin_scale_acc_finished -> mvin_scale_acc_out.bits.tag.bytes_read)) io.tlb(0) <> writer.module.io.tlb io.tlb(1) <> reader.module.io.tlb @@ -294,10 +397,19 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, writer.module.io.flush := io.flush reader.module.io.flush := io.flush - io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid + //for monitoring conflicts + io.latency_out := io.latency_in + io.alert_cycles_out := io.alert_cycles_in + io.pause_turn_out := io.pause_turn_in + io.pause_out := reader.module.io.pause_out + reader.module.io.latency_in := io.latency_out + reader.module.io.alert_cycles_in := io.alert_cycles_out + reader.module.io.pause_turn_in := io.pause_turn_out - { - val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(sp_bank_entries, spad_w, mem_pipeline, aligned_to)) } + io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid + + { + val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(sp_bank_entries, spad_w, mem_pipeline, aligned_to, config.sp_singleported)) } val bank_ios = VecInit(banks.map(_.io)) // Getting the output of the bank that's about to be issued to the writer @@ -314,10 +426,12 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_issue_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + !write_dispatch_q.bits.laddr.is_garbage() && + !(bio.write.en && config.sp_singleported.B) && !write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U - bio.read.req.valid := exread || (dmawrite && !write_dispatch_q.bits.laddr.is_garbage()) + bio.read.req.valid := exread || dmawrite ex_read_req.ready := bio.read.req.ready // The ExecuteController gets priority when reading from SRAMs @@ -328,9 +442,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.read.req.bits.addr := write_dispatch_q.bits.laddr.sp_row() bio.read.req.bits.fromDMA := true.B - when (bio.read.req.fire() || write_dispatch_q.bits.laddr.is_garbage()) { + when (bio.read.req.fire()) { write_dispatch_q.ready := true.B - write_issue_q.io.enq.valid := true.B + write_scale_q.io.enq.valid := true.B io.dma.write.resp.valid := true.B } @@ -357,7 +471,13 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val dmaread = mvin_scale_out.valid && !mvin_scale_out.bits.tag.is_acc && laddr.sp_bank() === i.U - bio.write.en := exwrite || dmaread + // We need to make sure that we don't try to return a dma read resp from both zero_writer and either mvin_scale + // or mvin_acc_scale at the same time. The scalers always get priority in those cases + val zerowrite = zero_writer.io.resp.valid && !zero_writer.io.resp.bits.laddr.is_acc_addr && + zero_writer.io.resp.bits.laddr.sp_bank() === i.U && + !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + + bio.write.en := exwrite || dmaread || zerowrite when (exwrite) { bio.write.addr := io.srams.write(i).addr @@ -369,6 +489,17 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.write.mask := mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) mvin_scale_out.ready := true.B // TODO we combinationally couple valid and ready signals + }.elsewhen (zerowrite) { + bio.write.addr := zero_writer.io.resp.bits.laddr.sp_row() + bio.write.data := 0.U + bio.write.mask := { + val n = inputType.getWidth / 8 + val mask = zero_writer.io.resp.bits.mask + val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) + expanded + } + + zero_writer.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals }.otherwise { bio.write.addr := DontCare bio.write.data := DontCare @@ -377,32 +508,64 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, } } + val acc_row_t = Vec(meshColumns, Vec(tileColumns, accType)) + val spad_row_t = Vec(meshColumns, Vec(tileColumns, inputType)) + + val acc_scale_unit = Module(new AccumulatorScale( + acc_row_t, + spad_row_t, + acc_scale_args.multiplicand_t, + log2Up(accType.getWidth), + acc_read_small_width, + acc_read_full_width, + acc_scale_args + )) + + acc_scale_unit.io.in.valid := false.B + acc_scale_unit.io.in.bits := DontCare + val dma_resp_ready = ( + writer.module.io.req.ready && + write_issue_q.io.deq.bits.laddr.is_acc_addr && + !write_issue_q.io.deq.bits.laddr.is_garbage() + ) + acc_scale_unit.io.out.ready := false.B + when (acc_scale_unit.io.out.bits.fromDMA && dma_resp_ready) { + acc_scale_unit.io.out.ready := true.B + writeData.valid := acc_scale_unit.io.out.valid + writeData.bits := acc_scale_unit.io.out.bits.data.asUInt + fullAccWriteData := acc_scale_unit.io.out.bits.full_data.asUInt + } + for (i <- 0 until acc_banks) { + io.acc.read_resp(i).valid := false.B + io.acc.read_resp(i).bits := acc_scale_unit.io.out.bits + when (!acc_scale_unit.io.out.bits.fromDMA && acc_scale_unit.io.out.bits.acc_bank_id === i.U) { + acc_scale_unit.io.out.ready := io.acc.read_resp(i).ready + io.acc.read_resp(i).valid := acc_scale_unit.io.out.valid + } + } + { - val acc_row_t = Vec(meshColumns, Vec(tileColumns, accType)) - val spad_row_t = Vec(meshColumns, Vec(tileColumns, inputType)) - val banks = Seq.fill(acc_banks) { Module(new AccumulatorMem(acc_bank_entries, acc_row_t, spad_row_t, mem_pipeline, acc_scale_args, acc_read_small_width, acc_read_full_width)) } + val banks = Seq.fill(acc_banks) { Module(new AccumulatorMem( + acc_bank_entries, acc_row_t, acc_scale_args, + acc_singleported, num_acc_sub_banks + )) } val bank_ios = VecInit(banks.map(_.io)) // Getting the output of the bank that's about to be issued to the writer val bank_issued_io = bank_ios(write_issue_q.io.deq.bits.laddr.acc_bank()) - when (write_issue_q.io.deq.bits.laddr.is_acc_addr) { - writeData.valid := bank_issued_io.read.resp.valid && bank_issued_io.read.resp.bits.fromDMA - writeData.bits := bank_issued_io.read.resp.bits.data.asUInt() - fullAccWriteData := bank_issued_io.read.resp.bits.full_data.asUInt() - } - // Reading from the Accumulator banks bank_ios.zipWithIndex.foreach { case (bio, i) => - val ex_read_req = io.acc.read(i).req + val ex_read_req = io.acc.read_req(i) val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_issue_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + !write_dispatch_q.bits.laddr.is_garbage() && write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.acc_bank() === i.U - bio.read.req.valid := exread || (dmawrite && !write_dispatch_q.bits.laddr.is_garbage()) + bio.read.req.valid := exread || dmawrite bio.read.req.bits.scale := ex_read_req.bits.scale bio.read.req.bits.relu6_shift := ex_read_req.bits.relu6_shift bio.read.req.bits.act := ex_read_req.bits.act @@ -418,28 +581,41 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.read.req.bits.full := write_dispatch_q.bits.laddr.read_full_acc_row bio.read.req.bits.fromDMA := true.B - when (bio.read.req.fire() || write_dispatch_q.bits.laddr.is_garbage()) { + when (bio.read.req.fire()) { write_dispatch_q.ready := true.B - write_issue_q.io.enq.valid := true.B + write_scale_q.io.enq.valid := true.B io.dma.write.resp.valid := true.B } }.otherwise { bio.read.req.bits := DontCare } + bio.read.resp.ready := false.B + + + when (write_scale_q.io.deq.valid && + acc_scale_unit.io.in.ready && + bio.read.resp.valid && + write_issue_q.io.enq.ready && + write_scale_q.io.deq.bits.laddr.is_acc_addr && + !write_scale_q.io.deq.bits.laddr.is_garbage() && + write_scale_q.io.deq.bits.laddr.acc_bank() === i.U) { + write_scale_q.io.deq.ready := true.B + acc_scale_unit.io.in.valid := true.B + bio.read.resp.ready := true.B + write_issue_q.io.enq.valid := true.B + + acc_scale_unit.io.in.bits := bio.read.resp.bits + acc_scale_unit.io.in.bits.acc_bank_id := i.U + } - val ex_read_resp = io.acc.read(i).resp - val dma_resp_ready = writer.module.io.req.ready && - write_issue_q.io.deq.bits.laddr.is_acc_addr && write_issue_q.io.deq.bits.laddr.acc_bank() === i.U && // I believe we don't need to check that write_issue_q is valid here, because if the accumulator bank's resp is valid, then that means that the write_issue_q's deq should also be valid - !write_issue_q.io.deq.bits.laddr.is_garbage() - - bio.read.resp.ready := Mux(bio.read.resp.bits.fromDMA, dma_resp_ready, ex_read_resp.ready) - ex_read_resp.valid := bio.read.resp.valid // TODO should we AND this with fromDMA? - ex_read_resp.bits := bio.read.resp.bits } // Writing to the accumulator banks bank_ios.zipWithIndex.foreach { case (bio, i) => + // Order of precedence during writes is ExecuteController, and then mvin_scale, and then mvin_scale_acc, and + // then zero_writer + val exwrite = io.acc.write(i).valid io.acc.write(i).ready := true.B assert(!(exwrite && !bio.write.ready), "Execute controller write to AccumulatorMem was skipped") @@ -447,48 +623,85 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val from_mvin_scale = mvin_scale_out.valid && mvin_scale_out.bits.tag.is_acc val from_mvin_scale_acc = mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.tag.is_acc - val mvin_scale_acc_laddr = mvin_scale_acc_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_acc_out.bits.row val mvin_scale_laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + val mvin_scale_acc_laddr = mvin_scale_acc_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_acc_out.bits.row - val dmaread_bank = Mux(from_mvin_scale_acc, mvin_scale_acc_laddr.acc_bank(), - mvin_scale_laddr.acc_bank()) - val dmaread_row = Mux(from_mvin_scale_acc, mvin_scale_acc_laddr.acc_row(), mvin_scale_laddr.acc_row()) + val dmaread_bank = Mux(from_mvin_scale, mvin_scale_laddr.acc_bank(), + mvin_scale_acc_laddr.acc_bank()) + val dmaread_row = Mux(from_mvin_scale, mvin_scale_laddr.acc_row(), mvin_scale_acc_laddr.acc_row()) // We need to make sure that we don't try to return a dma read resp from both mvin_scale and mvin_scale_acc // at the same time. mvin_scale always gets priority in this cases - val mvin_scale_out_last = mvin_scale_out.valid && mvin_scale_out.bits.last + val spad_last = mvin_scale_out.valid && mvin_scale_out.bits.last && !mvin_scale_out.bits.tag.is_acc val dmaread = (from_mvin_scale || from_mvin_scale_acc) && - dmaread_bank === i.U && - (mvin_scale_same.B || from_mvin_scale || !mvin_scale_out_last) + dmaread_bank === i.U /* && + (mvin_scale_same.B || from_mvin_scale || !spad_dmaread_last) */ + + // We need to make sure that we don't try to return a dma read resp from both zero_writer and either mvin_scale + // or mvin_acc_scale at the same time. The scalers always get priority in those cases + val zerowrite = zero_writer.io.resp.valid && zero_writer.io.resp.bits.laddr.is_acc_addr && + zero_writer.io.resp.bits.laddr.acc_bank() === i.U && + !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + val consecutive_write_block = RegInit(false.B) + if (acc_singleported) { + val consecutive_write_sub_bank = RegInit(0.U((1 max log2Ceil(num_acc_sub_banks)).W)) + when (bio.write.fire() && bio.write.bits.acc && + (bio.write.bits.addr(log2Ceil(num_acc_sub_banks)-1,0) === consecutive_write_sub_bank)) { + consecutive_write_block := true.B + } .elsewhen (bio.write.fire() && bio.write.bits.acc) { + consecutive_write_block := false.B + consecutive_write_sub_bank := bio.write.bits.addr(log2Ceil(num_acc_sub_banks)-1,0) + } .otherwise { + consecutive_write_block := false.B + } + } + bio.write.valid := false.B - bio.write.valid := exwrite || dmaread - bio.write.bits.acc := Mux(exwrite, io.acc.write(i).bits.acc, - Mux(from_mvin_scale_acc, mvin_scale_acc_out.bits.tag.accumulate, mvin_scale_out.bits.tag.accumulate)) - bio.write.bits.addr := Mux(exwrite, io.acc.write(i).bits.addr, dmaread_row) + bio.write.bits.acc := MuxCase(zero_writer.io.resp.bits.laddr.accumulate, + Seq(exwrite -> io.acc.write(i).bits.acc, + from_mvin_scale -> mvin_scale_out.bits.tag.accumulate, + from_mvin_scale_acc -> mvin_scale_acc_out.bits.tag.accumulate)) + + bio.write.bits.addr := MuxCase(zero_writer.io.resp.bits.laddr.acc_row(), + Seq(exwrite -> io.acc.write(i).bits.addr, + (from_mvin_scale || from_mvin_scale_acc) -> dmaread_row)) when (exwrite) { + bio.write.valid := true.B bio.write.bits.data := io.acc.write(i).bits.data bio.write.bits.mask := io.acc.write(i).bits.mask - }.elsewhen (dmaread && bio.write.fire()) { - bio.write.bits.data := Mux(from_mvin_scale_acc, - mvin_scale_acc_out.bits.out.asTypeOf(acc_row_t), - VecInit(mvin_scale_out.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t)) + }.elsewhen (dmaread && !spad_last && !consecutive_write_block) { + bio.write.valid := true.B + bio.write.bits.data := Mux(from_mvin_scale, + VecInit(mvin_scale_out.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t), + mvin_scale_acc_out.bits.out.asTypeOf(acc_row_t)) bio.write.bits.mask := - Mux(from_mvin_scale_acc, - mvin_scale_acc_out.bits.tag.mask, + Mux(from_mvin_scale, { val n = accType.getWidth / inputType.getWidth val mask = mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) expanded - }) + }, + mvin_scale_acc_out.bits.tag.mask) - when (from_mvin_scale_acc) { - mvin_scale_acc_out.ready := true.B + when(from_mvin_scale) { + mvin_scale_out.ready := bio.write.ready }.otherwise { - mvin_scale_out.ready := true.B + mvin_scale_acc_out.ready := bio.write.ready + } + }.elsewhen (zerowrite && !spad_last && !consecutive_write_block) { + bio.write.valid := true.B + bio.write.bits.data := 0.U.asTypeOf(acc_row_t) + bio.write.bits.mask := { + val n = accType.getWidth / 8 + val mask = zero_writer.io.resp.bits.mask + val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) + expanded } + + zero_writer.io.resp.ready := bio.write.ready }.otherwise { bio.write.bits.data := DontCare bio.write.bits.mask := DontCare diff --git a/src/main/scala/gemmini/StoreController.scala b/src/main/scala/gemmini/StoreController.scala index 07399a16..98584bca 100644 --- a/src/main/scala/gemmini/StoreController.scala +++ b/src/main/scala/gemmini/StoreController.scala @@ -35,8 +35,13 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val stride = Reg(UInt(coreMaxAddrBits.W)) val block_rows = meshRows * tileRows + val block_stride = block_rows.U + val block_cols = meshColumns * tileColumns + val max_blocks = (dma_maxbytes / (block_cols * inputType.getWidth / 8)) max 1 + //val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) - val row_counter = RegInit(0.U(12.W)) + val row_counter = RegInit(0.U(12.W)) // TODO magic number + val block_counter = RegInit(0.U(8.W)) // TODO magic number // Pooling variables val pool_stride = Reg(UInt(2.W)) // When this is 0, pooling is disabled // TODO magic number @@ -69,8 +74,9 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val cmd = Queue(io.cmd, st_queue_length) val vaddr = cmd.bits.cmd.rs1 val localaddr = cmd.bits.cmd.rs2.asTypeOf(local_addr_t) - val cols = cmd.bits.cmd.rs2(32 + mvout_len_bits - 1, 32) // TODO magic numbers + val cols = cmd.bits.cmd.rs2(32 + mvout_cols_bits - 1, 32) // TODO magic numbers val rows = cmd.bits.cmd.rs2(48 + mvout_rows_bits - 1, 48) // TODO magic numbers + val blocks = (cols / block_cols.U) + (cols % block_cols.U =/= 0.U) val config_stride = cmd.bits.cmd.rs2 val config_pool_stride = cmd.bits.cmd.rs1(5, 4) // TODO magic numbers val config_pool_size = cmd.bits.cmd.rs1(7, 6) // TODO magic numbers @@ -84,7 +90,8 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val mstatus = cmd.bits.cmd.status - val localaddr_plus_row_counter = localaddr + row_counter + val current_vaddr = vaddr + row_counter * stride + val current_localaddr = localaddr + (block_counter * block_stride + row_counter) val pool_row_addr = localaddr + (orow * pool_ocols +& ocol) when (orow_is_negative || ocol_is_negative || orow >= pool_orows || ocol >= pool_ocols) { @@ -106,31 +113,32 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm val rob_id = UInt(log2Up(rob_entries).W) } - val cmd_tracker_max_rows = (block_rows max + val cmd_tracker_max_rows = ((block_rows * max_blocks) max (((1 << pool_orows.getWidth)-1) * ((1 << pool_ocols.getWidth)-1) + 2*((1 << pool_lpad.getWidth)-1) + 2*((1 << pool_upad.getWidth)-1))) min ((config.sp_banks * config.sp_bank_entries) max (config.acc_banks * config.acc_bank_entries)) - val cmd_tracker = Module(new DMAReadCommandTracker(nCmds, cmd_tracker_max_rows, deps_t)) + val cmd_tracker = Module(new DMACommandTracker(nCmds, cmd_tracker_max_rows, deps_t)) // DMA IO wiring io.dma.req.valid := (control_state === waiting_for_command && cmd.valid && DoStore && cmd_tracker.io.alloc.ready) || control_state === waiting_for_dma_req_ready || - (control_state === sending_rows && row_counter =/= 0.U) || // TODO Do we really have to check whether the counters should be 0 here? + (control_state === sending_rows && (block_counter =/= 0.U || row_counter =/= 0.U)) || (control_state === pooling && (wcol_counter =/= 0.U || wrow_counter =/= 0.U || pocol_counter =/= 0.U || porow_counter =/= 0.U)) - io.dma.req.bits.vaddr := Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, vaddr + row_counter * stride) - io.dma.req.bits.laddr := Mux(pooling_is_enabled, pool_row_addr, localaddr_plus_row_counter) //Todo: laddr for 1D? + io.dma.req.bits.vaddr := Mux(pooling_is_enabled || mvout_1d_enabled, pool_vaddr, current_vaddr) + io.dma.req.bits.laddr := Mux(pooling_is_enabled, pool_row_addr, current_localaddr) //Todo: laddr for 1D? - io.dma.req.bits.len := cols + io.dma.req.bits.len := Mux(block_counter === blocks - 1.U, ((cols - 1.U) % block_cols.U) + 1.U, block_cols.U) + io.dma.req.bits.block := block_counter io.dma.req.bits.status := mstatus io.dma.req.bits.pool_en := pooling_is_enabled && (wrow_counter =/= 0.U || wcol_counter =/= 0.U) - io.dma.req.bits.store_en := !pooling_is_enabled || - (wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U) + io.dma.req.bits.store_en := Mux(pooling_is_enabled, wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U, + block_counter === blocks - 1.U) // Command tracker IO cmd_tracker.io.alloc.valid := control_state === waiting_for_command && cmd.valid && DoStore - cmd_tracker.io.alloc.bits.bytes_to_read := Mux(!pooling_is_enabled, Mux(mvout_1d_enabled, mvout_1d_rows, rows), pool_total_rows) // TODO do we have to add upad and lpad to this? + cmd_tracker.io.alloc.bits.bytes_to_read := Mux(!pooling_is_enabled, Mux(mvout_1d_enabled, mvout_1d_rows, rows*blocks), pool_total_rows) // TODO do we have to add upad and lpad to this? cmd_tracker.io.alloc.bits.tag.rob_id := cmd.bits.rob_id.bits cmd_tracker.io.request_returned.valid := io.dma.resp.fire() // TODO use a bundle connect @@ -155,13 +163,18 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm pocol_counter := wrappingAdd(pocol_counter, 1.U, pool_ocols) porow_counter := wrappingAdd(porow_counter, 1.U, pool_orows, pocol_counter === pool_ocols - 1.U) } - row_counter := Mux(mvout_1d_enabled, wrappingAdd(row_counter, 1.U, mvout_1d_rows), wrappingAdd(row_counter, 1.U, rows)) + + block_counter := wrappingAdd(block_counter, 1.U, blocks) + row_counter := Mux(mvout_1d_enabled, wrappingAdd(row_counter, 1.U, mvout_1d_rows), wrappingAdd(row_counter, 1.U, rows, block_counter === blocks - 1.U)) }.otherwise { wcol_counter := wrappingAdd(wcol_counter, 1.U, pool_size) wrow_counter := wrappingAdd(wrow_counter, 1.U, pool_size, wcol_counter === pool_size - 1.U) pocol_counter := wrappingAdd(pocol_counter, 1.U, pool_pocols, wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U) porow_counter := wrappingAdd(porow_counter, 1.U, pool_porows, pocol_counter === pool_pocols - 1.U && wrow_counter === pool_size - 1.U && wcol_counter === pool_size - 1.U) } + + assert(!(io.dma.req.bits.laddr.read_full_acc_row && blocks > 1.U), "Block-mvouts are not permitted when moving out full accumulator data") + assert(!((pooling_is_enabled || mvout_1d_enabled) && blocks > 1.U), "Block-mvouts are not permitted when pooling") } // Control logic @@ -201,11 +214,13 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm } is (sending_rows) { - // TODO Is it really possible for row_counter to be 0 here? - val last_row = row_counter === 0.U || (Mux(mvout_1d_enabled, row_counter === mvout_1d_rows - 1.U, row_counter === rows - 1.U) && io.dma.req.fire()) + val last_block = block_counter === blocks - 1.U && io.dma.req.fire() + val last_row = Mux(mvout_1d_enabled, row_counter === mvout_1d_rows - 1.U, row_counter === rows - 1.U) && io.dma.req.fire() //normal mvout: row, 1D mvout: orows*ocols - when (last_row) { + val only_one_dma_req = block_counter === 0.U && row_counter === 0.U // This is a special case when only one DMA request is made + + when ((last_block && last_row) || only_one_dma_req) { control_state := waiting_for_command cmd.ready := true.B } diff --git a/src/main/scala/gemmini/TagQueue.scala b/src/main/scala/gemmini/TagQueue.scala index e7460b2e..3c516ff0 100644 --- a/src/main/scala/gemmini/TagQueue.scala +++ b/src/main/scala/gemmini/TagQueue.scala @@ -8,45 +8,45 @@ trait TagQueueTag { def make_this_garbage(dummy: Int = 0): Unit } -class TagQueue[T <: TagQueueTag with Data](entries: Int, t: T) extends Module { +class TagQueue[T <: Data with TagQueueTag](t: T, entries: Int) extends Module { val io = IO(new Bundle { - val in = new Bundle { - val valid = Input(Bool()) - val bits = Input(t) - } - - val out = new Bundle { - val next = Input(Bool()) - val bits = Output(Vec(2, t)) - val all = Output(Vec(entries, t)) - } - - // This should really be a constructor parameter, but Chisel errors out when it is - // val garbage = Input(t) + val enq = Flipped(Decoupled(t.cloneType)) + val deq = Decoupled(t.cloneType) + val all = Output(Vec(entries, t.cloneType)) }) - // val regs = RegInit(VecInit(Seq.fill(entries)(io.garbage))) val regs = Reg(Vec(entries, t.cloneType)) - val raddr = RegInit(0.U((log2Ceil(entries) max 1).W)) - val waddr = RegInit(3.U((log2Ceil(entries) max 1).W)) + val raddr = RegInit(0.U(log2Up(entries).W)) + val waddr = RegInit(0.U(log2Up(entries).W)) + val len = RegInit(0.U(log2Up(entries+1).W)) - val raddr_inc = wrappingAdd(raddr, 1.U, entries) - val raddr_inc2 = wrappingAdd(raddr, 2.U, entries) + val empty = len === 0.U + val full = len === entries.U - io.out.bits(0) := Mux(io.out.next, regs(raddr_inc), regs(raddr)) - io.out.bits(1) := Mux(io.out.next, regs(raddr_inc2), regs(raddr_inc)) - io.out.all := regs + io.enq.ready := !full + io.deq.valid := !empty + io.deq.bits := regs(raddr) + io.all := regs - when (io.in.valid) { + when (io.enq.fire()) { + regs(waddr) := io.enq.bits waddr := wrappingAdd(waddr, 1.U, entries) - regs(waddr) := io.in.bits } - when (io.out.next) { - raddr := raddr_inc + when (io.deq.fire()) { + regs(raddr).make_this_garbage() + raddr := wrappingAdd(raddr, 1.U, entries) + } + + when (io.enq.fire() && !io.deq.fire()) { + len := len + 1.U + }.elsewhen(!io.enq.fire() && io.deq.fire()) { + len := len - 1.U } when (reset.toBool()) { regs.foreach(_.make_this_garbage()) } + + assert(len <= entries.U) } diff --git a/src/main/scala/gemmini/Tile.scala b/src/main/scala/gemmini/Tile.scala index 69b606b8..59807893 100644 --- a/src/main/scala/gemmini/Tile.scala +++ b/src/main/scala/gemmini/Tile.scala @@ -12,22 +12,31 @@ import chisel3.util._ * @param rows Number of PEs on each row * @param columns Number of PEs on each column */ -class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: Dataflow.Value, pe_latency: Int, val rows: Int, val columns: Int) extends Module { +class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: Dataflow.Value, pe_latency: Int, max_simultaneous_matmuls: Int, val rows: Int, val columns: Int) extends Module { val io = IO(new Bundle { val in_a = Input(Vec(rows, inputType)) val in_b = Input(Vec(columns, outputType)) // This is the output of the tile next to it val in_d = Input(Vec(columns, outputType)) + val in_control = Input(Vec(columns, new PEControl(accType))) + val in_id = Input(Vec(columns, UInt(log2Up(max_simultaneous_matmuls).W))) + val in_last = Input(Vec(columns, Bool())) + val out_a = Output(Vec(rows, inputType)) val out_c = Output(Vec(columns, outputType)) val out_b = Output(Vec(columns, outputType)) + val out_control = Output(Vec(columns, new PEControl(accType))) + val out_id = Output(Vec(columns, UInt(log2Up(max_simultaneous_matmuls).W))) + val out_last = Output(Vec(columns, Bool())) val in_valid = Input(Vec(columns, Bool())) val out_valid = Output(Vec(columns, Bool())) + + val bad_dataflow = Output(Bool()) }) - val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, pe_latency))) + val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, pe_latency, max_simultaneous_matmuls))) val tileT = tile.transpose // TODO: abstract hori/vert broadcast, all these connections look the same @@ -76,13 +85,34 @@ class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: } } + // Broadcast 'id' vertically across the Tile + for (c <- 0 until columns) { + tileT(c).foldLeft(io.in_id(c)) { + case (id, pe) => + pe.io.in_id := id + pe.io.out_id + } + } + + // Broadcast 'last' vertically across the Tile + for (c <- 0 until columns) { + tileT(c).foldLeft(io.in_last(c)) { + case (last, pe) => + pe.io.in_last := last + pe.io.out_last + } + } + // Drive the Tile's bottom IO for (c <- 0 until columns) { io.out_c(c) := tile(rows-1)(c).io.out_c io.out_b(c) := tile(rows-1)(c).io.out_b io.out_control(c) := tile(rows-1)(c).io.out_control + io.out_id(c) := tile(rows-1)(c).io.out_id + io.out_last(c) := tile(rows-1)(c).io.out_last io.out_valid(c) := tile(rows-1)(c).io.out_valid } + io.bad_dataflow := tile.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_) // Drive the Tile's right IO for (r <- 0 until rows) { diff --git a/src/main/scala/gemmini/Util.scala b/src/main/scala/gemmini/Util.scala index 593d2070..15c15edc 100644 --- a/src/main/scala/gemmini/Util.scala +++ b/src/main/scala/gemmini/Util.scala @@ -35,6 +35,15 @@ object Util { Mux(u +& v > max, max, u + v) } + def satAdd(u: UInt, v: UInt, max_plus_one: UInt, en: Bool = true.B): UInt = { + val max = max_plus_one - 1.U + + MuxCase(u + v, Seq( + (!en) -> u, + ((u +& v) > max) -> max + )) + } + def floorAdd(u: UInt, n: UInt, max_plus_one: UInt, en: Bool = true.B): UInt = { val max = max_plus_one - 1.U @@ -44,6 +53,15 @@ object Util { )) } + def sFloorAdd(s: SInt, n: UInt, max_plus_one: SInt, min: SInt, en: Bool = true.B): SInt = { + val max = max_plus_one - 1.S + + MuxCase(s + n.zext(), Seq( + (!en) -> s, + ((s +& n.zext()) > max) -> min + )) + } + def wrappingSub(u: UInt, n: UInt, max_plus_one: Int): UInt = { val max = max_plus_one - 1 assert(n <= max.U, "cannot wrapSub when n is larger than max") @@ -82,6 +100,11 @@ object Util { Mux(enable, next, buf) } + def RegEnableThru[T <: Data](next: T, init: T, enable: Bool): T = { + val buf = RegEnable(next, init, enable) + Mux(enable, next, buf) + } + def maxOf(u1: UInt, u2: UInt): UInt = { Mux(u1 > u2, u1, u2) } diff --git a/src/main/scala/gemmini/VectorScalarMultiplier.scala b/src/main/scala/gemmini/VectorScalarMultiplier.scala index 4e86f61a..d1cefcb3 100644 --- a/src/main/scala/gemmini/VectorScalarMultiplier.scala +++ b/src/main/scala/gemmini/VectorScalarMultiplier.scala @@ -24,15 +24,33 @@ class VectorScalarMultiplierResp[T <: Data, Tag <: Data](block_cols: Int, t: T, override def cloneType: VectorScalarMultiplierResp.this.type = new VectorScalarMultiplierResp(block_cols, t, tag_t).asInstanceOf[this.type] } -// Currently, this class only supports multiplications of scratchpad inputs, rather than accumulator inputs -// class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data](config: GemminiArrayConfig[T, U], tag_t: Tag) extends Module { - // import config._ - // val block_cols = meshColumns * tileColumns -class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data](mvin_scale_args: Option[ScaleArguments[T, U]], block_cols: Int, t: T, tag_t: Tag) extends Module { - - val u = mvin_scale_args match { - case Some(ScaleArguments(_, _, multiplicand_t, _, _)) => multiplicand_t - case None => Bool() // TODO make this a 0-width UInt +class DataWithIndex[T <: Data, U <: Data](t: T, u: U) extends Bundle { + val data = t.cloneType + val scale = u.cloneType + val id = UInt(2.W) // TODO hardcoded + val index = UInt() + override def cloneType: DataWithIndex.this.type = new DataWithIndex(t, u).asInstanceOf[this.type] +} + +class ScalePipe[T <: Data, U <: Data](t: T, mvin_scale_args: ScaleArguments[T, U]) extends Module { + val u = mvin_scale_args.multiplicand_t + val io = IO(new Bundle { + val in = Input(Valid(new DataWithIndex(t, u))) + val out = Output(Valid(new DataWithIndex(t, u))) + }) + val latency = mvin_scale_args.latency + val out = WireInit(io.in) + out.bits.data := mvin_scale_args.scale_func(io.in.bits.data, io.in.bits.scale.asTypeOf(u)) + io.out := Pipe(out, latency) +} + +class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( + mvin_scale_args: Option[ScaleArguments[T, U]], block_cols: Int, t: T, tag_t: Tag +) extends Module { + + val (u, num_scale_units, always_identity) = mvin_scale_args match { + case Some(ScaleArguments(_, _, multiplicand_t, num_scale_units, _, _)) => (multiplicand_t, num_scale_units, false) + case None => (Bool(), -1, true) // TODO make this a 0-width UInt } val io = IO(new Bundle { @@ -40,48 +58,155 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data](mvin_scale_args: val resp = Decoupled(new VectorScalarMultiplierResp(block_cols, t, tag_t)) }) - val req = Reg(UDValid(chiselTypeOf(io.req.bits))) - - io.req.ready := !req.valid || (req.bits.repeats === 0.U && io.resp.fire()) - io.resp.valid := req.valid - io.resp.bits.tag := req.bits.tag - io.resp.bits.last := req.bits.repeats === 0.U && req.bits.last - io.resp.bits.row := req.bits.repeats - io.resp.bits.out := (mvin_scale_args match { - case Some(ScaleArguments(mvin_scale_func, _, multiplicand_t, _, _)) => - req.bits.in.map(x => mvin_scale_func(x, req.bits.scale.asTypeOf(multiplicand_t))) + val width = block_cols + val latency = mvin_scale_args match { + case Some(ScaleArguments(_, latency, _, _, _, _)) => latency + case None => 0 + } - case None => req.bits.in - }) + val in = Reg(Valid(new VectorScalarMultiplierReq(block_cols, t, u, tag_t))) + val in_fire = WireInit(false.B) + io.req.ready := !in.valid || (in.bits.repeats === 0.U && in_fire) when (io.req.fire()) { - req.push(io.req.bits) - }.elsewhen(io.resp.fire()) { - when (req.bits.repeats === 0.U) { - req.pop() - }.otherwise { - req.bits.repeats := req.bits.repeats - 1.U + in.valid := io.req.valid + in.bits := io.req.bits + } .elsewhen (in_fire) { + when (in.bits.repeats === 0.U) { + in.valid := false.B } + in.bits.repeats := in.bits.repeats - 1.U + } + when (reset.asBool) { + in.valid := false.B } - when (reset.toBool()) { - req.pop() + + if (num_scale_units == -1) { + val pipe = Module(new Pipeline( + new VectorScalarMultiplierResp(block_cols, t, tag_t), + latency + )()) + io.resp <> pipe.io.out + in_fire := pipe.io.in.fire() + + pipe.io.in.valid := in.valid + pipe.io.in.bits.tag := in.bits.tag + pipe.io.in.bits.last := in.bits.repeats === 0.U && in.bits.last + pipe.io.in.bits.row := in.bits.repeats + pipe.io.in.bits.out := (mvin_scale_args match { + case Some(ScaleArguments(mvin_scale_func, _, multiplicand_t, _, _, _)) => + in.bits.in.map(x => mvin_scale_func(x, in.bits.scale.asTypeOf(multiplicand_t))) + case None => in.bits.in + }) + } else { + val nEntries = 3 + val regs = Reg(Vec(nEntries, Valid(new VectorScalarMultiplierReq(block_cols, t, u, tag_t)))) + val out_regs = Reg(Vec(nEntries, new VectorScalarMultiplierResp(block_cols, t, tag_t))) + + val fired_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val completed_masks = Reg(Vec(nEntries, Vec(width, Bool()))) + val head_oh = RegInit(1.U(nEntries.W)) + val tail_oh = RegInit(1.U(nEntries.W)) + + io.resp.valid := Mux1H(head_oh.asBools, (regs zip completed_masks).map({case (r,c) => r.valid && c.reduce(_&&_)})) + io.resp.bits := Mux1H(head_oh.asBools, out_regs) + when (io.resp.fire()) { + for (i <- 0 until nEntries) { + when (head_oh(i)) { + regs(i).valid := false.B + } + } + head_oh := (head_oh << 1) | head_oh(nEntries-1) + } + in_fire := (in.valid && + (!Mux1H(tail_oh.asBools, regs.map(_.valid)) || (tail_oh === head_oh && io.resp.fire())) + ) + when (in_fire) { + for (i <- 0 until nEntries) { + when (tail_oh(i)) { + regs(i).valid := true.B + regs(i).bits := in.bits + out_regs(i).tag := in.bits.tag + out_regs(i).last := in.bits.repeats === 0.U && in.bits.last + out_regs(i).row := in.bits.repeats + out_regs(i).out := in.bits.in + val identity = (u match { + case u: UInt => Arithmetic.UIntArithmetic.cast(u).identity + case s: SInt => Arithmetic.SIntArithmetic.cast(s).identity + case f: Float => Arithmetic.FloatArithmetic.cast(f).identity + case b: Bool => 1.U(1.W) + }) + fired_masks(i).foreach(_ := in.bits.scale.asUInt === identity.asUInt || always_identity.B) + completed_masks(i).foreach(_ := in.bits.scale.asUInt === identity.asUInt || always_identity.B) + } + } + tail_oh := (tail_oh << 1) | tail_oh(nEntries-1) + } + + + + val inputs = Seq.fill(width*nEntries) { Wire(Decoupled(new DataWithIndex(t, u))) } + for (i <- 0 until nEntries) { + for (w <- 0 until width) { + val input = inputs(i*width+w) + input.valid := regs(i).valid && !fired_masks(i)(w) + input.bits.data := regs(i).bits.in(w) + input.bits.scale := regs(i).bits.scale.asTypeOf(u) + input.bits.id := i.U + input.bits.index := w.U + when (input.fire()) { + fired_masks(i)(w) := true.B + } + } + } + for (i <- 0 until num_scale_units) { + val arbIn = inputs.zipWithIndex.filter({ case (_, w) => w % num_scale_units == i }).map(_._1) + val arb = Module(new RRArbiter(new DataWithIndex(t, u), arbIn.length)) + arb.io.in <> arbIn + arb.io.out.ready := true.B + val arbOut = Reg(Valid(new DataWithIndex(t, u))) + arbOut.valid := arb.io.out.valid + arbOut.bits := arb.io.out.bits + when (reset.asBool) { + arbOut.valid := false.B + } + + + val pipe = Module(new ScalePipe(t, mvin_scale_args.get)) + pipe.io.in := arbOut + val pipe_out = pipe.io.out + for (j <- 0 until nEntries) { + for (w <- 0 until width) { + if ((j*width+w) % num_scale_units == i) { + when (pipe_out.fire() && pipe_out.bits.id === j.U && pipe_out.bits.index === w.U) { + out_regs(j).out(w) := pipe_out.bits.data + completed_masks(j)(w) := true.B + } + } + } + } + } + when (reset.asBool) { + regs.foreach(_.valid := false.B) + } + + } + + } object VectorScalarMultiplier { // Returns the input and output IO of the module (together with the pipeline) - def apply[T <: Data, U <: Data, Tag <: Data](scale_args: Option[ScaleArguments[T, U]], t: T, cols: Int, tag_t: Tag, is_acc: Boolean, is_mvin: Boolean=true) = { + def apply[T <: Data, U <: Data, Tag <: Data]( + scale_args: Option[ScaleArguments[T, U]], + t: T, cols: Int, tag_t: Tag, + is_acc: Boolean, + is_mvin: Boolean=true + ) = { assert(!is_acc || is_mvin) - val vsm = Module(new VectorScalarMultiplier(scale_args, cols, t, tag_t)) - - val in = vsm.io.req - val out = scale_args match { - case Some(ScaleArguments(_, latency, _, _, _)) => Pipeline(vsm.io.resp, latency) - case None => vsm.io.resp - } - - (in, out) + (vsm.io.req, vsm.io.resp) } } diff --git a/src/main/scala/gemmini/XactTracker.scala b/src/main/scala/gemmini/XactTracker.scala index af020ed9..9eee539a 100644 --- a/src/main/scala/gemmini/XactTracker.scala +++ b/src/main/scala/gemmini/XactTracker.scala @@ -13,7 +13,8 @@ class XactTrackerEntry[U <: Data](maxShift: Int, spadWidth: Int, accWidth: Int, val accumulate = Bool() val has_acc_bitwidth = Bool() val scale = UInt(mvin_scale_t_bits.W) - val rows = UInt(16.W) // TODO magic number + val repeats = UInt(16.W) // TODO magic number + val block_stride = UInt(16.W) // TODO magic number val spad_row_offset = UInt(log2Up(spadWidth max accWidth).W) val lg_len_req = UInt(log2Up(log2Up(maxReqBytes+1)+1).W) val bytes_to_read = UInt(log2Up(maxReqBytes+1).W) diff --git a/src/main/scala/gemmini/ZeroWriter.scala b/src/main/scala/gemmini/ZeroWriter.scala new file mode 100644 index 00000000..c2e97b36 --- /dev/null +++ b/src/main/scala/gemmini/ZeroWriter.scala @@ -0,0 +1,70 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +import Util._ + +class ZeroWriterReq[Tag <: Data](laddr_t: LocalAddr, max_cols: Int, tag_t: Tag) extends Bundle { + val laddr = laddr_t + val cols = UInt(log2Up(max_cols+1).W) + val block_stride = UInt(16.W) // TODO magic number + val tag = tag_t + + override def cloneType: ZeroWriterReq.this.type = new ZeroWriterReq(laddr_t.cloneType, max_cols, tag_t.cloneType).asInstanceOf[this.type] +} + +class ZeroWriterResp[Tag <: Data](laddr_t: LocalAddr, block_cols: Int, tag_t: Tag) extends Bundle { + val laddr = laddr_t.cloneType + val mask = Vec(block_cols, Bool()) + val last = Bool() + val tag = tag_t + + override def cloneType: ZeroWriterResp.this.type = new ZeroWriterResp(laddr_t, block_cols, tag_t.cloneType).asInstanceOf[this.type] +} + +class ZeroWriter[T <: Data, U <: Data, V <: Data, Tag <: Data](config: GemminiArrayConfig[T, U, V], tag_t: Tag) + extends Module { + import config._ + + val block_cols = meshColumns * tileColumns + val max_cols = (dma_maxbytes / (inputType.getWidth / 8)) max block_cols + + val io = IO(new Bundle { + val req = Flipped(Decoupled(new ZeroWriterReq(local_addr_t, max_cols, tag_t))) + val resp = Decoupled(new ZeroWriterResp(local_addr_t, block_cols, tag_t)) + }) + + val req = Reg(UDValid(new ZeroWriterReq(local_addr_t, max_cols, tag_t))) + + val col_counter = Reg(UInt(log2Up(max_cols).W)) + + io.req.ready := !req.valid + + io.resp.valid := req.valid + io.resp.bits.laddr := req.bits.laddr + req.bits.block_stride * (col_counter / block_cols.U) + io.resp.bits.mask.zipWithIndex.foreach { case (m, i) => m := col_counter + i.U < req.bits.cols } + io.resp.bits.last := col_counter +& block_cols.U >= req.bits.cols + io.resp.bits.tag := req.bits.tag + + when (io.resp.fire()) { + val next_col_counter = floorAdd(col_counter, block_cols.U, req.bits.cols) + + col_counter := next_col_counter + + when (next_col_counter === 0.U) { + req.pop() + io.req.ready := true.B + } + } + + when (io.req.fire()) { + req.push(io.req.bits) + + col_counter := 0.U + } + + when (reset.toBool()) { + req.pop() + } +}