Skip to content

Commit c1aa6b4

Browse files
add test for bigger matrix multiplication
1 parent bbd43ac commit c1aa6b4

1 file changed

Lines changed: 189 additions & 25 deletions

File tree

test/tpu/test_tpu.py

Lines changed: 189 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_expected_matmul(A, B, transpose=False, relu=False):
1717
result = np.maximum(result, 0)
1818
return [saturate_to_s8(val) for val in result.flatten().tolist()]
1919

20-
async def load_matrix(dut, matrix, sel, transpose=0, relu=0):
20+
async def load_matrix(dut, matrix, transpose=0, relu=0):
2121
"""
2222
Load a 2x2 matrix into the DUT.
2323
@@ -42,26 +42,155 @@ async def read_signed_output(dut, transpose=0, relu=0):
4242
dut._log.info(f"Read C[{i//2}][{i%2}] = {val_signed}")
4343
return results
4444

45+
def get_expected_large_matmul(A, B):
46+
A_saturated = np.vectorize(saturate_to_s8)(A)
47+
B_saturated = np.vectorize(saturate_to_s8)(B)
48+
49+
m, n = A.shape
50+
n_b, p = B.shape
51+
assert n == n_b, "Incompatible dimensions"
52+
53+
# Pad dimensions to multiple of 2
54+
m_padded = ((m + 1) // 2) * 2
55+
n_padded = ((n + 1) // 2) * 2
56+
p_padded = ((p + 1) // 2) * 2
57+
58+
A_pad = np.zeros((m_padded, n_padded), dtype=int)
59+
B_pad = np.zeros((n_padded, p_padded), dtype=int)
60+
A_pad[:m, :n] = A_saturated
61+
B_pad[:n, :p] = B_saturated
62+
63+
# Initialize output accumulator with 32-bit int to avoid overflow
64+
result = np.zeros((m_padded, p_padded), dtype=int)
65+
66+
for i in range(0, m_padded, 2):
67+
for j in range(0, p_padded, 2):
68+
# Clear PE accumulators for this 2x2 block:
69+
# We emulate the reset in hardware by zeroing local accumulators.
70+
acc_block = np.zeros((2, 2), dtype=int)
71+
72+
for k in range(0, n_padded, 2):
73+
# Extract 2x2 sub-blocks
74+
A_block = A_pad[i:i+2, k:k+2]
75+
B_block = B_pad[k:k+2, j:j+2]
76+
77+
# Multiply elementwise (each element 8-bit saturated)
78+
# Resulting products are 16-bit signed integers
79+
products = np.zeros((2, 2, 2, 2), dtype=int) # (A-row, A-col, B-row, B-col)
80+
for a_r in range(2):
81+
for a_c in range(2):
82+
for b_r in range(2):
83+
for b_c in range(2):
84+
if (i + a_r) < m and (k + a_c) < n and (k + b_r) < n and (j + b_c) < p:
85+
# Only valid indices contribute
86+
prod = A_block[a_r, a_c] * B_block[b_r, b_c]
87+
products[a_r, a_c, b_r, b_c] = prod
88+
else:
89+
products[a_r, a_c, b_r, b_c] = 0
90+
91+
# Now sum over the products for matrix multiplication partial sums:
92+
# The 2x2 block output c_ij = sum over k of a_ik * b_kj
93+
# For each output element in acc_block:
94+
for r in range(2):
95+
for c in range(2):
96+
# sum over a_c (which indexes over k dimension) matching b_r index
97+
partial_sum = 0
98+
for inner in range(2): # inner loop over 2 elements in block dimension
99+
partial_sum += products[r, inner, inner, c]
100+
101+
# Accumulate partial sums with 12-bit wrap (simulate PE accumulator)
102+
acc_block[r, c] = (acc_block[r, c] + saturate_to_s8(partial_sum))
103+
104+
# After summing all k-blocks, saturate final 12-bit values to signed 8-bit and store
105+
for r in range(2):
106+
for c in range(2):
107+
if (i + r) < m and (j + c) < p:
108+
result[i + r, j + c] = acc_block[r, c]
109+
110+
# Return only the original shape
111+
return result[:m, :p]
112+
113+
def check_expected(A, B, result):
114+
"""
115+
Check DUT results against expected matrix multiplication, for big matrices
116+
"""
117+
print(A @ B)
118+
print(result)
119+
expected = get_expected_large_matmul(A, B)
120+
print(expected)
121+
np.testing.assert_array_equal(result, expected, err_msg="Matrix multiplication result does not match expected")
122+
123+
async def read_matrix_output(dut, results_large, i, j, transpose=0, relu=0):
124+
for k in range(4):
125+
dut.uio_in.value = (transpose << 1) | (relu << 2)
126+
await ClockCycles(dut.clk, 1)
127+
val_unsigned = dut.uo_out.value.integer
128+
val_signed = val_unsigned if val_unsigned < 128 else val_unsigned - 256
129+
row = i + (k // 2)
130+
col = j + (k % 2)
131+
results_large[row, col] += val_signed
132+
dut._log.info(f"Read C[{row}][{col}] = {val_signed}")
133+
134+
async def matmul(dut, A, B):
135+
"""
136+
Perform matrix multiplication on DUT for matrices of arbitrary dimensions.
137+
"""
138+
m, n = A.shape
139+
n_b, p = B.shape
140+
assert n == n_b, "Matrix dimensions must be compatible for multiplication"
141+
142+
# Compute padded dimensions (next multiple of 2)
143+
m_padded = ((m + 1) // 2) * 2
144+
n_padded = ((n + 1) // 2) * 2
145+
p_padded = ((p + 1) // 2) * 2
146+
147+
# Pad matrices with zeros
148+
A_padded = np.zeros((m_padded, n_padded), dtype=int)
149+
B_padded = np.zeros((n_padded, p_padded), dtype=int)
150+
A_padded[:m, :n] = A
151+
B_padded[:n, :p] = B
152+
153+
# Initialize result matrix
154+
results_large = np.zeros((m_padded, p_padded), dtype=int)
155+
156+
# Process in 2x2 blocks
157+
for i in range(0, m_padded, 2):
158+
for j in range(0, p_padded, 2):
159+
for k in range(0, n_padded, 2):
160+
# Extract 2x2 blocks
161+
A_block = A_padded[i:i+2, k:k+2].flatten().tolist()
162+
B_block = B_padded[k:k+2, j:j+2].flatten().tolist()
163+
164+
# Load blocks into DUT
165+
await load_matrix(dut, A_block)
166+
await load_matrix(dut, B_block)
167+
168+
# Read partial result directly into results_large
169+
await read_matrix_output(dut, results_large, i, j)
170+
171+
# Return valid result (m x p)
172+
return results_large[:m, :p]
173+
45174
@cocotb.test()
46175
async def test_relu_transpose(dut):
47176
dut._log.info("Start")
48-
clock = Clock(dut.clk, 10, units="us")
177+
clock = Clock(dut.clk, 1, units="us")
49178
cocotb.start_soon(clock.start())
50179

51180
# Reset
52181
dut.ena.value = 1
53182
dut.ui_in.value = 0
54183
dut.uio_in.value = 0
55184
dut.rst_n.value = 0
56-
await ClockCycles(dut.clk, 5)
185+
await ClockCycles(dut.clk, 2)
57186
dut.rst_n.value = 1
58-
await ClockCycles(dut.clk, 5)
187+
await ClockCycles(dut.clk, 2)
59188

60189
A = [5, -6, 7, 8] # row-major
61190
B = [8, 9, 6, 8] # row-major: [B00, B01, B10, B11]
62191

63-
await load_matrix(dut, A, sel=0, transpose=0, relu=1)
64-
await load_matrix(dut, B, sel=1, transpose=0, relu=1)
192+
await load_matrix(dut, A, transpose=0, relu=1)
193+
await load_matrix(dut, B, transpose=0, relu=1)
65194

66195
expected = get_expected_matmul(A, B, transpose=False, relu=True)
67196
results = await read_signed_output(dut, transpose=0, relu=1)
@@ -74,8 +203,8 @@ async def test_relu_transpose(dut):
74203
A = [1, 2, 3, 4]
75204
B = [5, 6, 7, 8]
76205

77-
await load_matrix(dut, A, sel=0, transpose=1, relu=1)
78-
await load_matrix(dut, B, sel=1, transpose=1, relu=1)
206+
await load_matrix(dut, A, transpose=1, relu=1)
207+
await load_matrix(dut, B, transpose=1, relu=1)
79208

80209
expected = get_expected_matmul(A, B, transpose=True, relu=True)
81210
results = await read_signed_output(dut, transpose=1, relu=1)
@@ -88,23 +217,23 @@ async def test_relu_transpose(dut):
88217
@cocotb.test()
89218
async def test_numeric_limits(dut):
90219
dut._log.info("Start")
91-
clock = Clock(dut.clk, 10, units="us")
220+
clock = Clock(dut.clk, 1, units="us")
92221
cocotb.start_soon(clock.start())
93222

94223
# Reset
95224
dut.ena.value = 1
96225
dut.ui_in.value = 0
97226
dut.uio_in.value = 0
98227
dut.rst_n.value = 0
99-
await ClockCycles(dut.clk, 5)
228+
await ClockCycles(dut.clk, 2)
100229
dut.rst_n.value = 1
101-
await ClockCycles(dut.clk, 5)
230+
await ClockCycles(dut.clk, 2)
102231

103232
A = [5, -6, 7, 8] # row-major
104233
B = [8, 12, 9, -7] # row-major: [B00, B01, B10, B11]
105234

106-
await load_matrix(dut, A, sel=0)
107-
await load_matrix(dut, B, sel=1)
235+
await load_matrix(dut, A)
236+
await load_matrix(dut, B)
108237

109238
expected = get_expected_matmul(A, B)
110239
results = []
@@ -121,8 +250,8 @@ async def test_numeric_limits(dut):
121250
A = [5, -6, 7, 8] # row-major
122251
B = [8, -12, 9, -7] # row-major: [B00, B01, B10, B11]
123252

124-
await load_matrix(dut, A, sel=0)
125-
await load_matrix(dut, B, sel=1)
253+
await load_matrix(dut, A)
254+
await load_matrix(dut, B)
126255

127256
expected = get_expected_matmul(A, B)
128257
results = []
@@ -139,17 +268,17 @@ async def test_numeric_limits(dut):
139268
@cocotb.test()
140269
async def test_project(dut):
141270
dut._log.info("Start")
142-
clock = Clock(dut.clk, 10, units="us")
271+
clock = Clock(dut.clk, 1, units="us")
143272
cocotb.start_soon(clock.start())
144273

145274
# Reset
146275
dut.ena.value = 1
147276
dut.ui_in.value = 0
148277
dut.uio_in.value = 0
149278
dut.rst_n.value = 0
150-
await ClockCycles(dut.clk, 5)
279+
await ClockCycles(dut.clk, 2)
151280
dut.rst_n.value = 1
152-
await ClockCycles(dut.clk, 5)
281+
await ClockCycles(dut.clk, 2)
153282

154283
# ------------------------------
155284
# STEP 1: Load matrix A
@@ -163,8 +292,8 @@ async def test_project(dut):
163292
# [7, 8]]
164293
B = [5, 6, 7, 8] # row-major: [B00, B01, B10, B11]
165294

166-
await load_matrix(dut, A, sel=0)
167-
await load_matrix(dut, B, sel=1)
295+
await load_matrix(dut, A)
296+
await load_matrix(dut, B)
168297

169298
# ------------------------------
170299
# STEP 4: Read outputs
@@ -189,8 +318,8 @@ async def test_project(dut):
189318
A = [79, -10, 7, 8] # row-major
190319
B = [2, 6, 5, 8] # row-major: [B00, B01, B10, B11]
191320

192-
await load_matrix(dut, A, sel=0)
193-
await load_matrix(dut, B, sel=1)
321+
await load_matrix(dut, A)
322+
await load_matrix(dut, B)
194323

195324
# ------------------------------
196325
# STEP 4: Read outputs
@@ -212,8 +341,8 @@ async def test_project(dut):
212341
A = [5, -6, 7, 8] # row-major
213342
B = [1, 2, 3, -4] # row-major: [B00, B01, B10, B11]
214343

215-
await load_matrix(dut, A, sel=0)
216-
await load_matrix(dut, B, sel=1)
344+
await load_matrix(dut, A)
345+
await load_matrix(dut, B)
217346

218347
expected = get_expected_matmul(A, B)
219348
results = []
@@ -225,4 +354,39 @@ async def test_project(dut):
225354
for i in range(4):
226355
assert results[i] == expected[i], f"C[{i//2}][{i%2}] = {results[i]} != expected {expected[i]}"
227356

228-
dut._log.info("Test 3 passed!")
357+
dut._log.info("Test 3 passed!")
358+
359+
# ------------------------------
360+
# TEST RUN 4: Large Matrix Multiplication with Arbitrary Dimensions
361+
# User-specified size, all elements MUST FIT WITHIN INT8 RANGE
362+
363+
"""
364+
[[ 18 8 -6]
365+
[-13 0 18]
366+
[ -2 2 -10]
367+
[-10 3 15]
368+
[ 19 3 -18]]
369+
370+
[[ 1 -19 3 9 17 -19]
371+
[ 0 12 -9 1 4 6]
372+
[ 7 -5 -6 -18 16 -14]]
373+
374+
correct result:
375+
[[ -24 -128 18 127 127 -128]
376+
[ 113 127 -128 -128 67 -5]
377+
[ -72 112 36 127 -128 127]
378+
[ 95 127 -128 -128 82 -2]
379+
[-107 -128 127 127 47 -91]]
380+
"""
381+
np.random.seed(42)
382+
A_large = np.random.randint(-20, 20, size=(5, 3))
383+
B_large = np.random.randint(-20, 20, size=(3, 6))
384+
385+
# Perform matrix multiplication on DUT
386+
result = await matmul(dut, A_large, B_large)
387+
388+
# Check results against expected
389+
check_expected(A_large, B_large, result)
390+
391+
dut._log.info("Test 4 (Arbitrary Dimension Matrix) passed!")
392+

0 commit comments

Comments
 (0)