Skip to content

Commit 3994cf6

Browse files
committed
Add use_featurenames
1 parent fdfd4ac commit 3994cf6

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

src/SoleDecisionTreeInterface.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@ using Sole: DecisionTree
99

1010
export solemodel
1111

12-
function solemodel(tree::DT.InfoNode, keep_condensed = false)
12+
function solemodel(tree::DT.InfoNode, keep_condensed = false, use_featurenames = true, kwargs...)
1313
# @show fieldnames(typeof(tree))
14+
use_featurenames = use_featurenames ? tree.info.featurenames : false
1415
root, info = begin
1516
if keep_condensed
16-
root = solemodel(tree.node)
17+
root = solemodel(tree.node; use_featurenames = use_featurenames, kwargs...)
1718
info = (;
1819
apply_preprocess=(y -> UInt32(findfirst(x -> x == y, tree.info.classlabels))),
1920
apply_postprocess=(y -> tree.info.classlabels[y]),
2021
)
2122
root, info
2223
else
23-
root = solemodel(tree.node, tree.info.classlabels)
24+
root = solemodel(tree.node; replace_classlabels = tree.info.classlabels, use_featurenames = use_featurenames, kwargs...)
2425
info = (;)
2526
root, info
2627
end
@@ -48,21 +49,22 @@ end
4849
# return DecisionTree(root, info)
4950
# end
5051

51-
function solemodel(tree::DT.Node, replace_classlabels = nothing)
52+
function solemodel(tree::DT.Node; replace_classlabels = nothing, use_featurenames = false)
5253
test_operator = (<)
5354
# @show fieldnames(typeof(tree))
54-
cond = ScalarCondition(Sole.VariableValue(tree.featid), test_operator, tree.featval)
55+
feature = (use_featurenames != false) ? Sole.VariableValue(use_featurenames[tree.featid]) : Sole.VariableValue(tree.featid)
56+
cond = ScalarCondition(feature, test_operator, tree.featval)
5557
antecedent = Atom(cond)
56-
lefttree = solemodel(tree.left, replace_classlabels)
57-
righttree = solemodel(tree.right, replace_classlabels)
58+
lefttree = solemodel(tree.left; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames)
59+
righttree = solemodel(tree.right; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames)
5860
info = (;
5961
supporting_predictions = [lefttree.info[:supporting_predictions]..., righttree.info[:supporting_predictions]...],
6062
supporting_labels = [lefttree.info[:supporting_labels]..., righttree.info[:supporting_labels]...],
6163
)
6264
return Branch(antecedent, lefttree, righttree, info)
6365
end
6466

65-
function solemodel(tree::DT.Leaf, replace_classlabels = nothing)
67+
function solemodel(tree::DT.Leaf; replace_classlabels = nothing, use_featurenames = false)
6668
# @show fieldnames(typeof(tree))
6769
prediction = tree.majority
6870
labels = tree.values

0 commit comments

Comments
 (0)