33from cocotb .triggers import ClockCycles , RisingEdge
44import numpy as np
55
6- def get_expected_matmul (A , B ):
7- """
8- Args: lists A, B as flattened row-major matrices
9- """
10- return (np .array (A ).reshape (2 , 2 ) @ np .array (B ).reshape (2 , 2 )).flatten ().tolist ()
6+ def saturate_to_s8 (x ):
7+ """Clamp value to 8-bit signed range [-128, 127]."""
8+ return max (- 128 , min (127 , int (x )))
9+
10+ def get_expected_matmul (A , B , transpose = False , relu = False ):
11+ A_mat = np .array (A ).reshape (2 , 2 )
12+ B_mat = np .array (B ).reshape (2 , 2 )
13+ if transpose :
14+ B_mat = B_mat .T
15+ result = A_mat @ B_mat
16+ if relu :
17+ result = np .maximum (result , 0 )
18+ return [saturate_to_s8 (val ) for val in result .flatten ().tolist ()]
1119
1220async def load_matrix (dut , matrix , sel ):
1321 """
@@ -25,19 +33,65 @@ async def load_matrix(dut, matrix, sel):
2533 dut .uio_in .value = 0
2634 await RisingEdge (dut .clk )
2735
28- async def read_signed_output (dut ):
29- # Wait for first outputs to propagate
30- await ClockCycles (dut .clk , 3 )
36+ async def read_signed_output (dut , transpose = 0 , relu = 0 ):
37+ # Apply instruction signal just before reading
38+ for i in range (3 ):
39+ dut .uio_in .value = (transpose << 1 ) | (relu << 2 )
40+ await ClockCycles (dut .clk , 1 )
41+
3142 results = []
3243 for i in range (4 ):
33- dut .uio_in .value = 0
44+ dut .uio_in .value = ( transpose << 1 ) | ( relu << 2 )
3445 await ClockCycles (dut .clk , 1 )
3546 val_unsigned = dut .uo_out .value .integer
3647 val_signed = val_unsigned if val_unsigned < 128 else val_unsigned - 256
3748 results .append (val_signed )
3849 dut ._log .info (f"Read C[{ i // 2 } ][{ i % 2 } ] = { val_signed } " )
3950 return results
4051
52+ @cocotb .test ()
53+ async def test_relu_transpose (dut ):
54+ dut ._log .info ("Start" )
55+ clock = Clock (dut .clk , 10 , units = "us" )
56+ cocotb .start_soon (clock .start ())
57+
58+ # Reset
59+ dut .ena .value = 1
60+ dut .ui_in .value = 0
61+ dut .uio_in .value = 0
62+ dut .rst_n .value = 0
63+ await ClockCycles (dut .clk , 5 )
64+ dut .rst_n .value = 1
65+ await ClockCycles (dut .clk , 5 )
66+
67+ A = [5 , - 6 , 7 , 8 ] # row-major
68+ B = [8 , 9 , 6 , 8 ] # row-major: [B00, B01, B10, B11]
69+
70+ await load_matrix (dut , A , sel = 0 )
71+ await load_matrix (dut , B , sel = 1 )
72+
73+ expected = get_expected_matmul (A , B , transpose = False , relu = True )
74+ results = await read_signed_output (dut , transpose = 0 , relu = 1 )
75+
76+ for i in range (4 ):
77+ assert results [i ] == expected [i ], f"C[{ i // 2 } ][{ i % 2 } ] = { results [i ]} != expected { expected [i ]} "
78+
79+ dut ._log .info ("First part passed" )
80+
81+ A = [1 , 2 , 3 , 4 ]
82+ B = [5 , 6 , 7 , 8 ]
83+
84+ await load_matrix (dut , A , sel = 0 )
85+ await load_matrix (dut , B , sel = 1 )
86+
87+ expected = get_expected_matmul (A , B , transpose = True , relu = True )
88+ results = await read_signed_output (dut , transpose = 1 , relu = 1 )
89+
90+ for i in range (4 ):
91+ assert results [i ] == expected [i ], f"C[{ i // 2 } ][{ i % 2 } ] = { results [i ]} != expected { expected [i ]} "
92+
93+ dut ._log .info ("ReLU + Transpose test passed!" )
94+
4195@cocotb .test ()
4296async def test_numeric_limits (dut ):
4397 dut ._log .info ("Start" )
@@ -54,7 +108,25 @@ async def test_numeric_limits(dut):
54108 await ClockCycles (dut .clk , 5 )
55109
56110 A = [5 , - 6 , 7 , 8 ] # row-major
57- B = [8 , 9 , 6 , - 7 ] # row-major: [B00, B01, B10, B11]
111+ B = [8 , 12 , 9 , - 7 ] # row-major: [B00, B01, B10, B11]
112+
113+ await load_matrix (dut , A , sel = 0 )
114+ await load_matrix (dut , B , sel = 1 )
115+
116+ expected = get_expected_matmul (A , B )
117+ results = []
118+
119+ # Wait for systolic array to compute
120+
121+ results = await read_signed_output (dut )
122+
123+ for i in range (4 ):
124+ assert results [i ] == expected [i ], f"C[{ i // 2 } ][{ i % 2 } ] = { results [i ]} != expected { expected [i ]} "
125+
126+ dut ._log .info ("Passed large positive values" )
127+
128+ A = [5 , - 6 , 7 , 8 ] # row-major
129+ B = [8 , - 12 , 9 , - 7 ] # row-major: [B00, B01, B10, B11]
58130
59131 await load_matrix (dut , A , sel = 0 )
60132 await load_matrix (dut , B , sel = 1 )
0 commit comments