diff --git a/src/algorithms/machinelearning/tensorflowpredict.cpp b/src/algorithms/machinelearning/tensorflowpredict.cpp index 5b8a02ebc..a0b4d0c7e 100644 --- a/src/algorithms/machinelearning/tensorflowpredict.cpp +++ b/src/algorithms/machinelearning/tensorflowpredict.cpp @@ -80,6 +80,14 @@ void TensorflowPredict::configure() { // Do not do anything if we did not get a non-empty model name. if ((_savedModel.empty()) and (_graphFilename.empty())) return; + // Lazy initialization: Create TF objects on first real configure() call. + // This prevents GPU initialization when the algorithm is just registered. + if (!_status) { + _status = TF_NewStatus(); + _options = TF_NewImportGraphDefOptions(); + _sessionOptions = TF_NewSessionOptions(); + } + _tags = parameter("tags").toVectorString(); _inputNames = parameter("inputs").toVectorString(); @@ -100,7 +108,8 @@ void TensorflowPredict::configure() { _outputTensors.resize(_nOutputs); _outputNodes.resize(_nOutputs); - TF_DeleteGraph(_graph); + // Clean up previous graph if reconfiguring + if (_graph) TF_DeleteGraph(_graph); _graph = TF_NewGraph(); openGraph(); @@ -209,16 +218,20 @@ void TensorflowPredict::openGraph() { void TensorflowPredict::reset() { if (!_isConfigured) return; - TF_CloseSession(_session, _status); - if (TF_GetCode(_status) != TF_OK) { - throw EssentiaException("TensorflowPredict: Error closing session. ", TF_Message(_status)); - } + // Close and delete existing session if present (reconfiguration case) + if (_session) { + TF_CloseSession(_session, _status); + if (TF_GetCode(_status) != TF_OK) { + throw EssentiaException("TensorflowPredict: Error closing session. ", TF_Message(_status)); + } - TF_DeleteSession(_session, _status); - if (TF_GetCode(_status) != TF_OK) { - throw EssentiaException("TensorflowPredict: Error deleting session. ", TF_Message(_status)); + TF_DeleteSession(_session, _status); + if (TF_GetCode(_status) != TF_OK) { + throw EssentiaException("TensorflowPredict: Error deleting session. ", TF_Message(_status)); + } } + // Create new session (this is where TF GPU initialization happens) _session = TF_NewSession(_graph, _sessionOptions, _status); if (TF_GetCode(_status) != TF_OK) { throw EssentiaException("TensorflowPredict: Error creating new session after reset. ", TF_Message(_status)); diff --git a/src/algorithms/machinelearning/tensorflowpredict.h b/src/algorithms/machinelearning/tensorflowpredict.h index b3c3ac1cd..4decd73b3 100644 --- a/src/algorithms/machinelearning/tensorflowpredict.h +++ b/src/algorithms/machinelearning/tensorflowpredict.h @@ -79,22 +79,27 @@ class TensorflowPredict : public Algorithm { } public: - TensorflowPredict() : _graph(TF_NewGraph()), _status(TF_NewStatus()), - _options(TF_NewImportGraphDefOptions()), _sessionOptions(TF_NewSessionOptions()), - _session(TF_NewSession(_graph, _sessionOptions, _status)), _runOptions(NULL), + // Lazy initialization: TF objects are created in configure(), not here. + // This prevents GPU initialization when the algorithm is just registered. + TensorflowPredict() : _graph(NULL), _status(NULL), + _options(NULL), _sessionOptions(NULL), + _session(NULL), _runOptions(NULL), _isConfigured(false) { declareInput(_poolIn, "poolIn", "the pool where to get the feature tensors"); declareOutput(_poolOut, "poolOut", "the pool where to store the output tensors"); } ~TensorflowPredict(){ - TF_CloseSession(_session, _status); - TF_DeleteSessionOptions(_sessionOptions); - TF_DeleteSession(_session, _status); - TF_DeleteImportGraphDefOptions(_options); - TF_DeleteStatus(_status); - TF_DeleteGraph(_graph); - TF_DeleteBuffer(_runOptions); + // Guard against destruction before configure() was ever called + if (_session) { + TF_CloseSession(_session, _status); + TF_DeleteSession(_session, _status); + } + if (_sessionOptions) TF_DeleteSessionOptions(_sessionOptions); + if (_options) TF_DeleteImportGraphDefOptions(_options); + if (_status) TF_DeleteStatus(_status); + if (_graph) TF_DeleteGraph(_graph); + if (_runOptions) TF_DeleteBuffer(_runOptions); } void declareParameters() {