Skip to content

Commit 30ea560

Browse files
Fused Transpose + ReLU (#2)
* relu and transpose * widen accumulation to 12 bits, improves precision (e.g. if a negative value is subtracted from an overflowed value that is in 12 bits, you can be in the range of 8 bits and obtain a correct value)
1 parent 4398a95 commit 30ea560

6 files changed

Lines changed: 144 additions & 46 deletions

File tree

info.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ pinout:
4949

5050
# Bidirectional pins
5151
uio[0]: "LOAD_EN (input)"
52-
uio[1]: "Unused"
53-
uio[2]: "Unused"
52+
uio[1]: "TRANSPOSE (input)"
53+
uio[2]: "ACTIVATION (input)"
5454
uio[3]: "Unused"
5555
uio[4]: "Unused"
5656
uio[5]: "Unused"

src/PE.v

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,31 @@ module PE #(
33
)(
44
input wire clk,
55
input wire rst,
6-
input wire clear, // clears accumulators between computations...
7-
6+
input wire clear,
87
input wire signed [WIDTH-1:0] a_in,
98
input wire signed [WIDTH-1:0] b_in,
109

1110
output reg signed [WIDTH-1:0] a_out,
1211
output reg signed [WIDTH-1:0] b_out,
1312

14-
output reg signed [WIDTH-1:0] c_out
13+
output reg signed [WIDTH+3:0] c_out
1514
);
1615

1716
always @(posedge clk or posedge rst) begin
1817
if (rst) begin
19-
c_out <= 0;
20-
a_out <= 0;
21-
b_out <= 0;
18+
a_out <= 0;
19+
b_out <= 0;
20+
c_out <= 0;
2221
end else if (clear) begin
23-
c_out <= 0;
24-
a_out <= 0;
25-
b_out <= 0;
22+
a_out <= 0;
23+
b_out <= 0;
24+
c_out <= 0;
2625
end else begin
27-
c_out <= c_out + a_in * b_in;
28-
a_out <= a_in;
29-
b_out <= b_in;
26+
a_out <= a_in;
27+
b_out <= b_in;
28+
29+
c_out <= c_out + (a_in * b_in);
3030
end
3131
end
32+
3233
endmodule

src/mmu_feeder.v

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ module mmu_feeder (
66
input wire en,
77
input wire [2:0] mmu_cycle,
88

9+
input wire transpose,
10+
911
/* Memory module interface */
1012
input wire [7:0] weight0, weight1, weight2, weight3,
1113
input wire [7:0] input0, input1, input2, input3,
1214

1315
/* systolic array -> feeder */
14-
input wire [7:0] c00, c01, c10, c11,
16+
input wire signed [11:0] c00, c01, c10, c11,
1517

1618
/* feeder -> mmu */
1719
output reg clear,
@@ -31,6 +33,18 @@ module mmu_feeder (
3133
// Output counter for selecting c_out
3234
reg [1:0] output_count;
3335

36+
function [7:0] saturate_to_s8;
37+
input signed [11:0] val;
38+
begin
39+
if (val > 127)
40+
saturate_to_s8 = 8'sd127;
41+
else if (val < -128)
42+
saturate_to_s8 = -8'sd128;
43+
else
44+
saturate_to_s8 = val[7:0];
45+
end
46+
endfunction
47+
3448
// Sequential logic for control and data outputs
3549
always @(posedge clk or posedge rst) begin
3650
if (rst) begin
@@ -55,8 +69,6 @@ module mmu_feeder (
5569
end else begin
5670
output_count <= 0;
5771
end
58-
59-
// Input assignments based on mmu_cycle
6072
case (mmu_cycle)
6173
3'b000: begin
6274
a_data0 <= weight0;
@@ -65,8 +77,13 @@ module mmu_feeder (
6577
3'b001: begin
6678
a_data0 <= weight1;
6779
a_data1 <= weight2;
68-
b_data0 <= input2;
69-
b_data1 <= input1;
80+
if (transpose) begin
81+
b_data0 <= input1;
82+
b_data1 <= input2;
83+
end else begin
84+
b_data0 <= input2;
85+
b_data1 <= input1;
86+
end
7087
end
7188
3'b010: begin
7289
a_data1 <= weight3;
@@ -86,10 +103,10 @@ module mmu_feeder (
86103
host_outdata = 8'b0; // Default to avoid latch
87104
if (en) begin
88105
case (output_count)
89-
2'b00: host_outdata = c00;
90-
2'b01: host_outdata = c01;
91-
2'b10: host_outdata = c10;
92-
2'b11: host_outdata = c11;
106+
2'b00: host_outdata = saturate_to_s8(c00);
107+
2'b01: host_outdata = saturate_to_s8(c01);
108+
2'b10: host_outdata = saturate_to_s8(c10);
109+
2'b11: host_outdata = saturate_to_s8(c11);
93110
default: host_outdata = 8'b0;
94111
endcase
95112
end

src/systolic_array_2x2.v

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@ module systolic_array_2x2 #(
55
input wire rst,
66
input wire clear,
77

8+
input wire activation,
9+
810
input wire [WIDTH-1:0] a_data0,
911
input wire [WIDTH-1:0] a_data1,
1012
input wire [WIDTH-1:0] b_data0,
1113
input wire [WIDTH-1:0] b_data1,
1214

13-
output wire [WIDTH-1:0] c00,
14-
output wire [WIDTH-1:0] c01,
15-
output wire [WIDTH-1:0] c10,
16-
output wire [WIDTH-1:0] c11
15+
output wire [WIDTH+3:0] c00,
16+
output wire [WIDTH+3:0] c01,
17+
output wire [WIDTH+3:0] c10,
18+
output wire [WIDTH+3:0] c11
1719
);
1820

1921
// Internal signals between PEs
2022
wire [WIDTH-1:0] a_wire [0:1][0:2];
2123
wire [WIDTH-1:0] b_wire [0:2][0:1];
22-
wire [7:0] c_array [0:1][0:1];
24+
wire signed [WIDTH+3:0] c_array [0:1][0:1];
2325

2426
// Input loading at top-left
2527
assign a_wire[0][0] = a_data0;
@@ -31,7 +33,7 @@ module systolic_array_2x2 #(
3133
generate
3234
for (i = 0; i < 2; i = i + 1) begin : row
3335
for (j = 0; j < 2; j = j + 1) begin : col
34-
PE #(.WIDTH(8)) pe_inst (
36+
PE #(.WIDTH(WIDTH)) pe_inst (
3537
.clk(clk),
3638
.rst(rst),
3739
.clear(clear),
@@ -45,8 +47,10 @@ module systolic_array_2x2 #(
4547
end
4648
endgenerate
4749

48-
assign c00 = c_array[0][0];
49-
assign c01 = c_array[0][1];
50-
assign c10 = c_array[1][0];
51-
assign c11 = c_array[1][1];
52-
endmodule
50+
// Combinational logic for output with optional ReLU
51+
assign c00 = activation ? (c_array[0][0] < 0 ? 0 : c_array[0][0]) : c_array[0][0];
52+
assign c01 = activation ? (c_array[0][1] < 0 ? 0 : c_array[0][1]) : c_array[0][1];
53+
assign c10 = activation ? (c_array[1][0] < 0 ? 0 : c_array[1][0]) : c_array[1][0];
54+
assign c11 = activation ? (c_array[1][1] < 0 ? 0 : c_array[1][1]) : c_array[1][1];
55+
56+
endmodule

src/tpu.v

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ module tt_um_tpu (
1717
);
1818

1919
wire instruction = uio_in[0];
20+
wire transpose = uio_in[1];
21+
wire activation = uio_in[2];
2022

2123
wire compute_en; // internal signal
2224
reg clear; // reset of PEs only
@@ -28,7 +30,7 @@ module tt_um_tpu (
2830
wire [7:0] weight0, weight1, weight2, weight3;
2931
wire [7:0] input0, input1, input2, input3;
3032

31-
wire [7:0] outputs [0:3]; // raw accumulations (16-bit)
33+
wire [11:0] outputs [0:3]; // raw accumulations (16-bit)
3234
wire [7:0] out_data; // sent to CPU
3335
// Ports of the systolic Array
3436
wire [7:0] a_data0, b_data0, a_data1, b_data1;
@@ -60,6 +62,7 @@ module tt_um_tpu (
6062
.clk(clk),
6163
.rst(~rst_n),
6264
.clear(clear),
65+
.activation(activation),
6366
.a_data0(a_data0),
6467
.a_data1(a_data1),
6568
.b_data0(b_data0),
@@ -75,6 +78,7 @@ module tt_um_tpu (
7578
.rst(~rst_n),
7679
.en(compute_en),
7780
.mmu_cycle(mmu_cycle),
81+
.transpose(transpose),
7882
.weight0(weight0), .weight1(weight1), .weight2(weight2), .weight3(weight3),
7983
.input0(input0), .input1(input1), .input2(input2), .input3(input3),
8084
.c00(outputs[0]),
@@ -94,6 +98,6 @@ module tt_um_tpu (
9498
assign uio_out = {done, 7'b0};
9599
assign uio_oe = 8'b10000000;
96100

97-
wire _unused = &{ena, uio_in[7:1]};
101+
wire _unused = &{ena, uio_in[7:3]};
98102

99103
endmodule

test/test.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@
33
from cocotb.triggers import ClockCycles, RisingEdge
44
import numpy as np
55

6-
def get_expected_matmul(A, B):
7-
"""
8-
Args: lists A, B as flattened row-major matrices
9-
"""
10-
return (np.array(A).reshape(2, 2) @ np.array(B).reshape(2, 2)).flatten().tolist()
6+
def saturate_to_s8(x):
7+
"""Clamp value to 8-bit signed range [-128, 127]."""
8+
return max(-128, min(127, int(x)))
9+
10+
def get_expected_matmul(A, B, transpose=False, relu=False):
11+
A_mat = np.array(A).reshape(2, 2)
12+
B_mat = np.array(B).reshape(2, 2)
13+
if transpose:
14+
B_mat = B_mat.T
15+
result = A_mat @ B_mat
16+
if relu:
17+
result = np.maximum(result, 0)
18+
return [saturate_to_s8(val) for val in result.flatten().tolist()]
1119

1220
async def load_matrix(dut, matrix, sel):
1321
"""
@@ -25,19 +33,65 @@ async def load_matrix(dut, matrix, sel):
2533
dut.uio_in.value = 0
2634
await RisingEdge(dut.clk)
2735

28-
async def read_signed_output(dut):
29-
# Wait for first outputs to propagate
30-
await ClockCycles(dut.clk, 3)
36+
async def read_signed_output(dut, transpose=0, relu=0):
37+
# Apply instruction signal just before reading
38+
for i in range(3):
39+
dut.uio_in.value = (transpose << 1) | (relu << 2)
40+
await ClockCycles(dut.clk, 1)
41+
3142
results = []
3243
for i in range(4):
33-
dut.uio_in.value = 0
44+
dut.uio_in.value = (transpose << 1) | (relu << 2)
3445
await ClockCycles(dut.clk, 1)
3546
val_unsigned = dut.uo_out.value.integer
3647
val_signed = val_unsigned if val_unsigned < 128 else val_unsigned - 256
3748
results.append(val_signed)
3849
dut._log.info(f"Read C[{i//2}][{i%2}] = {val_signed}")
3950
return results
4051

52+
@cocotb.test()
53+
async def test_relu_transpose(dut):
54+
dut._log.info("Start")
55+
clock = Clock(dut.clk, 10, units="us")
56+
cocotb.start_soon(clock.start())
57+
58+
# Reset
59+
dut.ena.value = 1
60+
dut.ui_in.value = 0
61+
dut.uio_in.value = 0
62+
dut.rst_n.value = 0
63+
await ClockCycles(dut.clk, 5)
64+
dut.rst_n.value = 1
65+
await ClockCycles(dut.clk, 5)
66+
67+
A = [5, -6, 7, 8] # row-major
68+
B = [8, 9, 6, 8] # row-major: [B00, B01, B10, B11]
69+
70+
await load_matrix(dut, A, sel=0)
71+
await load_matrix(dut, B, sel=1)
72+
73+
expected = get_expected_matmul(A, B, transpose=False, relu=True)
74+
results = await read_signed_output(dut, transpose=0, relu=1)
75+
76+
for i in range(4):
77+
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
78+
79+
dut._log.info("First part passed")
80+
81+
A = [1, 2, 3, 4]
82+
B = [5, 6, 7, 8]
83+
84+
await load_matrix(dut, A, sel=0)
85+
await load_matrix(dut, B, sel=1)
86+
87+
expected = get_expected_matmul(A, B, transpose=True, relu=True)
88+
results = await read_signed_output(dut, transpose=1, relu=1)
89+
90+
for i in range(4):
91+
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
92+
93+
dut._log.info("ReLU + Transpose test passed!")
94+
4195
@cocotb.test()
4296
async def test_numeric_limits(dut):
4397
dut._log.info("Start")
@@ -54,7 +108,25 @@ async def test_numeric_limits(dut):
54108
await ClockCycles(dut.clk, 5)
55109

56110
A = [5, -6, 7, 8] # row-major
57-
B = [8, 9, 6, -7] # row-major: [B00, B01, B10, B11]
111+
B = [8, 12, 9, -7] # row-major: [B00, B01, B10, B11]
112+
113+
await load_matrix(dut, A, sel=0)
114+
await load_matrix(dut, B, sel=1)
115+
116+
expected = get_expected_matmul(A, B)
117+
results = []
118+
119+
# Wait for systolic array to compute
120+
121+
results = await read_signed_output(dut)
122+
123+
for i in range(4):
124+
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
125+
126+
dut._log.info("Passed large positive values")
127+
128+
A = [5, -6, 7, 8] # row-major
129+
B = [8, -12, 9, -7] # row-major: [B00, B01, B10, B11]
58130

59131
await load_matrix(dut, A, sel=0)
60132
await load_matrix(dut, B, sel=1)

0 commit comments

Comments
 (0)