Skip to content

Commit 45808dd

Browse files
committed
Allocate the global stability variables with the context and catch exceptions in the thread
Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
1 parent 64d2e4c commit 45808dd

1 file changed

Lines changed: 93 additions & 49 deletions

File tree

source/lib/ze_lib.cpp

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ namespace ze_lib
2828
}
2929
}
3030
bool delayContextDestruction = false;
31-
std::mutex stabilityMutex;
32-
std::promise<int> stabilityPromiseResult;
33-
std::future<int> resultFutureResult;
34-
std::atomic<int> stabilityCheckThreadStarted{0};
35-
std::thread stabilityThread;
31+
std::mutex *stabilityMutex = nullptr;
32+
std::promise<int> *stabilityPromiseResult = nullptr;
33+
std::future<int> *resultFutureResult = nullptr;
34+
std::atomic<int> *stabilityCheckThreadStarted = nullptr;
35+
std::thread *stabilityThread = nullptr;
3636
#endif
3737
bool destruction = false;
3838

@@ -49,14 +49,24 @@ namespace ze_lib
4949
if (loader) {
5050
FREE_DRIVER_LIBRARY( loader );
5151
}
52-
ze_lib::stabilityCheckThreadStarted = -1;
52+
ze_lib::stabilityCheckThreadStarted->store(-1);
5353
try {
54-
if (stabilityThread.joinable()) {
55-
stabilityThread.join();
54+
if (stabilityThread->joinable()) {
55+
stabilityThread->join();
5656
}
5757
} catch (...) {
5858
// Ignore any exceptions from thread join
5959
}
60+
delete stabilityThread;
61+
stabilityThread = nullptr;
62+
delete stabilityMutex;
63+
stabilityMutex = nullptr;
64+
delete stabilityPromiseResult;
65+
stabilityPromiseResult = nullptr;
66+
delete resultFutureResult;
67+
resultFutureResult = nullptr;
68+
delete stabilityCheckThreadStarted;
69+
stabilityCheckThreadStarted = nullptr;
6070
#endif
6171
ze_lib::destruction = true;
6272
};
@@ -163,6 +173,10 @@ namespace ze_lib
163173
std::string version_message = "Loader API Version to be requested is v" + std::to_string(ZE_MAJOR_VERSION(version)) + "." + std::to_string(ZE_MINOR_VERSION(version));
164174
debug_trace_message(version_message, "");
165175
loaderDriverGet = reinterpret_cast<ze_pfnDriverGet_t>(GET_FUNCTION_PTR(loader, "zeDriverGet"));
176+
stabilityMutex = new std::mutex();
177+
stabilityPromiseResult = new std::promise<int>();
178+
resultFutureResult = new std::future<int>(stabilityPromiseResult->get_future());
179+
stabilityCheckThreadStarted = new std::atomic<int>(0);
166180
#else
167181
result = zeLoaderInit();
168182
if( ZE_RESULT_SUCCESS == result ) {
@@ -455,28 +469,33 @@ zelSetDelayLoaderContextTeardown()
455469
* @exception This function catches all exceptions internally and does not throw.
456470
*/
457471
void stabilityCheck(std::promise<int> stabilityPromise) {
458-
if (!ze_lib::context->loaderDriverGet) {
459-
if (ze_lib::context->debugTraceEnabled) {
460-
std::string message = "LoaderDriverGet is a bad pointer. Exiting stability checker thread.";
461-
ze_lib::context->debug_trace_message(message, "");
472+
try {
473+
if (!ze_lib::context->loaderDriverGet) {
474+
if (ze_lib::context->debugTraceEnabled) {
475+
std::string message = "LoaderDriverGet is a bad pointer. Exiting stability checker.";
476+
ze_lib::context->debug_trace_message(message, "");
477+
}
478+
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL);
479+
return;
462480
}
463-
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_NULL);
464-
return;
465-
}
466481

467-
uint32_t driverCount = 0;
468-
ze_result_t result = ZE_RESULT_ERROR_UNINITIALIZED;
469-
result = ze_lib::context->loaderDriverGet(&driverCount, nullptr);
470-
if (result != ZE_RESULT_SUCCESS || driverCount == 0) {
471-
if (ze_lib::context->debugTraceEnabled) {
472-
std::string message = "Loader stability check failed. Exiting stability checker thread.";
473-
ze_lib::context->debug_trace_message(message, "");
482+
uint32_t driverCount = 0;
483+
ze_result_t result = ZE_RESULT_ERROR_UNINITIALIZED;
484+
result = ze_lib::context->loaderDriverGet(&driverCount, nullptr);
485+
if (result != ZE_RESULT_SUCCESS || driverCount == 0) {
486+
if (ze_lib::context->debugTraceEnabled) {
487+
std::string message = "Loader stability check failed. Exiting stability checker.";
488+
ze_lib::context->debug_trace_message(message, "");
489+
}
490+
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED);
491+
return;
474492
}
475-
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_DRIVER_GET_FAILED);
493+
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_SUCCESS);
494+
return;
495+
} catch (...) {
496+
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_EXCEPTION);
476497
return;
477498
}
478-
stabilityPromise.set_value(ZEL_STABILITY_CHECK_RESULT_SUCCESS);
479-
return;
480499
}
481500
#endif
482501

@@ -509,45 +528,70 @@ zelCheckIsLoaderInTearDown() {
509528
try {
510529
// Launch the stability checker thread on the first call
511530
static std::once_flag stabilityThreadFlag;
512-
std::lock_guard<std::mutex> lock(ze_lib::stabilityMutex);
513-
ze_lib::stabilityPromiseResult = std::promise<int>();
514-
ze_lib::resultFutureResult = ze_lib::stabilityPromiseResult.get_future();
515-
ze_lib::stabilityCheckThreadStarted = 1;
531+
std::lock_guard<std::mutex> lock(*ze_lib::stabilityMutex);
532+
*ze_lib::stabilityPromiseResult = std::promise<int>();
533+
*ze_lib::resultFutureResult = ze_lib::stabilityPromiseResult->get_future();
534+
ze_lib::stabilityCheckThreadStarted->store(1);
516535
std::call_once(stabilityThreadFlag, []() {
517-
ze_lib::stabilityThread = std::thread([]() {
536+
ze_lib::stabilityThread = new std::thread([]() {
518537
while (true) {
519-
std::promise<int> stabilityPromise;
520-
std::future<int> resultFuture = stabilityPromise.get_future();
521-
while(ze_lib::stabilityCheckThreadStarted == 0) {
522-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
523-
}
524-
if (ze_lib::stabilityCheckThreadStarted == -1) {
525-
break;
526-
}
527-
ze_lib::stabilityCheckThreadStarted = 0;
528-
stabilityCheck(std::move(stabilityPromise));
529-
int result = resultFuture.get();
530-
if (result != ZEL_STABILITY_CHECK_RESULT_SUCCESS) {
538+
try {
539+
std::promise<int> stabilityPromise;
540+
std::future<int> resultFuture = stabilityPromise.get_future();
541+
while(ze_lib::stabilityCheckThreadStarted && *ze_lib::stabilityCheckThreadStarted == 0) {
542+
std::this_thread::sleep_for(std::chrono::milliseconds(1));
543+
}
544+
if (!ze_lib::stabilityCheckThreadStarted) {
545+
break;
546+
}
547+
if (*ze_lib::stabilityCheckThreadStarted == -1) {
548+
break;
549+
}
550+
ze_lib::stabilityCheckThreadStarted->store(0);
551+
stabilityCheck(std::move(stabilityPromise));
552+
int result = resultFuture.get();
553+
if (result != ZEL_STABILITY_CHECK_RESULT_SUCCESS) {
554+
if (ze_lib::context->debugTraceEnabled) {
555+
std::string message = "Loader stability check thread failed with result: " + std::to_string(result);
556+
ze_lib::context->debug_trace_message(message, "");
557+
}
558+
if (ze_lib::stabilityPromiseResult) {
559+
ze_lib::stabilityPromiseResult->set_value(result);
560+
}
561+
break; // Exit the thread if stability check fails
562+
}
563+
if (ze_lib::stabilityPromiseResult) {
564+
ze_lib::stabilityPromiseResult->set_value(result);
565+
}
566+
} catch (const std::exception& e) {
567+
if (ze_lib::context->debugTraceEnabled) {
568+
std::string message = "Exception caught in stability check thread: " + std::string(e.what());
569+
ze_lib::context->debug_trace_message(message, "");
570+
if (ze_lib::stabilityPromiseResult) {
571+
ze_lib::stabilityPromiseResult->set_value(ZEL_STABILITY_CHECK_RESULT_EXCEPTION);
572+
}
573+
}
574+
} catch (...) {
531575
if (ze_lib::context->debugTraceEnabled) {
532-
std::string message = "Loader stability check thread failed with result: " + std::to_string(result);
576+
std::string message = "Unknown exception caught in stability check thread.";
533577
ze_lib::context->debug_trace_message(message, "");
578+
if (ze_lib::stabilityPromiseResult) {
579+
ze_lib::stabilityPromiseResult->set_value(ZEL_STABILITY_CHECK_RESULT_EXCEPTION);
580+
}
534581
}
535-
ze_lib::stabilityPromiseResult.set_value(result);
536-
break; // Exit the thread if stability check fails
537582
}
538-
ze_lib::stabilityPromiseResult.set_value(result);
539583
}
540584
});
541-
ze_lib::stabilityThread.detach();
585+
ze_lib::stabilityThread->detach();
542586
});
543-
if (ze_lib::resultFutureResult.wait_for(std::chrono::milliseconds(ZEL_STABILITY_CHECK_THREAD_TIMEOUT)) == std::future_status::timeout) {
587+
if (ze_lib::resultFutureResult->wait_for(std::chrono::milliseconds(ZEL_STABILITY_CHECK_THREAD_TIMEOUT)) == std::future_status::timeout) {
544588
if (ze_lib::context->debugTraceEnabled) {
545589
std::string message = "Stability Thread timeout, assuming thread has crashed";
546590
ze_lib::context->debug_trace_message(message, "");
547591
}
548592
threadResult = ZEL_STABILITY_CHECK_RESULT_EXCEPTION;
549593
} else {
550-
threadResult = ze_lib::resultFutureResult.get();
594+
threadResult = ze_lib::resultFutureResult->get();
551595
}
552596
} catch (const std::exception& e) {
553597
if (ze_lib::context->debugTraceEnabled) {

0 commit comments

Comments
 (0)