Skip to content

Commit d1a49fe

Browse files
relu works
1 parent d748b98 commit d1a49fe

3 files changed

Lines changed: 22 additions & 34 deletions

File tree

src/PE.v

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ module PE #(
44
input wire clk,
55
input wire rst,
66
input wire clear,
7-
input wire transpose_en, // Enables fused transpose
8-
input wire relu_en, // Enables fused ReLU
9-
7+
input wire relu_en,
108
input wire signed [WIDTH-1:0] a_in,
119
input wire signed [WIDTH-1:0] b_in,
1210

@@ -16,8 +14,6 @@ module PE #(
1614
output reg signed [WIDTH-1:0] c_out
1715
);
1816

19-
reg signed [WIDTH-1:0] product;
20-
2117
always @(posedge clk or posedge rst) begin
2218
if (rst) begin
2319
a_out <= 0;
@@ -28,18 +24,10 @@ module PE #(
2824
b_out <= 0;
2925
c_out <= 0;
3026
end else begin
31-
// Fused transpose
32-
a_out <= transpose_en ? b_in : a_in;
33-
b_out <= transpose_en ? a_in : b_in;
34-
35-
// Compute product
36-
product = a_in * b_in;
27+
a_out <= a_in;
28+
b_out <= b_in;
3729

38-
// Apply ReLU if enabled
39-
if (relu_en && product < 0)
40-
c_out <= 0;
41-
else
42-
c_out <= c_out + product;
30+
c_out <= c_out + (a_in * b_in);
4331
end
4432
end
4533

src/systolic_array_2x2.v

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ module systolic_array_2x2 #(
2222
// Internal signals between PEs
2323
wire [WIDTH-1:0] a_wire [0:1][0:2];
2424
wire [WIDTH-1:0] b_wire [0:2][0:1];
25-
wire [7:0] c_array [0:1][0:1];
25+
wire signed [WIDTH-1:0] c_array [0:1][0:1];
2626

2727
// Input loading at top-left
2828
assign a_wire[0][0] = a_data0;
@@ -34,11 +34,10 @@ module systolic_array_2x2 #(
3434
generate
3535
for (i = 0; i < 2; i = i + 1) begin : row
3636
for (j = 0; j < 2; j = j + 1) begin : col
37-
PE #(.WIDTH(8)) pe_inst (
37+
PE #(.WIDTH(WIDTH)) pe_inst (
3838
.clk(clk),
3939
.rst(rst),
4040
.clear(clear),
41-
.transpose_en(transpose),
4241
.relu_en(activation),
4342
.a_in(a_wire[i][j]),
4443
.b_in(b_wire[i][j]),
@@ -50,8 +49,10 @@ module systolic_array_2x2 #(
5049
end
5150
endgenerate
5251

53-
assign c00 = c_array[0][0];
54-
assign c01 = c_array[0][1];
55-
assign c10 = c_array[1][0];
56-
assign c11 = c_array[1][1];
57-
endmodule
52+
// Combinational logic for output with optional ReLU
53+
assign c00 = activation ? (c_array[0][0] < 0 ? 0 : c_array[0][0]) : c_array[0][0];
54+
assign c01 = activation ? (c_array[0][1] < 0 ? 0 : c_array[0][1]) : c_array[0][1];
55+
assign c10 = activation ? (c_array[1][0] < 0 ? 0 : c_array[1][0]) : c_array[1][0];
56+
assign c11 = activation ? (c_array[1][1] < 0 ? 0 : c_array[1][1]) : c_array[1][1];
57+
58+
endmodule

test/test.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,15 @@ async def load_matrix(dut, matrix, sel):
2929
dut.uio_in.value = 0
3030
await RisingEdge(dut.clk)
3131

32-
async def read_signed_output(dut, transpose=False, relu=False):
32+
async def read_signed_output(dut, transpose=0, relu=0):
3333
# Apply instruction signal just before reading
34-
dut.uio_in.value = 0b00000001 | (transpose << 1) | (relu << 2)
35-
await ClockCycles(dut.clk, 1)
36-
dut.uio_in.value = 0
37-
await ClockCycles(dut.clk, 2) # allow systolic array to compute
34+
for i in range(3):
35+
dut.uio_in.value = (transpose << 1) | (relu << 2)
36+
await ClockCycles(dut.clk, 1)
3837

3938
results = []
4039
for i in range(4):
41-
dut.uio_in.value = 0 # Read mode
40+
dut.uio_in.value = (transpose << 1) | (relu << 2)
4241
await ClockCycles(dut.clk, 1)
4342
val_unsigned = dut.uo_out.value.integer
4443
val_signed = val_unsigned if val_unsigned < 128 else val_unsigned - 256
@@ -61,14 +60,14 @@ async def test_relu_transpose(dut):
6160
dut.rst_n.value = 1
6261
await ClockCycles(dut.clk, 5)
6362

64-
A = [1, -2, 3, -4] # row-major
65-
B = [5, 6, -7, 8] # row-major
63+
A = [5, -6, 7, 8] # row-major
64+
B = [8, 9, 6, 8] # row-major: [B00, B01, B10, B11]
6665

6766
await load_matrix(dut, A, sel=0)
6867
await load_matrix(dut, B, sel=1)
6968

70-
expected = get_expected_matmul(A, B, transpose=True, relu=True)
71-
results = await read_signed_output(dut)
69+
expected = get_expected_matmul(A, B, transpose=False, relu=True)
70+
results = await read_signed_output(dut, relu=1)
7271

7372
for i in range(4):
7473
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"

0 commit comments

Comments
 (0)