Skip to content

Commit 6c0cc00

Browse files
committed
Add Support for optional ddi loading and ULTs
Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
1 parent d6f2c9f commit 6c0cc00

8 files changed

Lines changed: 213 additions & 17 deletions

File tree

scripts/templates/ze_loader_internal.h.mako

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ namespace loader
7171
ze_driver_handle_t zerDriverHandle = nullptr;
7272
ze_api_version_t versionRequested = ZE_API_VERSION_CURRENT;
7373
bool ddiInitialized = false;
74+
bool customDriver = false;
7475
ze_result_t zeddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
7576
ze_result_t zetddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
7677
ze_result_t zesddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED;

source/loader/driver_discovery.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414
namespace loader {
1515

16-
using DriverLibraryPath = std::string;
16+
struct DriverLibraryPath {
17+
std::string path;
18+
bool customDriver;
19+
20+
DriverLibraryPath(const std::string& p, bool isCustom = false)
21+
: path(p), customDriver(isCustom) {}
22+
};
1723

1824
std::vector<DriverLibraryPath> discoverEnabledDrivers();
1925

source/loader/linux/driver_discovery_lin.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,19 @@ std::vector<DriverLibraryPath> discoverEnabledDrivers() {
113113
// ZE_ENABLE_ALT_DRIVERS is for development/debug only
114114
altDrivers = getenv("ZE_ENABLE_ALT_DRIVERS");
115115
if (altDrivers == nullptr) {
116+
// Standard drivers - not custom
116117
for (auto path : knownDriverNames) {
117118
if (libraryExistsInSearchPaths(path)) {
118-
enabledDrivers.emplace_back(path);
119+
enabledDrivers.emplace_back(path, false);
119120
}
120121
}
121122
} else {
123+
// Alternative drivers from environment variable - these are custom
122124
std::stringstream ss(altDrivers);
123125
while (ss.good()) {
124126
std::string substr;
125127
getline(ss, substr, ',');
126-
enabledDrivers.emplace_back(substr);
128+
enabledDrivers.emplace_back(substr, true);
127129
}
128130
}
129131
return enabledDrivers;

source/loader/windows/driver_discovery_win.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,18 @@ std::vector<DriverLibraryPath> discoverEnabledDrivers() {
3333
// ZE_ENABLE_ALT_DRIVERS is for development/debug only
3434
envBufferSize = GetEnvironmentVariable("ZE_ENABLE_ALT_DRIVERS", &altDrivers[0], envBufferSize);
3535
if (!envBufferSize) {
36+
// Standard drivers discovered from registry - not custom
3637
auto displayDrivers = discoverDriversBasedOnDisplayAdapters(GUID_DEVCLASS_DISPLAY);
3738
auto computeDrivers = discoverDriversBasedOnDisplayAdapters(GUID_DEVCLASS_COMPUTEACCELERATOR);
3839
enabledDrivers.insert(enabledDrivers.end(), displayDrivers.begin(), displayDrivers.end());
3940
enabledDrivers.insert(enabledDrivers.end(), computeDrivers.begin(), computeDrivers.end());
4041
} else {
42+
// Alternative drivers from environment variable - these are custom
4143
std::stringstream ss(altDrivers.c_str());
4244
while (ss.good()) {
4345
std::string substr;
4446
getline(ss, substr, ',');
45-
enabledDrivers.emplace_back(substr);
47+
enabledDrivers.emplace_back(substr, true);
4648
}
4749
}
4850

@@ -110,7 +112,7 @@ DriverLibraryPath readDriverPathForDisplayAdapter(DEVINST dnDevNode) {
110112

111113
if (CR_SUCCESS != configErr) {
112114
assert(false && "CM_Open_DevNode_Key failed");
113-
return "";
115+
return DriverLibraryPath("", false);
114116
}
115117

116118
DWORD regValueType = {};
@@ -133,7 +135,7 @@ DriverLibraryPath readDriverPathForDisplayAdapter(DEVINST dnDevNode) {
133135
regOpStatus = RegCloseKey(hkey);
134136
assert((ERROR_SUCCESS == regOpStatus) && "RegCloseKey failed");
135137

136-
return driverPath;
138+
return DriverLibraryPath(driverPath, false);
137139
}
138140

139141
std::wstring readDisplayAdaptersDeviceIdsList(const GUID rguid) {
@@ -193,11 +195,12 @@ std::vector<DriverLibraryPath> discoverDriversBasedOnDisplayAdapters(const GUID
193195

194196
auto driverPath = readDriverPathForDisplayAdapter(devinst);
195197

196-
if (driverPath.empty()) {
198+
if (driverPath.path.empty()) {
197199
continue;
198200
}
199201

200-
bool alreadyOnTheList = (enabledDrivers.end() != std::find(enabledDrivers.begin(), enabledDrivers.end(), driverPath));
202+
bool alreadyOnTheList = (enabledDrivers.end() != std::find_if(enabledDrivers.begin(), enabledDrivers.end(),
203+
[&driverPath](const DriverLibraryPath& d) { return d.path == driverPath.path; }));
201204
if (alreadyOnTheList) {
202205
continue;
203206
}

source/loader/ze_loader.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -510,19 +510,21 @@ namespace loader
510510
std::string message = "init driver " + driver.name + " found GPU Supported Driver.";
511511
debug_trace_message(message, "");
512512
}
513+
loadDriver = true;
513514
}
514-
loadDriver = true;
515515
}
516516
if ((!desc && (flags == 0 || flags & ZE_INIT_FLAG_VPU_ONLY)) || (desc && desc->flags & ZE_INIT_DRIVER_TYPE_FLAG_NPU)) {
517517
if (driver.name.find("vpu") != std::string::npos || driver.name.find("npu") != std::string::npos) {
518518
if (debugTraceEnabled) {
519519
std::string message = "init driver " + driver.name + " found VPU/NPU Supported Driver.";
520520
debug_trace_message(message, "");
521521
}
522+
loadDriver = true;
522523
}
523-
loadDriver = true;
524524
}
525525

526+
loadDriver = !driver.handle && driver.customDriver ? true : loadDriver;
527+
526528
if (loadDriver && !driver.handle) {
527529
auto handle = LOAD_DRIVER_LIBRARY( driver.name.c_str() );
528530
if( NULL != handle )
@@ -584,6 +586,14 @@ namespace loader
584586
driver.ddiInitialized = true;
585587
}
586588

589+
if (!driver.handle && !driver.ddiInitialized) {
590+
if (debugTraceEnabled) {
591+
std::string message = "init driver " + driver.name + " does not match the requested flags or desc, skipping driver.";
592+
debug_trace_message(message, "");
593+
}
594+
return ZE_RESULT_ERROR_UNINITIALIZED;
595+
}
596+
587597
return ZE_RESULT_SUCCESS;
588598
}
589599

@@ -682,14 +692,14 @@ namespace loader
682692
}
683693
}
684694

685-
for( auto name : discoveredDrivers )
695+
for( auto driverInfo : discoveredDrivers )
686696
{
687697
if (discoveredDrivers.size() == 1) {
688-
auto handle = LOAD_DRIVER_LIBRARY( name.c_str() );
698+
auto handle = LOAD_DRIVER_LIBRARY( driverInfo.path.c_str() );
689699
if( NULL != handle )
690700
{
691701
if (debugTraceEnabled) {
692-
std::string message = "Loading Driver " + name + " succeeded";
702+
std::string message = "Loading Driver " + driverInfo.path + " succeeded";
693703
#if !defined(_WIN32) && !defined(ANDROID)
694704
// TODO: implement same message for windows, move dlinfo to ze_util.h as a macro
695705
struct link_map *dlinfo_map;
@@ -701,17 +711,19 @@ namespace loader
701711
}
702712
allDrivers.emplace_back();
703713
allDrivers.rbegin()->handle = handle;
704-
allDrivers.rbegin()->name = name;
714+
allDrivers.rbegin()->name = driverInfo.path;
715+
allDrivers.rbegin()->customDriver = driverInfo.customDriver;
705716
} else if (debugTraceEnabled) {
706717
GET_LIBRARY_ERROR(loadLibraryErrorValue);
707-
std::string errorMessage = "Load Library of " + name + " failed with ";
718+
std::string errorMessage = "Load Library of " + driverInfo.path + " failed with ";
708719
debug_trace_message(errorMessage, loadLibraryErrorValue);
709720
loadLibraryErrorValue.clear();
710721
}
711722
} else {
712723
allDrivers.emplace_back();
713724
allDrivers.rbegin()->handle = nullptr;
714-
allDrivers.rbegin()->name = name;
725+
allDrivers.rbegin()->name = driverInfo.path;
726+
allDrivers.rbegin()->customDriver = driverInfo.customDriver;
715727
}
716728
}
717729
if(allDrivers.size()==0){

source/loader/ze_loader_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ namespace loader
6262
ze_driver_handle_t zerDriverHandle = nullptr;
6363
ze_api_version_t versionRequested = ZE_API_VERSION_CURRENT;
6464
bool ddiInitialized = false;
65+
bool customDriver = false;
6566
ze_result_t zeddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
6667
ze_result_t zetddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
6768
ze_result_t zesddiInitResult = ZE_RESULT_ERROR_UNINITIALIZED;

test/CMakeLists.txt

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ add_executable(
1111
# Only include driver_ordering_unit_tests for static builds or non-Windows platforms
1212
# as it requires internal loader symbols that are not exported in Windows DLLs
1313
if(BUILD_STATIC OR NOT WIN32)
14-
target_sources(tests PRIVATE driver_ordering_unit_tests.cpp)
14+
target_sources(tests PRIVATE driver_ordering_unit_tests.cpp init_driver_unit_tests.cpp)
1515
endif()
1616

1717
# For static builds, we need to include the loader source files directly
@@ -69,6 +69,37 @@ if(BUILD_STATIC)
6969
endif()
7070
endif()
7171

72+
# Create fake driver copies for init_driver_unit_tests
73+
if(BUILD_STATIC OR NOT WIN32)
74+
if(MSVC)
75+
add_custom_command(TARGET tests POST_BUILD
76+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
77+
$<TARGET_FILE_DIR:tests>/ze_null.dll
78+
$<TARGET_FILE_DIR:tests>/ze_fake_gpu.dll
79+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
80+
$<TARGET_FILE_DIR:tests>/ze_null_test1.dll
81+
$<TARGET_FILE_DIR:tests>/ze_fake_npu.dll
82+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
83+
$<TARGET_FILE_DIR:tests>/ze_null_test2.dll
84+
$<TARGET_FILE_DIR:tests>/ze_fake_vpu.dll
85+
COMMENT "Copying null drivers to fake driver names for init_driver_unit_tests"
86+
)
87+
else()
88+
add_custom_command(TARGET tests POST_BUILD
89+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
90+
${CMAKE_BINARY_DIR}/lib/libze_null.so.1
91+
${CMAKE_BINARY_DIR}/lib/libze_fake_gpu.so.1
92+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
93+
${CMAKE_BINARY_DIR}/lib/libze_null_test1.so.1
94+
${CMAKE_BINARY_DIR}/lib/libze_fake_npu.so.1
95+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
96+
${CMAKE_BINARY_DIR}/lib/libze_null_test2.so.1
97+
${CMAKE_BINARY_DIR}/lib/libze_fake_vpu.so.1
98+
COMMENT "Copying null drivers to fake driver names for init_driver_unit_tests"
99+
)
100+
endif()
101+
endif()
102+
72103
add_test(NAME tests_api COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingzeGetLoaderVersionsAPIThenValidVersionIsReturned*)
73104
set_property(TEST tests_api PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1")
74105
add_test(NAME tests_init_gpu_all COMMAND tests --gtest_filter=*GivenLevelZeroLoaderPresentWhenCallingZeInitDriversWithGPUTypeThenExpectPassWithGPUorAllOnly*)
@@ -596,6 +627,9 @@ set_property(TEST driver_ordering_trim_function PROPERTY ENVIRONMENT "ZE_ENABLE_
596627
add_test(NAME driver_ordering_parse_driver_order COMMAND tests --gtest_filter=DriverOrderingHelperFunctionsTest.ParseDriverOrder_*)
597628
set_property(TEST driver_ordering_parse_driver_order PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1")
598629

630+
# Init Driver Unit Tests
631+
add_test(NAME init_driver_unit_tests COMMAND tests --gtest_filter=InitDriverUnitTest.*)
632+
set_property(TEST init_driver_unit_tests PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1")
599633

600634
# These tests are currently not supported on Windows. The reason is that the std::cerr is not being redirected to a pipe in Windows to be then checked against the expected output.
601635
if(NOT MSVC)

test/init_driver_unit_tests.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (C) 2025 Intel Corporation
3+
* SPDX-License-Identifier: MIT
4+
*/
5+
6+
7+
#include "gtest/gtest.h"
8+
#include "source/loader/ze_loader_internal.h"
9+
#include "ze_api.h"
10+
#include <vector>
11+
#include <string>
12+
13+
class InitDriverUnitTest : public ::testing::Test {
14+
protected:
15+
void SetUp() override {
16+
if (!loader::context) {
17+
loader::context = new loader::context_t();
18+
loader::context->debugTraceEnabled = false;
19+
}
20+
}
21+
};
22+
// Helper to create a mock null driver with a given name and type
23+
loader::driver_t createNullDriver(const std::string& name, loader::zel_driver_type_t type) {
24+
loader::driver_t driver;
25+
#ifdef _WIN32
26+
driver.name = name + ".dll";
27+
#else
28+
driver.name = "lib" + name + ".so.1";
29+
#endif
30+
driver.driverType = type;
31+
driver.handle = nullptr; // Simulate null driver
32+
driver.initStatus = ZE_RESULT_SUCCESS;
33+
driver.driverInuse = false;
34+
driver.ddiInitialized = false;
35+
return driver;
36+
}
37+
38+
39+
TEST_F(InitDriverUnitTest, InitWithSingleGPUDriver) {
40+
loader::driver_t gpuDriver = createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU);
41+
ze_result_t result = loader::context->init_driver(gpuDriver, ZE_INIT_FLAG_GPU_ONLY, nullptr, nullptr, nullptr, false);
42+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
43+
EXPECT_TRUE(gpuDriver.ddiInitialized);
44+
}
45+
46+
TEST_F(InitDriverUnitTest, InitWithSingleNPUDriver) {
47+
loader::driver_t npuDriver = createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU);
48+
ze_init_driver_type_desc_t desc = {};
49+
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU;
50+
ze_result_t result = loader::context->init_driver(npuDriver, 0, &desc, nullptr, nullptr, false);
51+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
52+
EXPECT_TRUE(npuDriver.ddiInitialized);
53+
}
54+
55+
TEST_F(InitDriverUnitTest, InitWithSingleVPUDriver) {
56+
loader::driver_t vpuDriver = createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU);
57+
ze_result_t result = loader::context->init_driver(vpuDriver, ZE_INIT_FLAG_VPU_ONLY, nullptr, nullptr, nullptr, false);
58+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
59+
EXPECT_TRUE(vpuDriver.ddiInitialized);
60+
}
61+
62+
TEST_F(InitDriverUnitTest, zeInitWithMultipleDrivers) {
63+
std::vector<loader::driver_t> drivers = {
64+
createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU),
65+
createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU),
66+
createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU)
67+
};
68+
for (auto& driver : drivers) {
69+
ze_result_t result = loader::context->init_driver(driver, 0, nullptr, nullptr, nullptr, false);
70+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
71+
EXPECT_TRUE(driver.ddiInitialized);
72+
}
73+
}
74+
75+
TEST_F(InitDriverUnitTest, zeInitDriversWithMultipleDrivers) {
76+
std::vector<loader::driver_t> drivers = {
77+
createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU),
78+
createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU),
79+
createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU)
80+
};
81+
ze_init_driver_type_desc_t desc = {};
82+
desc.flags = UINT32_MAX; // Request all driver types
83+
for (auto& driver : drivers) {
84+
ze_result_t result = loader::context->init_driver(driver, 0, &desc, nullptr, nullptr, false);
85+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
86+
EXPECT_TRUE(driver.ddiInitialized);
87+
}
88+
}
89+
90+
TEST_F(InitDriverUnitTest, zeInitDriversWithMultipleDriversNPURequested) {
91+
std::vector<loader::driver_t> drivers = {
92+
createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU),
93+
createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU),
94+
createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU)
95+
};
96+
ze_init_driver_type_desc_t desc = {};
97+
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_NPU; // Request NPU driver types
98+
for (auto& driver : drivers) {
99+
if (driver.driverType == loader::ZEL_DRIVER_TYPE_NPU) {
100+
ze_result_t result = loader::context->init_driver(driver, 0, &desc, nullptr, nullptr, false);
101+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
102+
EXPECT_TRUE(driver.ddiInitialized);
103+
} else {
104+
ze_result_t result = loader::context->init_driver(driver, 0, &desc, nullptr, nullptr, false);
105+
EXPECT_NE(result, ZE_RESULT_SUCCESS);
106+
EXPECT_FALSE(driver.ddiInitialized);
107+
}
108+
}
109+
}
110+
111+
TEST_F(InitDriverUnitTest, zeInitDriversWithMultipleDriversGPURequested) {
112+
std::vector<loader::driver_t> drivers = {
113+
createNullDriver("ze_fake_gpu", loader::ZEL_DRIVER_TYPE_DISCRETE_GPU),
114+
createNullDriver("ze_fake_npu", loader::ZEL_DRIVER_TYPE_NPU),
115+
createNullDriver("ze_fake_vpu", loader::ZEL_DRIVER_TYPE_NPU)
116+
};
117+
ze_init_driver_type_desc_t desc = {};
118+
desc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; // Request GPU driver types
119+
for (auto& driver : drivers) {
120+
if (driver.driverType == loader::ZEL_DRIVER_TYPE_DISCRETE_GPU) {
121+
ze_result_t result = loader::context->init_driver(driver, 0, &desc, nullptr, nullptr, false);
122+
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
123+
EXPECT_TRUE(driver.ddiInitialized);
124+
} else {
125+
ze_result_t result = loader::context->init_driver(driver, 0, &desc, nullptr, nullptr, false);
126+
EXPECT_NE(result, ZE_RESULT_SUCCESS);
127+
EXPECT_FALSE(driver.ddiInitialized);
128+
}
129+
}
130+
}
131+
132+
TEST_F(InitDriverUnitTest, InitWithUnsupportedNullDriverType) {
133+
loader::driver_t otherDriver = createNullDriver("ze_fake_other", loader::ZEL_DRIVER_TYPE_OTHER);
134+
ze_result_t result = loader::context->init_driver(otherDriver, 0, nullptr, nullptr, nullptr, false);
135+
EXPECT_NE(result, ZE_RESULT_SUCCESS);
136+
EXPECT_FALSE(otherDriver.ddiInitialized);
137+
}

0 commit comments

Comments
 (0)