-
Notifications
You must be signed in to change notification settings - Fork 117
Expand file tree
/
Copy pathsource_onemath_usm_gemm.cpp
More file actions
140 lines (108 loc) · 4 KB
/
source_onemath_usm_gemm.cpp
File metadata and controls
140 lines (108 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/*
SYCL Academy (c)
SYCL Academy is licensed under a Creative Commons
Attribution-ShareAlike 4.0 International License.
You should have received a copy of the license along with this
work. If not, see <http://creativecommons.org/licenses/by-sa/4.0/>.
Quick Reference
~~~~~~~~~~~~~~~~~~~~
oneMath execution model:
https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onemath/source/architecture/architecture
oneMath GEMM API:
https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onemath/source/domains/blas/gemm
*/
#include <iostream>
#include <limits>
#include <oneapi/math.hpp>
#include <random>
#include <sycl/sycl.hpp>
// Matrix size constants
constexpr size_t SIZE = 4800; // Must be a multiple of 8.
constexpr size_t M = SIZE / 8;
constexpr size_t N = SIZE / 4;
constexpr size_t P = SIZE / 2;
using T = double;
//////////////////////////////////////////////////////////////////////////////////////////
bool ValueSame(T a, T b) { return std::fabs(a - b) < 1.0e-08; }
int VerifyResult(T* c_A, T* c_B) {
bool MismatchFound = false;
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < P; j++) {
if (!ValueSame(c_A[i * P + j], c_B[i * P + j])) {
std::cout << "fail - The result is incorrect for element: [" << i
<< ", " << j << "], expected: " << c_A[i * P + j]
<< " , but got: " << c_B[i * P + j] << std::endl;
MismatchFound = true;
}
}
}
if (!MismatchFound) {
std::cout << "SUCCESS - The results are correct!" << std::endl;
return 0;
} else {
std::cout << "FAIL - The results mis-match!" << std::endl;
return -1;
}
}
//////////////////////////////////////////////////////////////////////////////////////////
void print_device_info(sycl::queue& Q) {
std::string sycl_dev_name, sycl_dev_version, sycl_driver;
sycl_dev_name = Q.get_device().get_info<sycl::info::device::name>();
sycl_driver = Q.get_device().get_info<sycl::info::device::driver_version>();
sycl_dev_version = Q.get_device().get_info<sycl::info::device::version>();
std::cout << "Running on " << sycl_dev_name.c_str()
<< ", version: " << sycl_dev_version.c_str()
<< ", driver version: " << sycl_driver.c_str() << std::endl;
}
//////////////////////////////////////////////////////////////////////////////////////////
int main() {
std::random_device
rd; // Will be used to obtain a seed for the random number engine
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd()
std::uniform_real_distribution<> dis(1.0, 2.0);
// matrix data sizes
int m = M;
int n = P;
int k = N;
// leading dimensions of data
int ldA = k;
int ldB = n;
int ldC = n;
// set scalar fp values
T alpha = 1.0;
T beta = 0.0;
// Allocate memory on host
std::vector<T> A(M * N);
std::vector<T> B(N * P);
std::vector<T> C_host(M * P);
std::cout << "Problem size: c(" << M << "," << P << ") = a(" << M << "," << N
<< ") * b(" << N << "," << P << ")" << std::endl;
// A(M, N)
for (size_t i = 0; i < M; i++)
for (size_t j = 0; j < N; j++) A[i * N + j] = dis(gen);
// B(N, P)
for (size_t i = 0; i < N; i++)
for (size_t j = 0; j < P; j++) B[i * P + j] = dis(gen);
// Resultant matrix: C_serial = A*B
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < P; j++) {
for (size_t d = 0; d < N; d++) {
C_host[i * P + j] += A[i * N + d] * B[d * P + j];
}
}
}
// Create a SYCL queue
sycl::queue Q;
// Prints some basic info related to the hardware
print_device_info(Q);
// TODO: Allocate memory on device, (using sycl::malloc_device APIs)
// TODO: Use oneMath GEMM USM API
// TODO: Copy the results from device to host for verification
// Verify results from oneMath
int result = 0;
std::cout << "Verify results between oneMath & serial: ";
// TODO: Uncomment the following line verify the results
// result = VerifyResult(C_device, C_host);
// TODO: Free memory from device
return result;
}