Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/algorithms/machinelearning/tensorflowpredict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ void TensorflowPredict::configure() {
// Do not do anything if we did not get a non-empty model name.
if ((_savedModel.empty()) and (_graphFilename.empty())) return;

// Lazy initialization: Create TF objects on first real configure() call.
// This prevents GPU initialization when the algorithm is just registered.
if (!_status) {
_status = TF_NewStatus();
_options = TF_NewImportGraphDefOptions();
_sessionOptions = TF_NewSessionOptions();
}

_tags = parameter("tags").toVectorString();

_inputNames = parameter("inputs").toVectorString();
Expand All @@ -100,7 +108,8 @@ void TensorflowPredict::configure() {
_outputTensors.resize(_nOutputs);
_outputNodes.resize(_nOutputs);

TF_DeleteGraph(_graph);
// Clean up previous graph if reconfiguring
if (_graph) TF_DeleteGraph(_graph);
_graph = TF_NewGraph();

openGraph();
Expand Down Expand Up @@ -209,16 +218,20 @@ void TensorflowPredict::openGraph() {
void TensorflowPredict::reset() {
if (!_isConfigured) return;

TF_CloseSession(_session, _status);
if (TF_GetCode(_status) != TF_OK) {
throw EssentiaException("TensorflowPredict: Error closing session. ", TF_Message(_status));
}
// Close and delete existing session if present (reconfiguration case)
if (_session) {
TF_CloseSession(_session, _status);
if (TF_GetCode(_status) != TF_OK) {
throw EssentiaException("TensorflowPredict: Error closing session. ", TF_Message(_status));
}

TF_DeleteSession(_session, _status);
if (TF_GetCode(_status) != TF_OK) {
throw EssentiaException("TensorflowPredict: Error deleting session. ", TF_Message(_status));
TF_DeleteSession(_session, _status);
if (TF_GetCode(_status) != TF_OK) {
throw EssentiaException("TensorflowPredict: Error deleting session. ", TF_Message(_status));
}
}

// Create new session (this is where TF GPU initialization happens)
_session = TF_NewSession(_graph, _sessionOptions, _status);
if (TF_GetCode(_status) != TF_OK) {
throw EssentiaException("TensorflowPredict: Error creating new session after reset. ", TF_Message(_status));
Expand Down
25 changes: 15 additions & 10 deletions src/algorithms/machinelearning/tensorflowpredict.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,27 @@ class TensorflowPredict : public Algorithm {
}

public:
TensorflowPredict() : _graph(TF_NewGraph()), _status(TF_NewStatus()),
_options(TF_NewImportGraphDefOptions()), _sessionOptions(TF_NewSessionOptions()),
_session(TF_NewSession(_graph, _sessionOptions, _status)), _runOptions(NULL),
// Lazy initialization: TF objects are created in configure(), not here.
// This prevents GPU initialization when the algorithm is just registered.
TensorflowPredict() : _graph(NULL), _status(NULL),
_options(NULL), _sessionOptions(NULL),
_session(NULL), _runOptions(NULL),
_isConfigured(false) {
declareInput(_poolIn, "poolIn", "the pool where to get the feature tensors");
declareOutput(_poolOut, "poolOut", "the pool where to store the output tensors");
}

~TensorflowPredict(){
TF_CloseSession(_session, _status);
TF_DeleteSessionOptions(_sessionOptions);
TF_DeleteSession(_session, _status);
TF_DeleteImportGraphDefOptions(_options);
TF_DeleteStatus(_status);
TF_DeleteGraph(_graph);
TF_DeleteBuffer(_runOptions);
// Guard against destruction before configure() was ever called
if (_session) {
TF_CloseSession(_session, _status);
TF_DeleteSession(_session, _status);
}
if (_sessionOptions) TF_DeleteSessionOptions(_sessionOptions);
if (_options) TF_DeleteImportGraphDefOptions(_options);
if (_status) TF_DeleteStatus(_status);
if (_graph) TF_DeleteGraph(_graph);
if (_runOptions) TF_DeleteBuffer(_runOptions);
}

void declareParameters() {
Expand Down