From 2fbc407498d42052ad883033566c2b4e15cc8ba8 Mon Sep 17 00:00:00 2001 From: mfdel Date: Wed, 14 Jan 2026 21:21:37 +0100 Subject: [PATCH 1/5] Add Booster.compute_leaf_similarity() method Compute similarity between observations based on leaf node co-occurrence across trees. Similar to Random Forest proximity matrices. - Two weight types: 'gain' (default) and 'cover' - Returns similarity matrix with values in [0, 1] - Self-similarity is 1.0 Closes #11919 --- python-package/xgboost/core.py | 64 ++++++++++++++++++++++++++++ tests/python/test_leaf_similarity.py | 47 ++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 tests/python/test_leaf_similarity.py diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 528b44c6fe25..3ffe681a4a07 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2981,6 +2981,70 @@ def inplace_predict( "Data type:" + str(type(data)) + " not supported by inplace prediction." ) + def compute_leaf_similarity( + self, + data: DMatrix, + reference: DMatrix, + weight_type: str = "gain", + ) -> np.ndarray: + """Compute similarity between observations based on leaf node co-occurrence. + + Two samples are similar if they land in the same leaf nodes across trees. + This is similar to Random Forest proximity matrices. + + Parameters + ---------- + data : + Query dataset (m samples). + reference : + Reference dataset (n samples). + weight_type : + How to weight trees: "gain" (by loss improvement) or "cover" + (by hessian sum, approximately sample count for regression). + + Returns + ------- + similarity : ndarray of shape (m, n) + Similarity scores in [0, 1]. + """ + if weight_type not in ("gain", "cover"): + raise ValueError( + f"weight_type must be 'gain' or 'cover', got '{weight_type}'" + ) + + query_leaves = self.predict(data, pred_leaf=True) + ref_leaves = self.predict(reference, pred_leaf=True) + + if query_leaves.ndim == 1: + query_leaves = query_leaves.reshape(-1, 1) + if ref_leaves.ndim == 1: + ref_leaves = ref_leaves.reshape(-1, 1) + + n_trees = query_leaves.shape[1] + + trees_df = self.trees_to_dataframe() + split_nodes = trees_df[trees_df["Feature"] != "Leaf"] + col = "Gain" if weight_type == "gain" else "Cover" + tree_weights = split_nodes.groupby("Tree")[col].sum() + + weights = np.zeros(n_trees, dtype=np.float32) + for tree_id, w in tree_weights.items(): + if tree_id < n_trees: + weights[int(tree_id)] = w + + if weights.sum() == 0: + weights = np.ones(n_trees, dtype=np.float32) + + total_weight = weights.sum() + m, n = len(query_leaves), len(ref_leaves) + + similarity = np.zeros((m, n), dtype=np.float32) + for i in range(m): + matches_i = query_leaves[i] == ref_leaves + similarity[i] = (matches_i * weights).sum(axis=1) / total_weight + + return similarity + def save_model(self, fname: PathLike) -> None: """Save the model to a file. diff --git a/tests/python/test_leaf_similarity.py b/tests/python/test_leaf_similarity.py new file mode 100644 index 000000000000..3dcbf9e804da --- /dev/null +++ b/tests/python/test_leaf_similarity.py @@ -0,0 +1,47 @@ +"""Tests for leaf similarity computation.""" + +import numpy as np +import pytest + +import xgboost as xgb +from xgboost import testing as tm + +rng = np.random.RandomState(1994) + + +class TestLeafSimilarity: + """Tests for Booster.compute_leaf_similarity()""" + + def test_leaf_similarity(self) -> None: + """Test basic leaf similarity computation.""" + dtrain, _ = tm.load_agaricus(__file__) + param = {"max_depth": 4, "eta": 0.3, "objective": "binary:logistic"} + bst = xgb.train(param, dtrain, num_boost_round=10) + + X = dtrain.get_data() + dm_query = xgb.DMatrix(X[:10]) + dm_ref = xgb.DMatrix(X[100:150]) + + # Test shape and range + similarity = bst.compute_leaf_similarity(dm_query, dm_ref) + assert similarity.shape == (10, 50) + assert similarity.min() >= 0.0 + assert similarity.max() <= 1.0 + + # Self-similarity diagonal should be 1.0 + dm_self = xgb.DMatrix(X[:20]) + self_sim = bst.compute_leaf_similarity(dm_self, dm_self) + np.testing.assert_allclose(np.diag(self_sim), 1.0, rtol=1e-5) + + # Test weight types + sim_gain = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="gain") + sim_cover = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="cover") + assert sim_gain.shape == sim_cover.shape + + # Default should be gain + sim_default = bst.compute_leaf_similarity(dm_query, dm_ref) + np.testing.assert_array_equal(sim_default, sim_gain) + + # Invalid weight_type + with pytest.raises(ValueError, match="weight_type must be"): + bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="invalid") From 79a5791d3ac49cab32f560086e70cb42e87e0876 Mon Sep 17 00:00:00 2001 From: ZhuYizhou2333 <2135378845@qq.com> Date: Sun, 12 Apr 2026 12:20:08 +0800 Subject: [PATCH 2/5] refactor: optimize leaf similarity computation --- python-package/xgboost/core.py | 90 +++++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 28 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3ffe681a4a07..209e3972bc5f 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2985,7 +2985,7 @@ def compute_leaf_similarity( self, data: DMatrix, reference: DMatrix, - weight_type: str = "gain", + weight_type: str = "uniform", ) -> np.ndarray: """Compute similarity between observations based on leaf node co-occurrence. @@ -2999,51 +2999,85 @@ def compute_leaf_similarity( reference : Reference dataset (n samples). weight_type : - How to weight trees: "gain" (by loss improvement) or "cover" - (by hessian sum, approximately sample count for regression). + How to weight trees: "uniform" (equal tree weights), "gain" + (by loss improvement), or "cover" (by hessian sum, approximately + sample count for regression). Returns ------- similarity : ndarray of shape (m, n) Similarity scores in [0, 1]. """ - if weight_type not in ("gain", "cover"): + if weight_type not in ("uniform", "gain", "cover"): raise ValueError( - f"weight_type must be 'gain' or 'cover', got '{weight_type}'" + "weight_type must be 'uniform', 'gain', or 'cover', " + f"got '{weight_type}'" ) - query_leaves = self.predict(data, pred_leaf=True) - ref_leaves = self.predict(reference, pred_leaf=True) + query_leaves = self.predict(data, pred_leaf=True, strict_shape=True) + ref_leaves = self.predict(reference, pred_leaf=True, strict_shape=True) - if query_leaves.ndim == 1: - query_leaves = query_leaves.reshape(-1, 1) - if ref_leaves.ndim == 1: - ref_leaves = ref_leaves.reshape(-1, 1) + query_leaves = np.asarray(query_leaves, dtype=np.int64).reshape( + query_leaves.shape[0], -1 + ) + ref_leaves = np.asarray(ref_leaves, dtype=np.int64).reshape( + ref_leaves.shape[0], -1 + ) + + m, n = query_leaves.shape[0], ref_leaves.shape[0] + if query_leaves.shape[1] != ref_leaves.shape[1]: + raise ValueError("Query and reference leaf predictions have different shapes.") n_trees = query_leaves.shape[1] + if m == 0 or n == 0 or n_trees == 0: + return np.zeros((m, n), dtype=np.float32) - trees_df = self.trees_to_dataframe() - split_nodes = trees_df[trees_df["Feature"] != "Leaf"] - col = "Gain" if weight_type == "gain" else "Cover" - tree_weights = split_nodes.groupby("Tree")[col].sum() + if weight_type == "uniform": + weights = np.ones(n_trees, dtype=np.float32) + else: + trees_df = self.trees_to_dataframe() + split_nodes = trees_df[trees_df["Feature"] != "Leaf"] + col = "Gain" if weight_type == "gain" else "Cover" + tree_weights = split_nodes.groupby("Tree")[col].sum() - weights = np.zeros(n_trees, dtype=np.float32) - for tree_id, w in tree_weights.items(): - if tree_id < n_trees: - weights[int(tree_id)] = w + weights = np.zeros(n_trees, dtype=np.float32) + for tree_id, w in tree_weights.items(): + if tree_id < n_trees: + weights[int(tree_id)] = w - if weights.sum() == 0: - weights = np.ones(n_trees, dtype=np.float32) + if weights.sum() == 0: + weights = np.ones(n_trees, dtype=np.float32) total_weight = weights.sum() - m, n = len(query_leaves), len(ref_leaves) - - similarity = np.zeros((m, n), dtype=np.float32) - for i in range(m): - matches_i = query_leaves[i] == ref_leaves - similarity[i] = (matches_i * weights).sum(axis=1) / total_weight + if total_weight == 0: + weights = np.ones(n_trees, dtype=np.float32) + total_weight = weights.sum() + + leaf_upper = np.maximum(query_leaves.max(axis=0), ref_leaves.max(axis=0)) + 1 + offsets = np.zeros(n_trees, dtype=np.int64) + if n_trees > 1: + offsets[1:] = np.cumsum(leaf_upper[:-1], dtype=np.int64) + + weight_values = np.sqrt(weights / total_weight, dtype=np.float32) + q_cols = (query_leaves + offsets).reshape(-1) + r_cols = (ref_leaves + offsets).reshape(-1) + q_rows = np.repeat(np.arange(m), n_trees) + r_rows = np.repeat(np.arange(n), n_trees) + feature_dim = int(offsets[-1] + leaf_upper[-1]) + + query_matrix = scipy.sparse.csr_matrix( + (np.tile(weight_values, m), (q_rows, q_cols)), + shape=(m, feature_dim), + dtype=np.float32, + ) + ref_matrix = scipy.sparse.csr_matrix( + (np.tile(weight_values, n), (r_rows, r_cols)), + shape=(n, feature_dim), + dtype=np.float32, + ) - return similarity + similarity = query_matrix @ ref_matrix.T + return similarity.toarray() def save_model(self, fname: PathLike) -> None: """Save the model to a file. From affe81bd49921cc1ffe8341f1f5b6a3bf7731611 Mon Sep 17 00:00:00 2001 From: ZhuYizhou2333 <2135378845@qq.com> Date: Sun, 12 Apr 2026 12:35:39 +0800 Subject: [PATCH 3/5] feat: expose leaf similarity weights in c api --- include/xgboost/c_api.h | 28 ++++++++++ include/xgboost/gbm.h | 5 ++ include/xgboost/learner.h | 5 ++ python-package/xgboost/core.py | 29 ++++++---- src/c_api/c_api.cc | 23 ++++++++ src/gbm/gblinear.cc | 10 ++++ src/gbm/gbtree.h | 39 ++++++++++++++ src/learner.cc | 10 ++++ tests/python/test_leaf_similarity.py | 79 ++++++++++++++++++++++++++-- 9 files changed, 216 insertions(+), 12 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 6b7a879350c9..ba398b894cb3 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1698,6 +1698,34 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, bst_ulong *out_n_features, char const ***out_features, bst_ulong *out_dim, bst_ulong const **out_shape, float const **out_scores); + +/** + * @brief Get per-tree weights for leaf similarity computation. + * + * @param config A JSON string with the following format: + * + * { + * "weight_type": str, + * "iteration_begin": int, + * "iteration_end": int + * } + * + * - weight_type: A JSON string with following possible values: + * * 'uniform': assign equal weight to each tree. + * * 'gain': sum split gain for each tree. + * * 'cover': sum split cover for each tree. + * - iteration_begin: Beginning iteration used when extracting tree weights. + * - iteration_end: End iteration used when extracting tree weights. 0 means + * using all remaining iterations. + * + * @param out_len Length of output tree weight array. + * @param out_weights Pointer to the output tree weight array. + * + * @return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterGetLeafSimilarityWeights(BoosterHandle handle, const char *config, + bst_ulong *out_len, + float const **out_weights); /**@}*/ // End of Booster /** diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 65940773ffee..5c5602759c60 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -151,6 +151,11 @@ class GradientBooster : public Model, public Configurable { common::Span trees, std::vector* features, std::vector* scores) const = 0; + + virtual void LeafSimilarityWeights(std::string const& weight_type, + bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) const = 0; /** * @brief Getter for categories. */ diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index ffaddfbe6442..8f7ea2d25830 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -140,6 +140,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { std::vector* features, std::vector* scores) = 0; + virtual void CalcLeafSimilarityWeights(std::string const& weight_type, + bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) = 0; + /* * \brief Get number of boosted rounds from gradient booster. */ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 209e3972bc5f..530da9c2f586 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -3035,15 +3035,26 @@ def compute_leaf_similarity( if weight_type == "uniform": weights = np.ones(n_trees, dtype=np.float32) else: - trees_df = self.trees_to_dataframe() - split_nodes = trees_df[trees_df["Feature"] != "Leaf"] - col = "Gain" if weight_type == "gain" else "Cover" - tree_weights = split_nodes.groupby("Tree")[col].sum() - - weights = np.zeros(n_trees, dtype=np.float32) - for tree_id, w in tree_weights.items(): - if tree_id < n_trees: - weights[int(tree_id)] = w + out_len = c_bst_ulong() + out_weights = ctypes.POINTER(ctypes.c_float)() + _check_call( + _LIB.XGBoosterGetLeafSimilarityWeights( + self.handle, + make_jcargs( + weight_type=weight_type, + iteration_begin=0, + iteration_end=0, + ), + ctypes.byref(out_len), + ctypes.byref(out_weights), + ) + ) + weights = ctypes2numpy(out_weights, out_len.value, np.float32) + if weights.shape[0] != n_trees: + raise ValueError( + "Tree weight count does not match leaf prediction shape: " + f"{weights.shape[0]} != {n_trees}" + ) if weights.sum() == 0: weights = np.ones(n_trees, dtype=np.float32) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1afdfdb4f91a..8b115bed1c9f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -2023,3 +2023,26 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config, *out_features = dmlc::BeginPtr(feature_names_c); API_END(); } + +XGB_DLL int XGBoosterGetLeafSimilarityWeights(BoosterHandle handle, char const *config, + bst_ulong *out_len, + float const **out_weights) { + API_BEGIN(); + CHECK_HANDLE(); + xgboost_CHECK_C_ARG_PTR(config); + auto *learner = static_cast(handle); + auto jconfig = Json::Load(StringView{config}); + + auto weight_type = RequiredArg(jconfig, "weight_type", __func__); + auto iteration_begin = RequiredArg(jconfig, "iteration_begin", __func__); + auto iteration_end = RequiredArg(jconfig, "iteration_end", __func__); + + auto &weights = learner->GetThreadLocal().ret_vec_float; + learner->CalcLeafSimilarityWeights(weight_type, iteration_begin, iteration_end, &weights); + + xgboost_CHECK_C_ARG_PTR(out_len); + xgboost_CHECK_C_ARG_PTR(out_weights); + *out_len = weights.size(); + *out_weights = dmlc::BeginPtr(weights); + API_END(); +} diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 3bd03a3b4a41..0aef1e000152 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -237,6 +237,16 @@ class GBLinear : public GradientBooster { } } + void LeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) const override { + (void)weight_type; + (void)iteration_begin; + (void)iteration_end; + (void)weights; + LOG(FATAL) << "Leaf similarity weights are not defined for gblinear booster."; + } + protected: void PredictBatchInternal(DMatrix *p_fmat, std::vector *out_preds) { diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index ec39e2748799..5b441d3b75e1 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -289,6 +289,45 @@ class GBTree : public GradientBooster { } } + void LeafSimilarityWeights(std::string const& weight_type, + bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) const override { + auto [tree_begin, tree_end] = detail::LayerToTree(model_, iteration_begin, iteration_end); + weights->clear(); + weights->reserve(tree_end - tree_begin); + + auto const get_weight = [&](RegTree const& tree) { + CHECK(!tree.IsMultiTarget()) << "Leaf similarity weights for multi-target tree " + << MTNotImplemented(); + tree::ScalarTreeView view{&tree}; + + if (weight_type == "uniform") { + return 1.0f; + } + + float weight = 0.0f; + for (bst_node_t nidx = 0; nidx < view.Size(); ++nidx) { + if (!view.IsLeaf(nidx)) { + if (weight_type == "gain") { + weight += view.LossChg(nidx); + } else if (weight_type == "cover") { + weight += view.SumHess(nidx); + } else { + LOG(FATAL) << "Unknown leaf similarity weight type, expected one of: " + << R"({"uniform", "gain", "cover"}, got: )" << weight_type; + } + } + } + return weight; + }; + + for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& tree = *model_.trees.at(tree_idx); + weights->push_back(get_weight(tree)); + } + } + [[nodiscard]] CatContainer const* Cats() const override { return this->model_.Cats(); } void PredictLeaf(DMatrix* p_fmat, diff --git a/src/learner.cc b/src/learner.cc index 26ebc1e6f9c4..ba80147d8f53 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1266,6 +1266,16 @@ class LearnerImpl : public LearnerIO { gbm_->FeatureScore(importance_type, trees, features, scores); } + void CalcLeafSimilarityWeights(std::string const& weight_type, + bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) override { + this->Configure(); + this->CheckModelInitialized(); + + gbm_->LeafSimilarityWeights(weight_type, iteration_begin, iteration_end, weights); + } + const std::map& GetConfigurationArguments() const override { return cfg_; } diff --git a/tests/python/test_leaf_similarity.py b/tests/python/test_leaf_similarity.py index 3dcbf9e804da..6e5c5f669e3b 100644 --- a/tests/python/test_leaf_similarity.py +++ b/tests/python/test_leaf_similarity.py @@ -1,9 +1,18 @@ """Tests for leaf similarity computation.""" +import ctypes + import numpy as np import pytest import xgboost as xgb +from xgboost.core import ( + _LIB, + _check_call, + c_bst_ulong, + ctypes2numpy, + from_pystr_to_cstr, +) from xgboost import testing as tm rng = np.random.RandomState(1994) @@ -12,6 +21,69 @@ class TestLeafSimilarity: """Tests for Booster.compute_leaf_similarity()""" + @pytest.mark.parametrize( + ("param", "num_boost_round"), + [ + ({"max_depth": 4, "eta": 0.3, "objective": "binary:logistic"}, 8), + ( + { + "max_depth": 3, + "eta": 0.3, + "objective": "multi:softprob", + "num_class": 3, + }, + 6, + ), + ( + { + "max_depth": 4, + "eta": 1.0, + "objective": "binary:logistic", + "num_parallel_tree": 3, + }, + 5, + ), + ], + ) + @pytest.mark.parametrize(("weight_type", "column"), [("gain", "Gain"), ("cover", "Cover")]) + def test_leaf_similarity_weight_api( + self, param: dict, num_boost_round: int, weight_type: str, column: str + ) -> None: + """Test the low-level tree weight API shape and order contract.""" + dtrain, _ = tm.load_agaricus(__file__) + bst = xgb.train(param, dtrain, num_boost_round=num_boost_round) + + leaves = bst.predict(dtrain, pred_leaf=True, strict_shape=True) + expected_len = int(np.prod(leaves.shape[1:])) + expected_weights = np.zeros(expected_len, dtype=np.float32) + trees_df = bst.trees_to_dataframe() + split_nodes = trees_df[trees_df["Feature"] != "Leaf"] + tree_weights = split_nodes.groupby("Tree")[column].sum() + for tree_id, weight in tree_weights.items(): + expected_weights[int(tree_id)] = weight + + config = from_pystr_to_cstr( + ( + "{" + f'"weight_type":"{weight_type}",' + '"iteration_begin":0,' + '"iteration_end":0' + "}" + ) + ) + out_len = c_bst_ulong() + out_weights = ctypes.POINTER(ctypes.c_float)() + + _check_call( + _LIB.XGBoosterGetLeafSimilarityWeights( + bst.handle, config, ctypes.byref(out_len), ctypes.byref(out_weights) + ) + ) + + assert out_len.value == expected_len + weights = ctypes2numpy(out_weights, out_len.value, np.float32) + np.testing.assert_allclose(weights, expected_weights, rtol=1e-6, atol=1e-3) + def test_leaf_similarity(self) -> None: """Test basic leaf similarity computation.""" dtrain, _ = tm.load_agaricus(__file__) @@ -26,7 +98,7 @@ def test_leaf_similarity(self) -> None: similarity = bst.compute_leaf_similarity(dm_query, dm_ref) assert similarity.shape == (10, 50) assert similarity.min() >= 0.0 - assert similarity.max() <= 1.0 + assert similarity.max() <= 1.0 + 1e-6 # Self-similarity diagonal should be 1.0 dm_self = xgb.DMatrix(X[:20]) @@ -38,9 +110,10 @@ def test_leaf_similarity(self) -> None: sim_cover = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="cover") assert sim_gain.shape == sim_cover.shape - # Default should be gain + # Default should be uniform + sim_uniform = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="uniform") sim_default = bst.compute_leaf_similarity(dm_query, dm_ref) - np.testing.assert_array_equal(sim_default, sim_gain) + np.testing.assert_array_equal(sim_default, sim_uniform) # Invalid weight_type with pytest.raises(ValueError, match="weight_type must be"): From 4260822db387e1066518a72340e1473894908854 Mon Sep 17 00:00:00 2001 From: ZhuYizhou2333 <2135378845@qq.com> Date: Sun, 12 Apr 2026 12:54:47 +0800 Subject: [PATCH 4/5] test: cover leaf similarity compatibility modes --- python-package/xgboost/core.py | 12 +++ tests/python/test_leaf_similarity.py | 115 +++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 530da9c2f586..da99d9c47006 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -3014,6 +3014,18 @@ def compute_leaf_similarity( f"got '{weight_type}'" ) + config = json.loads(self.save_config())["learner"] + booster = config["gradient_booster"]["name"] + if booster == "gblinear": + raise XGBoostError( + "Leaf similarity is only defined for tree boosters, got gblinear." + ) + + if config["learner_train_param"]["multi_strategy"] == "multi_output_tree": + raise XGBoostError( + "Leaf similarity does not support multi_output_tree." + ) + query_leaves = self.predict(data, pred_leaf=True, strict_shape=True) ref_leaves = self.predict(reference, pred_leaf=True, strict_shape=True) diff --git a/tests/python/test_leaf_similarity.py b/tests/python/test_leaf_similarity.py index 6e5c5f669e3b..9a0ae8743546 100644 --- a/tests/python/test_leaf_similarity.py +++ b/tests/python/test_leaf_similarity.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from sklearn.datasets import load_diabetes, load_iris import xgboost as xgb from xgboost.core import ( @@ -118,3 +119,117 @@ def test_leaf_similarity(self) -> None: # Invalid weight_type with pytest.raises(ValueError, match="weight_type must be"): bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="invalid") + + @pytest.mark.parametrize( + "param", + [ + { + "max_depth": 3, + "eta": 0.3, + "objective": "multi:softprob", + "num_class": 3, + }, + { + "booster": "dart", + "max_depth": 4, + "eta": 0.3, + "objective": "binary:logistic", + }, + { + "max_depth": 4, + "eta": 1.0, + "objective": "binary:logistic", + "num_parallel_tree": 3, + }, + ], + ) + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_supported_tree_modes( + self, param: dict, weight_type: str + ) -> None: + """Test supported tree model modes.""" + if param.get("objective") == "multi:softprob": + X, y = load_iris(return_X_y=True) + dtrain = xgb.DMatrix(X, label=y) + dm_query = xgb.DMatrix(X[:5]) + dm_ref = xgb.DMatrix(X[10:20]) + rounds = 8 + else: + dtrain, _ = tm.load_agaricus(__file__) + X = dtrain.get_data() + dm_query = xgb.DMatrix(X[:10]) + dm_ref = xgb.DMatrix(X[100:130]) + rounds = 8 if param.get("booster") == "dart" else 5 + + bst = xgb.train(param, dtrain, num_boost_round=rounds) + similarity = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type=weight_type) + assert similarity.shape == (dm_query.num_row(), dm_ref.num_row()) + assert similarity.min() >= 0.0 + assert similarity.max() <= 1.0 + 1e-6 + + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_one_output_per_tree_multi_target( + self, weight_type: str + ) -> None: + """Test multi-target model with one output per tree.""" + X, y = load_diabetes(return_X_y=True) + y = np.column_stack([y, y * 0.5]) + dtrain = xgb.DMatrix(X, label=y) + bst = xgb.train( + { + "max_depth": 3, + "eta": 0.3, + "tree_method": "hist", + "objective": "reg:squarederror", + "multi_strategy": "one_output_per_tree", + "num_target": 2, + }, + dtrain, + num_boost_round=6, + ) + + similarity = bst.compute_leaf_similarity( + xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type + ) + assert similarity.shape == (5, 10) + assert similarity.min() >= 0.0 + assert similarity.max() <= 1.0 + 1e-6 + + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_gblinear_error(self, weight_type: str) -> None: + """Test unsupported gblinear booster with stable error.""" + dtrain, _ = tm.load_agaricus(__file__) + bst = xgb.train({"booster": "gblinear", "objective": "binary:logistic"}, dtrain) + X = dtrain.get_data() + + with pytest.raises(xgb.core.XGBoostError, match="Leaf similarity is only defined"): + bst.compute_leaf_similarity( + xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type + ) + + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_multi_output_tree_error(self, weight_type: str) -> None: + """Test unsupported multi-output tree with stable error.""" + X, y = load_diabetes(return_X_y=True) + y = np.column_stack([y, y * 0.5]) + dtrain = xgb.DMatrix(X, label=y) + bst = xgb.train( + { + "max_depth": 3, + "eta": 0.3, + "tree_method": "hist", + "objective": "reg:squarederror", + "multi_strategy": "multi_output_tree", + "num_target": 2, + }, + dtrain, + num_boost_round=6, + ) + + with pytest.raises( + xgb.core.XGBoostError, + match="Leaf similarity does not support multi_output_tree", + ): + bst.compute_leaf_similarity( + xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type + ) From 7d2e9bdcba352b507ada170a923c7e9c4433b4a9 Mon Sep 17 00:00:00 2001 From: ZhuYizhou2333 <2135378845@qq.com> Date: Tue, 14 Apr 2026 20:45:44 +0800 Subject: [PATCH 5/5] fix: address pre-commit issues in leaf similarity follow-up --- include/xgboost/c_api.h | 234 ++++++++----------- include/xgboost/gbm.h | 17 +- include/xgboost/learner.h | 18 +- python-package/xgboost/core.py | 14 +- src/c_api/c_api.cc | 328 +++++++++++---------------- src/gbm/gblinear.cc | 51 ++--- src/gbm/gbtree.h | 39 ++-- src/learner.cc | 135 +++++------ tests/python/test_leaf_similarity.py | 32 +-- 9 files changed, 356 insertions(+), 512 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index ba398b894cb3..071e8af71590 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -9,19 +9,19 @@ #ifdef __cplusplus #define XGB_EXTERN_C extern "C" #include -#include #include +#include #else #define XGB_EXTERN_C #include -#include #include +#include #endif // __cplusplus #if defined(_MSC_VER) || defined(_WIN32) #define XGB_DLL XGB_EXTERN_C __declspec(dllexport) #else -#define XGB_DLL XGB_EXTERN_C __attribute__ ((visibility ("default"))) +#define XGB_DLL XGB_EXTERN_C __attribute__((visibility("default"))) #endif // defined(_MSC_VER) || defined(_WIN32) // manually define unsigned long @@ -65,7 +65,7 @@ typedef void *CategoriesHandle; // NOLINT(*) * @param minor Store the minor version number. * @param patch Store the patch (revision) number. */ -XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch); +XGB_DLL void XGBoostVersion(int *major, int *minor, int *patch); /** * @brief Get compile information of the shared XGBoost library. @@ -98,7 +98,7 @@ XGB_DLL const char *XGBGetLastError(); * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)); +XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char *)); /** * @brief Set global configuration (collection of parameters that apply globally). This function @@ -112,7 +112,8 @@ XGB_DLL int XGBSetGlobalConfig(char const *config); /** * @brief Get current global configuration (collection of parameters that apply globally). - * @param out_config pointer to received returned global configuration, represented as a JSON string. + * @param out_config pointer to received returned global configuration, represented as a JSON + * string. * @return 0 when success, -1 when failure happens */ XGB_DLL int XGBGetGlobalConfig(char const **out_config); @@ -149,12 +150,14 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle * @brief load a data matrix * * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: - * - uri: The URI of the input file. The URI parameter `format` is required when loading text data. + * - uri: The URI of the input file. The URI parameter `format` is required when loading text + * data. * @verbatim embed:rst:leading-asterisk * See :doc:`/tutorials/input_format` for more info. * @endverbatim * - silent (optional): Whether to print message during loading. Default to true. - * - data_split_mode (optional): Whether the file was split by row or column beforehand for distributed computing. Default to row. + * - data_split_mode (optional): Whether the file was split by row or column beforehand for + * distributed computing. Default to row. * @param out a loaded data matrix * @return 0 when success, -1 when failure happens */ @@ -243,7 +246,8 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char * @param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. - * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default + * to row. * @param out The created DMatrix * * @return 0 when success, -1 when failure happens @@ -265,7 +269,6 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char const *data, bst_ulong nrow, char const *config, DMatrixHandle *out); - /** * @brief create matrix content from dense matrix * @param data pointer to the data space @@ -275,10 +278,7 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char * @param out created dmatrix * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixCreateFromMat(const float *data, - bst_ulong nrow, - bst_ulong ncol, - float missing, +XGB_DLL int XGDMatrixCreateFromMat(const float *data, bst_ulong nrow, bst_ulong ncol, float missing, DMatrixHandle *out); /** * @brief create matrix content from dense matrix @@ -291,9 +291,8 @@ XGB_DLL int XGDMatrixCreateFromMat(const float *data, * @return 0 when success, -1 when failure happens */ XGB_DLL int XGDMatrixCreateFromMat_omp(const float *data, // NOLINT - bst_ulong nrow, bst_ulong ncol, - float missing, DMatrixHandle *out, - int nthread); + bst_ulong nrow, bst_ulong ncol, float missing, + DMatrixHandle *out, int nthread); /** * @brief Create DMatrix from CUDA columnar format. (cuDF) @@ -315,7 +314,8 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, char const *config * @param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. - * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default + * to row. * @param out created dmatrix * @return 0 when success, -1 when failure happens */ @@ -380,7 +380,6 @@ typedef void *DataIterHandle; // NOLINT(*) /** @brief handle to an internal data holder. */ typedef void *DataHolderHandle; // NOLINT(*) - /** @brief Mini batch used in XGBoost Data Iteration */ typedef struct { // NOLINT(*) /** @brief number of rows in the minibatch */ @@ -391,18 +390,18 @@ typedef struct { // NOLINT(*) #ifdef __APPLE__ /* Necessary as Java on MacOS defines jlong as long int * and gcc defines int64_t as long long int. */ - long* offset; // NOLINT(*) + long *offset; // NOLINT(*) #else - int64_t* offset; // NOLINT(*) + int64_t *offset; // NOLINT(*) #endif // __APPLE__ /** @brief labels of each instance */ - float* label; + float *label; /** @brief weight of each instance, can be NULL */ - float* weight; + float *weight; /** @brief feature index */ - int* index; + int *index; /** @brief feature values */ - float* value; + float *value; } XGBoostBatchCSR; /** @@ -437,12 +436,9 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*) * @param out The created DMatrix * @return 0 when success, -1 when failure happens. */ -XGB_DLL int XGDMatrixCreateFromDataIter( - DataIterHandle data_handle, - XGBCallbackDataIterNext* callback, - const char* cache_info, - float missing, - DMatrixHandle *out); +XGB_DLL int XGDMatrixCreateFromDataIter(DataIterHandle data_handle, + XGBCallbackDataIterNext *callback, const char *cache_info, + float missing, DMatrixHandle *out); /** * Second set of callback functions, used by constructing Quantile DMatrix or external @@ -469,7 +465,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter( * * @return 0 when success, -1 when failure happens. */ -XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out); +XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out); /** * @brief Callback function prototype for getting next batch of data. @@ -483,8 +479,7 @@ XGB_EXTERN_C typedef int XGDMatrixCallbackNext(DataIterHandle iter); // NOLINT( /** * @brief Callback function prototype for resetting the external iterator. */ -XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLINT(*) - +XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLINT(*) /** * @brief Create an external memory DMatrix with data iterator. @@ -503,7 +498,8 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN * @param next Callback function yielding the next batch of data. * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: * - missing: Which value to represent missing value - * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. + * - cache_prefix: The path of cache file, caller must initialize all the directories in this + * path. * - nthread (optional): Number of threads used for initializing DMatrix. * @param[out] out The created external memory DMatrix * @@ -572,7 +568,8 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand * @param next Callback function yielding the next batch of data. * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: * - missing: Which value to represent missing value - * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. + * - cache_prefix: The path of cache file, caller must initialize all the directories in this + * path. * - nthread (optional): Number of threads used for initializing DMatrix. * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with * the corresponding booster training parameter. @@ -653,9 +650,8 @@ XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle, char const *data); * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, - char const *indices, char const *data, - bst_ulong ncol); +XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, char const *indices, + char const *data, bst_ulong ncol); /** @} */ // End of Streaming @@ -667,9 +663,7 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, * @param out a sliced new matrix * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, - const int *idxset, - bst_ulong len, +XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int *idxset, bst_ulong len, DMatrixHandle *out); /** * @brief create a new dmatrix from sliced content of existing matrix @@ -680,11 +674,8 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, * @param allow_groups allow slicing of an array with groups * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, - const int *idxset, - bst_ulong len, - DMatrixHandle *out, - int allow_groups); +XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, const int *idxset, bst_ulong len, + DMatrixHandle *out, int allow_groups); /** * @brief Free a DMatrix object. * @@ -705,8 +696,7 @@ XGB_DLL int XGDMatrixFree(DMatrixHandle handle); * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, - const char *fname, int silent); +XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char *fname, int silent); /** * @brief Set content in array interface to a content in info. @@ -765,8 +755,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const * @endcode */ XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field, - const char **features, - const bst_ulong size); + const char **features, const bst_ulong size); /** * @brief Get string encoded information of all features. @@ -803,8 +792,7 @@ XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field, * * @endcode */ -XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, - bst_ulong *size, +XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, bst_ulong *size, const char ***out_features); /** @@ -901,9 +889,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle, const char *field, * @param out_dptr pointer to the result * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, - const char *field, - bst_ulong* out_len, +XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, const char *field, bst_ulong *out_len, const unsigned **out_dptr); /** * @brief Get the number of rows from a DMatrix. @@ -981,7 +967,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config * __(cuda_)array_interface__. */ XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config, - char const **out_indptr, char const **out_data); + char const **out_indptr, char const **out_data); /** @} */ // End of DMatrix @@ -1046,8 +1032,7 @@ XGB_DLL int XGBoosterReset(BoosterHandle handle); * * @return 0 when success, -1 when failure happens, -2 when index is out of bound. */ -XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, - int end_layer, int step, +XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step, BoosterHandle *out); /** @@ -1057,7 +1042,7 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, * @param out Pointer to output integer. * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out); +XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int *out); /** * @brief set parameters @@ -1066,9 +1051,7 @@ XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out); * @param value value of parameter * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterSetParam(BoosterHandle handle, - const char *name, - const char *value); +XGB_DLL int XGBoosterSetParam(BoosterHandle handle, const char *name, const char *value); /** * @example c-api-demo.c */ @@ -1157,26 +1140,22 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, int iter, DMatrixHandle d * 1:output margin instead of transformed value * 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree * 4:output feature contributions to individual predictions - * @param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees - * when the parameter is set to 0, we will use all the trees + * @param ntree_limit limit number of trees used for prediction, this is only valid for boosted + * trees when the parameter is set to 0, we will use all the trees * @param training Whether the prediction function is used as part of a training loop. * Prediction can be run in 2 scenarios: * 1. Given data matrix X, obtain prediction y_pred from the model. * 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout - * during training, and the prediction result will be different from the one obtained by normal - * inference step due to dropped trees. - * Set training=false for the first scenario. Set training=true for the second scenario. - * The second scenario applies when you are defining a custom objective function. + * during training, and the prediction result will be different from the one obtained by + * normal inference step due to dropped trees. Set training=false for the first scenario. Set + * training=true for the second scenario. The second scenario applies when you are defining a custom + * objective function. * @param out_len used to store length of returning result * @param out_result used to set a pointer to array * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterPredict(BoosterHandle handle, - DMatrixHandle dmat, - int option_mask, - unsigned ntree_limit, - int training, - bst_ulong *out_len, +XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int option_mask, + unsigned ntree_limit, int training, bst_ulong *out_len, const float **out_result); /** @@ -1201,21 +1180,17 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, * * Prediction can be run in 2 scenarios: * 1. Given data matrix X, obtain prediction y_pred from the model. - * 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout - * during training, and the prediction result will be different from the one obtained by normal - * inference step due to dropped trees. - * Set training=false for the first scenario. Set training=true for the second - * scenario. The second scenario applies when you are defining a custom objective - * function. - * "iteration_begin": int - * Beginning iteration of prediction. + * 2. Obtain the prediction for computing gradients. For example, DART booster performs + * dropout during training, and the prediction result will be different from the one obtained by + * normal inference step due to dropped trees. Set training=false for the first scenario. Set + * training=true for the second scenario. The second scenario applies when you are defining a + * custom objective function. "iteration_begin": int Beginning iteration of prediction. * "iteration_end": int - * End iteration of prediction. Set to 0 this will become the size of tree model (all the trees). - * "strict_shape": bool - * Whether should we reshape the output with stricter rules. If set to true, - * normal/margin/contrib/interaction predict will output consistent shape - * disregarding the use of multi-class model, and leaf prediction will output 4-dim - * array representing: (n_samples, n_iterations, n_classes, n_trees_in_forest) + * End iteration of prediction. Set to 0 this will become the size of tree model (all the + * trees). "strict_shape": bool Whether should we reshape the output with stricter rules. If set to + * true, normal/margin/contrib/interaction predict will output consistent shape disregarding the use + * of multi-class model, and leaf prediction will output 4-dim array representing: (n_samples, + * n_iterations, n_classes, n_trees_in_forest) * * Example JSON input for running a normal prediction with strict output shape, 2 dim * for softprob , 1 dim for others. @@ -1235,7 +1210,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, * * @return 0 when success, -1 when failure happens * - * @see XGBoosterPredictFromDense XGBoosterPredictFromCSR XGBoosterPredictFromCudaArray XGBoosterPredictFromCudaColumnar + * @see XGBoosterPredictFromDense XGBoosterPredictFromCSR XGBoosterPredictFromCudaArray + * XGBoosterPredictFromCudaColumnar */ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat, char const *config, bst_ulong const **out_shape, @@ -1374,7 +1350,6 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *d /**@}*/ // End of Prediction - /** * @defgroup Serialization Serialization * @ingroup Booster @@ -1409,8 +1384,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *d * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, - const char *fname); +XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char *fname); /** * @brief Save the model into an existing file * @@ -1419,8 +1393,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, - const char *fname); +XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname); /** * @brief load model from in memory buffer * @@ -1429,9 +1402,7 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, * @param len the length of the buffer * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, - const void *buf, - bst_ulong len); +XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf, bst_ulong len); /** * @brief Save model into raw bytes, return header of the array. User must copy the @@ -1472,8 +1443,7 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len, * @param len the length of the buffer * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, - const void *buf, bst_ulong len); +XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, const void *buf, bst_ulong len); /** * @brief Save XGBoost's internal configuration into a JSON document. Currently the @@ -1487,8 +1457,7 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, * be managed by caller. * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len, - char const **out_str); +XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len, char const **out_str); /** * @brief Load XGBoost's internal configuration from a JSON document. Currently the * support is experimental, function signature may change in the future without @@ -1510,11 +1479,8 @@ XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const *config); * @param out_dump_array pointer to hold representing dump of each model * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, - const char *fmap, - int with_stats, - bst_ulong *out_len, - const char ***out_dump_array); +XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, const char *fmap, int with_stats, + bst_ulong *out_len, const char ***out_dump_array); /** * @brief dump model, return array of strings representing model dump @@ -1526,11 +1492,8 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, * @param out_dump_array pointer to hold representing dump of each model * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, - const char *fmap, - int with_stats, - const char *format, - bst_ulong *out_len, +XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, const char *fmap, int with_stats, + const char *format, bst_ulong *out_len, const char ***out_dump_array); /** @@ -1544,12 +1507,8 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, * @param out_models pointer to hold representing dump of each model * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, - int fnum, - const char **fname, - const char **ftype, - int with_stats, - bst_ulong *out_len, +XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, int fnum, const char **fname, + const char **ftype, int with_stats, bst_ulong *out_len, const char ***out_models); /** @@ -1564,14 +1523,9 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, * @param out_models pointer to hold representing dump of each model * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, - int fnum, - const char **fname, - const char **ftype, - int with_stats, - const char *format, - bst_ulong *out_len, - const char ***out_models); +XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, int fnum, const char **fname, + const char **ftype, int with_stats, const char *format, + bst_ulong *out_len, const char ***out_models); /** * See @ref XGDMatrixGetCategories @@ -1600,10 +1554,7 @@ XGB_DLL int XGBoosterGetCategoriesExportToArrow(BoosterHandle handle, char const * @param success Whether the result is contained in out. * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, - const char* key, - const char** out, - int *success); +XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, const char *key, const char **out, int *success); /** * @brief Set or delete string attribute. * @@ -1613,9 +1564,7 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, * If nullptr, the attribute would be deleted. * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, - const char* key, - const char* value); +XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, const char *key, const char *value); /** * @brief Get the names of all attribute from Booster. * @param handle handle @@ -1623,9 +1572,7 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, * @param out pointer to hold the output attribute stings * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, - bst_ulong* out_len, - const char*** out); +XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, bst_ulong *out_len, const char ***out); /** * @brief Set string encoded feature info in Booster, similar to the feature @@ -1643,8 +1590,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, * @return 0 when success, -1 when failure happens */ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field, - const char **features, - const bst_ulong size); + const char **features, const bst_ulong size); /** * @brief Get string encoded feature info from Booster, similar to the feature info @@ -1665,8 +1611,7 @@ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field, * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, - bst_ulong *len, +XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, bst_ulong *len, const char ***out_features); /** @@ -1724,8 +1669,7 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, * @return 0 when success, -1 when failure happens */ XGB_DLL int XGBoosterGetLeafSimilarityWeights(BoosterHandle handle, const char *config, - bst_ulong *out_len, - float const **out_weights); + bst_ulong *out_len, float const **out_weights); /**@}*/ // End of Booster /** @@ -1878,7 +1822,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle); * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGCommunicatorInit(char const* config); +XGB_DLL int XGCommunicatorInit(char const *config); /** * @brief Finalize the collective communicator. @@ -1927,7 +1871,7 @@ XGB_DLL int XGCommunicatorPrint(char const *message); * @param name_str Pointer to received returned processor name. * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str); +XGB_DLL int XGCommunicatorGetProcessorName(const char **name_str); /** * @brief Broadcast a memory region to all others from root. This function is NOT @@ -1971,4 +1915,4 @@ XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op); /**@}*/ // End of Collective -#endif // XGBOOST_C_API_H_ +#endif // XGBOOST_C_API_H_ diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 5c5602759c60..b979a070f868 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -116,8 +116,7 @@ class GradientBooster : public Model, public Configurable { * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. */ - virtual void PredictLeaf(DMatrix *dmat, - HostDeviceVector *out_preds, + virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, unsigned layer_begin, unsigned layer_end) = 0; /*! @@ -147,13 +146,11 @@ class GradientBooster : public Model, public Configurable { [[nodiscard]] virtual std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const = 0; - virtual void FeatureScore(std::string const& importance_type, - common::Span trees, + virtual void FeatureScore(std::string const& importance_type, common::Span trees, std::vector* features, std::vector* scores) const = 0; - virtual void LeafSimilarityWeights(std::string const& weight_type, - bst_layer_t iteration_begin, + virtual void LeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, bst_layer_t iteration_end, std::vector* weights) const = 0; /** @@ -195,10 +192,10 @@ struct GradientBoosterReg * }); * \endcode */ -#define XGBOOST_REGISTER_GBM(UniqueId, Name) \ - static DMLC_ATTRIBUTE_UNUSED ::xgboost::GradientBoosterReg & \ - __make_ ## GradientBoosterReg ## _ ## UniqueId ## __ = \ - ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(Name) +#define XGBOOST_REGISTER_GBM(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::GradientBoosterReg& \ + __make_##GradientBoosterReg##_##UniqueId##__ = \ + ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(Name) } // namespace xgboost #endif // XGBOOST_GBM_H_ diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 8f7ea2d25830..c45992c33f34 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -96,8 +96,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param data_names name of each dataset * \return a string corresponding to the evaluation result */ - virtual std::string EvalOneIter(int iter, - const std::vector>& data_sets, + virtual std::string EvalOneIter(int iter, const std::vector>& data_sets, const std::vector& data_names) = 0; /*! * \brief get prediction given the model. @@ -107,7 +106,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. * \param training Whether the prediction result is used for training - * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor + * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree + * predictor * \param pred_contribs whether to only predict the feature contributions * \param approx_contribs whether to approximate the feature contributions for speed * \param pred_interactions whether to compute the feature pair contributions @@ -141,8 +141,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { std::vector* scores) = 0; virtual void CalcLeafSimilarityWeights(std::string const& weight_type, - bst_layer_t iteration_begin, - bst_layer_t iteration_end, + bst_layer_t iteration_begin, bst_layer_t iteration_end, std::vector* weights) = 0; /* @@ -211,7 +210,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \brief Set the feature names for current booster. * \param fn Input feature names */ - virtual void SetFeatureNames(std::vector const& fn) = 0; + virtual void SetFeatureNames(std::vector const& fn) = 0; /*! * \brief Get the feature names for current booster. * \param fn Output feature names @@ -250,8 +249,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param format the format to dump the model in * \return a vector of dump for boosters. */ - virtual std::vector DumpModel(const FeatureMap& fmap, - bool with_stats, + virtual std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) = 0; virtual XGBAPIThreadLocalEntry& GetThreadLocal() const = 0; @@ -264,7 +262,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param cache_data The matrix to cache the prediction. * \return Created learner. */ - static Learner* Create(const std::vector >& cache_data); + static Learner* Create(const std::vector>& cache_data); /** * \brief Return the context object of this Booster. */ @@ -281,7 +279,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { /*! \brief The gradient booster used by the model*/ std::unique_ptr gbm_; /*! \brief The evaluation metrics used to evaluate the model. */ - std::vector > metrics_; + std::vector> metrics_; /*! \brief Training parameter. */ Context ctx_; }; diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index da99d9c47006..51f8f02443f1 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -806,9 +806,9 @@ def _get_categories( values = from_array_interface(jvalues) pa_offsets = pa.array(offsets).buffers() pa_values = pa.array(values).buffers() - assert ( - pa_offsets[0] is None and pa_values[0] is None - ), "Should not have null mask." + assert pa_offsets[0] is None and pa_values[0] is None, ( + "Should not have null mask." + ) pa_dict = pa.StringArray.from_buffers( len(offsets) - 1, pa_offsets[1], pa_values[1] ) @@ -3022,9 +3022,7 @@ def compute_leaf_similarity( ) if config["learner_train_param"]["multi_strategy"] == "multi_output_tree": - raise XGBoostError( - "Leaf similarity does not support multi_output_tree." - ) + raise XGBoostError("Leaf similarity does not support multi_output_tree.") query_leaves = self.predict(data, pred_leaf=True, strict_shape=True) ref_leaves = self.predict(reference, pred_leaf=True, strict_shape=True) @@ -3038,7 +3036,9 @@ def compute_leaf_similarity( m, n = query_leaves.shape[0], ref_leaves.shape[0] if query_leaves.shape[1] != ref_leaves.shape[1]: - raise ValueError("Query and reference leaf predictions have different shapes.") + raise ValueError( + "Query and reference leaf predictions have different shapes." + ) n_trees = query_leaves.shape[1] if m == 0 or n == 0 or n_trees == 0: diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8b115bed1c9f..792ad68bb193 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -49,9 +49,9 @@ #include "xgboost/string_view.h" // for StringView, operator<< #include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... -using namespace xgboost; // NOLINT(*); +using namespace xgboost; // NOLINT(*); -XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) { +XGB_DLL void XGBoostVersion(int *major, int *minor, int *patch) { if (major) { *major = XGBOOST_VER_MAJOR; } @@ -142,9 +142,9 @@ XGB_DLL int XGBuildInfo(char const **out) { API_END(); } -XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) { +XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char *)) { API_BEGIN_UNGUARD() - LogCallbackRegistry* registry = LogCallbackRegistryStore::Get(); + LogCallbackRegistry *registry = LogCallbackRegistryStore::Get(); registry->Register(callback); API_END(); } @@ -220,22 +220,22 @@ XGB_DLL int XGBSetGlobalConfig(const char *json_str) { API_END(); } -XGB_DLL int XGBGetGlobalConfig(const char** json_str) { +XGB_DLL int XGBGetGlobalConfig(const char **json_str) { API_BEGIN_UNGUARD() - auto const& global_config = *GlobalConfigThreadLocalStore::Get(); - Json config {ToJson(global_config)}; - auto const* mgr = global_config.__MANAGER__(); + auto const &global_config = *GlobalConfigThreadLocalStore::Get(); + Json config{ToJson(global_config)}; + auto const *mgr = global_config.__MANAGER__(); - for (auto& item : get(config)) { + for (auto &item : get(config)) { auto const &str = get(item.second); auto const &name = item.first; auto e = mgr->Find(name); CHECK(e); - if (dynamic_cast const*>(e) || - dynamic_cast const*>(e) || - dynamic_cast const*>(e) || - dynamic_cast const*>(e)) { + if (dynamic_cast const *>(e) || + dynamic_cast const *>(e) || + dynamic_cast const *>(e) || + dynamic_cast const *>(e)) { auto i = std::strtoimax(str.data(), nullptr, 10); CHECK_LE(i, static_cast(std::numeric_limits::max())); item.second = Integer(static_cast(i)); @@ -251,7 +251,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) { } config["nthread"] = GlobalConfigThreadLocalStore::Get()->nthread; - auto& local = *GlobalConfigAPIThreadLocalStore::Get(); + auto &local = *GlobalConfigAPIThreadLocalStore::Get(); Json::Dump(config, &local.ret_str); xgboost_CHECK_C_ARG_PTR(json_str); @@ -291,9 +291,7 @@ XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) { XGB_DLL int XGDMatrixCreateFromDataIter( void *data_handle, // a Java iterator XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp - const char *cache_info, - float missing, - DMatrixHandle *out) { + const char *cache_info, float missing, DMatrixHandle *out) { API_BEGIN(); std::string scache; @@ -303,9 +301,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter( xgboost::data::IteratorAdapter adapter( data_handle, callback); xgboost_CHECK_C_ARG_PTR(out); - *out = new std::shared_ptr { - DMatrix::Create(&adapter, missing, 1, scache) - }; + *out = new std::shared_ptr{DMatrix::Create(&adapter, missing, 1, scache)}; API_END(); } @@ -526,8 +522,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char API_END(); } -XGB_DLL int XGDMatrixCreateFromDense(char const *data, - char const *c_json_config, +XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *c_json_config, DMatrixHandle *out) { API_BEGIN(); xgboost_CHECK_C_ARG_PTR(data); @@ -566,10 +561,8 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char API_END(); } -XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data, - xgboost::bst_ulong nrow, - xgboost::bst_ulong ncol, bst_float missing, - DMatrixHandle* out) { +XGB_DLL int XGDMatrixCreateFromMat(const bst_float *data, xgboost::bst_ulong nrow, + xgboost::bst_ulong ncol, bst_float missing, DMatrixHandle *out) { API_BEGIN(); data::DenseAdapter adapter(data, nrow, ncol); xgboost_CHECK_C_ARG_PTR(out); @@ -577,11 +570,9 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data, API_END(); } -XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT - xgboost::bst_ulong nrow, - xgboost::bst_ulong ncol, - bst_float missing, DMatrixHandle* out, - int nthread) { +XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float *data, // NOLINT + xgboost::bst_ulong nrow, xgboost::bst_ulong ncol, + bst_float missing, DMatrixHandle *out, int nthread) { API_BEGIN(); data::DenseAdapter adapter(data, nrow, ncol); xgboost_CHECK_C_ARG_PTR(out); @@ -595,41 +586,32 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int *idxset, xgboo return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0); } -XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, - const int* idxset, - xgboost::bst_ulong len, - DMatrixHandle* out, - int allow_groups) { +XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, const int *idxset, xgboost::bst_ulong len, + DMatrixHandle *out, int allow_groups) { API_BEGIN(); CHECK_HANDLE(); if (!allow_groups) { - CHECK_EQ(static_cast*>(handle) - ->get() - ->Info() - .group_ptr_.size(), - 0U) + CHECK_EQ(static_cast *>(handle)->get()->Info().group_ptr_.size(), 0U) << "slice does not support group structure"; } - DMatrix* dmat = static_cast*>(handle)->get(); - *out = new std::shared_ptr( - dmat->Slice({idxset, static_cast(len)})); + DMatrix *dmat = static_cast *>(handle)->get(); + *out = new std::shared_ptr(dmat->Slice({idxset, static_cast(len)})); API_END(); } XGB_DLL int XGDMatrixFree(DMatrixHandle handle) { API_BEGIN(); CHECK_HANDLE(); - delete static_cast*>(handle); + delete static_cast *>(handle); API_END(); } -XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname, - int) { +XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char *fname, int) { API_BEGIN(); CHECK_HANDLE(); - auto dmat = static_cast*>(handle)->get(); + auto dmat = static_cast *>(handle)->get(); xgboost_CHECK_C_ARG_PTR(fname); - if (data::SimpleDMatrix* derived = dynamic_cast(dmat)) { + if (data::SimpleDMatrix *derived = dynamic_cast(dmat)) { derived->SaveToLocalFile(fname); } else { LOG(FATAL) << "binary saving only supported by SimpleDMatrix"; @@ -668,8 +650,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const API_END(); } -XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field, - const char **c_info, +XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field, const char **c_info, const xgboost::bst_ulong size) { API_BEGIN(); CHECK_HANDLE(); @@ -680,11 +661,10 @@ XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field, } XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, - xgboost::bst_ulong *len, - const char ***out_features) { + xgboost::bst_ulong *len, const char ***out_features) { API_BEGIN(); CHECK_HANDLE(); - auto m = *static_cast*>(handle); + auto m = *static_cast *>(handle); auto &info = static_cast *>(handle)->get()->Info(); std::vector &charp_vecs = m->GetThreadLocal().ret_vec_charp; @@ -874,31 +854,27 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void API_END(); } -XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle, - const char* field, - xgboost::bst_ulong* out_len, - const bst_float** out_dptr) { +XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle, const char *field, + xgboost::bst_ulong *out_len, const bst_float **out_dptr) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(field); - const MetaInfo& info = static_cast*>(handle)->get()->Info(); + const MetaInfo &info = static_cast *>(handle)->get()->Info(); xgboost_CHECK_C_ARG_PTR(out_len); xgboost_CHECK_C_ARG_PTR(out_dptr); - info.GetInfo(field, out_len, DataType::kFloat32, reinterpret_cast(out_dptr)); + info.GetInfo(field, out_len, DataType::kFloat32, reinterpret_cast(out_dptr)); API_END(); } -XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, - const char *field, - xgboost::bst_ulong *out_len, - const unsigned **out_dptr) { +XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle, const char *field, + xgboost::bst_ulong *out_len, const unsigned **out_dptr) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(field); - const MetaInfo& info = static_cast*>(handle)->get()->Info(); + const MetaInfo &info = static_cast *>(handle)->get()->Info(); xgboost_CHECK_C_ARG_PTR(out_len); xgboost_CHECK_C_ARG_PTR(out_dptr); - info.GetInfo(field, out_len, DataType::kUInt32, reinterpret_cast(out_dptr)); + info.GetInfo(field, out_len, DataType::kUInt32, reinterpret_cast(out_dptr)); API_END(); } @@ -1073,14 +1049,13 @@ XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *conf } // xgboost implementation -XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], - xgboost::bst_ulong len, +XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], xgboost::bst_ulong len, BoosterHandle *out) { API_BEGIN(); - std::vector > mats; + std::vector> mats; for (xgboost::bst_ulong i = 0; i < len; ++i) { xgboost_CHECK_C_ARG_PTR(dmats); - mats.push_back(*static_cast*>(dmats[i])); + mats.push_back(*static_cast *>(dmats[i])); } xgboost_CHECK_C_ARG_PTR(out); *out = Learner::Create(mats); @@ -1090,7 +1065,7 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], XGB_DLL int XGBoosterFree(BoosterHandle handle) { API_BEGIN(); CHECK_HANDLE(); - delete static_cast(handle); + delete static_cast(handle); API_END(); } @@ -1101,53 +1076,49 @@ XGB_DLL int XGBoosterReset(BoosterHandle handle) { API_END(); } -XGB_DLL int XGBoosterSetParam(BoosterHandle handle, - const char *name, - const char *value) { +XGB_DLL int XGBoosterSetParam(BoosterHandle handle, const char *name, const char *value) { API_BEGIN(); CHECK_HANDLE(); - static_cast(handle)->SetParam(name, value); + static_cast(handle)->SetParam(name, value); API_END(); } -XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle, - xgboost::bst_ulong *out) { +XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle, xgboost::bst_ulong *out) { API_BEGIN(); CHECK_HANDLE(); - static_cast(handle)->Configure(); + static_cast(handle)->Configure(); xgboost_CHECK_C_ARG_PTR(out); - *out = static_cast(handle)->GetNumFeature(); + *out = static_cast(handle)->GetNumFeature(); API_END(); } -XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) { +XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int *out) { API_BEGIN(); CHECK_HANDLE(); - static_cast(handle)->Configure(); + static_cast(handle)->Configure(); xgboost_CHECK_C_ARG_PTR(out); - *out = static_cast(handle)->BoostedRounds(); + *out = static_cast(handle)->BoostedRounds(); API_END(); } -XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) { +XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const *json_parameters) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(json_parameters); - Json config { Json::Load(StringView{json_parameters}) }; - static_cast(handle)->LoadConfig(config); + Json config{Json::Load(StringView{json_parameters})}; + static_cast(handle)->LoadConfig(config); API_END(); } -XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, - xgboost::bst_ulong *out_len, - char const** out_str) { +XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, xgboost::bst_ulong *out_len, + char const **out_str) { API_BEGIN(); CHECK_HANDLE(); - Json config { Object() }; - auto* learner = static_cast(handle); + Json config{Object()}; + auto *learner = static_cast(handle); learner->Configure(); learner->SaveConfig(&config); - std::string& raw_str = learner->GetThreadLocal().ret_str; + std::string &raw_str = learner->GetThreadLocal().ret_str; Json::Dump(config, &raw_str); xgboost_CHECK_C_ARG_PTR(out_str); @@ -1158,12 +1129,10 @@ XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, - int iter, - DMatrixHandle dtrain) { +XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, int iter, DMatrixHandle dtrain) { API_BEGIN(); CHECK_HANDLE(); - auto* bst = static_cast(handle); + auto *bst = static_cast(handle); xgboost_CHECK_C_ARG_PTR(dtrain); auto *dtr = static_cast *>(dtrain); CHECK(dtr); @@ -1198,7 +1167,7 @@ void CopyGradientFromCudaArrays(Context const *, ArrayInterface<2, false> const common::AssertGPUSupport(); } #else -; // NOLINT + ; // NOLINT #endif } // namespace xgboost @@ -1279,23 +1248,20 @@ XGB_DLL int XGBoosterTrainOneIterWithSplitGrad(BoosterHandle handle, DMatrixHand API_END(); } -XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, - int iter, - DMatrixHandle dmats[], - const char* evnames[], - xgboost::bst_ulong len, - const char** out_str) { +XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, int iter, DMatrixHandle dmats[], + const char *evnames[], xgboost::bst_ulong len, + const char **out_str) { API_BEGIN(); CHECK_HANDLE(); - auto* bst = static_cast(handle); - std::string& eval_str = bst->GetThreadLocal().ret_str; + auto *bst = static_cast(handle); + std::string &eval_str = bst->GetThreadLocal().ret_str; std::vector> data_sets; std::vector data_names; for (xgboost::bst_ulong i = 0; i < len; ++i) { xgboost_CHECK_C_ARG_PTR(dmats); - data_sets.push_back(*static_cast*>(dmats[i])); + data_sets.push_back(*static_cast *>(dmats[i])); xgboost_CHECK_C_ARG_PTR(evnames); data_names.emplace_back(evnames[i]); } @@ -1306,22 +1272,17 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterPredict(BoosterHandle handle, - DMatrixHandle dmat, - int option_mask, - unsigned ntree_limit, - int training, - xgboost::bst_ulong *len, +XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int option_mask, + unsigned ntree_limit, int training, xgboost::bst_ulong *len, const bst_float **out_result) { API_BEGIN(); CHECK_HANDLE(); - auto *learner = static_cast(handle); - auto& entry = learner->GetThreadLocal().prediction_entry; + auto *learner = static_cast(handle); + auto &entry = learner->GetThreadLocal().prediction_entry; auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner); - learner->Predict(*static_cast *>(dmat), - (option_mask & 1) != 0, &entry.predictions, 0, iteration_end, - static_cast(training), (option_mask & 2) != 0, - (option_mask & 4) != 0, (option_mask & 8) != 0, + learner->Predict(*static_cast *>(dmat), (option_mask & 1) != 0, + &entry.predictions, 0, iteration_end, static_cast(training), + (option_mask & 2) != 0, (option_mask & 4) != 0, (option_mask & 8) != 0, (option_mask & 16) != 0); xgboost_CHECK_C_ARG_PTR(len); @@ -1332,12 +1293,10 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, - DMatrixHandle dmat, - char const* c_json_config, +XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat, + char const *c_json_config, xgboost::bst_ulong const **out_shape, - xgboost::bst_ulong *out_dim, - bst_float const **out_result) { + xgboost::bst_ulong *out_dim, bst_float const **out_result) { API_BEGIN(); if (handle == nullptr) { LOG(FATAL) << "Booster has not been initialized or has already been disposed."; @@ -1348,34 +1307,33 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, xgboost_CHECK_C_ARG_PTR(c_json_config); auto config = Json::Load(StringView{c_json_config}); - auto *learner = static_cast(handle); - auto& entry = learner->GetThreadLocal().prediction_entry; + auto *learner = static_cast(handle); + auto &entry = learner->GetThreadLocal().prediction_entry; auto p_m = *static_cast *>(dmat); auto type = PredictionType(RequiredArg(config, "type", __func__)); auto iteration_begin = RequiredArg(config, "iteration_begin", __func__); auto iteration_end = RequiredArg(config, "iteration_end", __func__); - auto const& j_config = get(config); + auto const &j_config = get(config); auto ntree_limit_it = j_config.find("ntree_limit"); if (ntree_limit_it != j_config.cend() && !IsA(ntree_limit_it->second) && get(ntree_limit_it->second) != 0) { - CHECK(iteration_end == 0) << - "Only one of the `ntree_limit` or `iteration_range` can be specified."; + CHECK(iteration_end == 0) + << "Only one of the `ntree_limit` or `iteration_range` can be specified."; LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_range` instead."; iteration_end = GetIterationFromTreeLimit(get(ntree_limit_it->second), learner); } - bool approximate = type == PredictionType::kApproxContribution || - type == PredictionType::kApproxInteraction; - bool contribs = type == PredictionType::kContribution || - type == PredictionType::kApproxContribution; - bool interactions = type == PredictionType::kInteraction || - type == PredictionType::kApproxInteraction; + bool approximate = + type == PredictionType::kApproxContribution || type == PredictionType::kApproxInteraction; + bool contribs = + type == PredictionType::kContribution || type == PredictionType::kApproxContribution; + bool interactions = + type == PredictionType::kInteraction || type == PredictionType::kApproxInteraction; bool training = RequiredArg(config, "training", __func__); - learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions, - iteration_begin, iteration_end, training, - type == PredictionType::kLeaf, contribs, approximate, + learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions, iteration_begin, + iteration_end, training, type == PredictionType::kLeaf, contribs, approximate, interactions); xgboost_CHECK_C_ARG_PTR(out_result); @@ -1391,9 +1349,8 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, xgboost_CHECK_C_ARG_PTR(out_dim); xgboost_CHECK_C_ARG_PTR(out_shape); - CalcPredictShape(strict_shape, type, p_m->Info().num_row_, - p_m->Info().num_col_, chunksize, learner->Groups(), rounds, - &shape, out_dim); + CalcPredictShape(strict_shape, type, p_m->Info().num_row_, p_m->Info().num_col_, chunksize, + learner->Groups(), rounds, &shape, out_dim); *out_shape = dmlc::BeginPtr(shape); API_END(); } @@ -1680,15 +1637,14 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, xgboost::bst_ulong API_END(); } -XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, - const void *buf, +XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, const void *buf, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(buf); - common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*) - static_cast(handle)->Load(&fs); + common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*) + static_cast(handle)->Load(&fs); API_END(); } @@ -1709,16 +1665,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, API_END(); } -inline void XGBoostDumpModelImpl(BoosterHandle handle, FeatureMap* fmap, - int with_stats, const char *format, - xgboost::bst_ulong *len, +inline void XGBoostDumpModelImpl(BoosterHandle handle, FeatureMap *fmap, int with_stats, + const char *format, xgboost::bst_ulong *len, const char ***out_models) { - auto *bst = static_cast(handle); + auto *bst = static_cast(handle); bst->Configure(); GenerateFeatureMap(bst, {}, bst->GetNumFeature(), fmap); - std::vector& str_vecs = bst->GetThreadLocal().ret_vec_str; - std::vector& charp_vecs = bst->GetThreadLocal().ret_vec_charp; + std::vector &str_vecs = bst->GetThreadLocal().ret_vec_str; + std::vector &charp_vecs = bst->GetThreadLocal().ret_vec_charp; str_vecs = bst->DumpModel(*fmap, with_stats != 0, format); charp_vecs.resize(str_vecs.size()); for (size_t i = 0; i < str_vecs.size(); ++i) { @@ -1732,23 +1687,17 @@ inline void XGBoostDumpModelImpl(BoosterHandle handle, FeatureMap* fmap, *len = static_cast(charp_vecs.size()); } -XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, - const char* fmap, - int with_stats, - xgboost::bst_ulong* len, - const char*** out_models) { +XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, const char *fmap, int with_stats, + xgboost::bst_ulong *len, const char ***out_models) { API_BEGIN(); CHECK_HANDLE(); return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models); API_END(); } -XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, - const char* fmap, - int with_stats, - const char *format, - xgboost::bst_ulong* len, - const char*** out_models) { +XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, const char *fmap, int with_stats, + const char *format, xgboost::bst_ulong *len, + const char ***out_models) { API_BEGIN(); CHECK_HANDLE(); @@ -1759,25 +1708,16 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, - int fnum, - const char** fname, - const char** ftype, - int with_stats, - xgboost::bst_ulong* len, - const char*** out_models) { - return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, - with_stats, "text", len, out_models); +XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, int fnum, const char **fname, + const char **ftype, int with_stats, + xgboost::bst_ulong *len, const char ***out_models) { + return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats, "text", len, + out_models); } -XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, - int fnum, - const char** fname, - const char** ftype, - int with_stats, - const char *format, - xgboost::bst_ulong* len, - const char*** out_models) { +XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, int fnum, const char **fname, + const char **ftype, int with_stats, const char *format, + xgboost::bst_ulong *len, const char ***out_models) { API_BEGIN(); CHECK_HANDLE(); FeatureMap featmap; @@ -1837,8 +1777,8 @@ XGB_DLL int XGBoosterGetCategoriesExportToArrow(BoosterHandle handle, char const XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, const char *key, const char **out, int *success) { - auto* bst = static_cast(handle); - std::string& ret_str = bst->GetThreadLocal().ret_str; + auto *bst = static_cast(handle); + std::string &ret_str = bst->GetThreadLocal().ret_str; API_BEGIN(); CHECK_HANDLE(); @@ -1855,12 +1795,10 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, const char *key, const char * API_END(); } -XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, - const char* key, - const char* value) { +XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, const char *key, const char *value) { API_BEGIN(); CHECK_HANDLE(); - auto* bst = static_cast(handle); + auto *bst = static_cast(handle); xgboost_CHECK_C_ARG_PTR(key); if (value == nullptr) { bst->DelAttr(key); @@ -1871,16 +1809,14 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, - xgboost::bst_ulong* out_len, - const char*** out) { +XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, xgboost::bst_ulong *out_len, + const char ***out) { API_BEGIN(); CHECK_HANDLE(); auto *learner = static_cast(handle); std::vector &str_vecs = learner->GetThreadLocal().ret_vec_str; - std::vector &charp_vecs = - learner->GetThreadLocal().ret_vec_charp; + std::vector &charp_vecs = learner->GetThreadLocal().ret_vec_charp; str_vecs = learner->GetAttrNames(); charp_vecs.resize(str_vecs.size()); for (size_t i = 0; i < str_vecs.size(); ++i) { @@ -1896,8 +1832,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, } XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field, - const char **features, - const xgboost::bst_ulong size) { + const char **features, const xgboost::bst_ulong size) { API_BEGIN(); CHECK_HANDLE(); auto *learner = static_cast(handle); @@ -1921,13 +1856,11 @@ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field, } XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, - xgboost::bst_ulong *len, - const char ***out_features) { + xgboost::bst_ulong *len, const char ***out_features) { API_BEGIN(); CHECK_HANDLE(); auto const *learner = static_cast(handle); - std::vector &charp_vecs = - learner->GetThreadLocal().ret_vec_charp; + std::vector &charp_vecs = learner->GetThreadLocal().ret_vec_charp; std::vector &str_vecs = learner->GetThreadLocal().ret_vec_str; if (!std::strcmp(field, "feature_name")) { learner->GetFeatureNames(&str_vecs); @@ -1985,9 +1918,9 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config, auto n_features = learner->GetNumFeature(); GenerateFeatureMap(learner, custom_feature_names, n_features, &feature_map); - auto& feature_names = learner->GetThreadLocal().ret_vec_str; + auto &feature_names = learner->GetThreadLocal().ret_vec_str; feature_names.resize(features.size()); - auto& feature_names_c = learner->GetThreadLocal().ret_vec_charp; + auto &feature_names_c = learner->GetThreadLocal().ret_vec_charp; feature_names_c.resize(features.size()); for (bst_feature_t i = 0; i < features.size(); ++i) { @@ -2025,8 +1958,7 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config, } XGB_DLL int XGBoosterGetLeafSimilarityWeights(BoosterHandle handle, char const *config, - bst_ulong *out_len, - float const **out_weights) { + bst_ulong *out_len, float const **out_weights) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(config); diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 0aef1e000152..1f93f467bed5 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -13,7 +13,7 @@ #include #include -#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc +#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc #include "../common/threading_utils.h" #include "../common/timer.h" #include "gblinear_model.h" @@ -35,13 +35,10 @@ struct GBLinearTrainParam : public XGBoostParameter { size_t max_row_perbatch; DMLC_DECLARE_PARAMETER(GBLinearTrainParam) { - DMLC_DECLARE_FIELD(updater) - .set_default("shotgun") - .describe("Update algorithm for linear model. One of shotgun/coord_descent"); - DMLC_DECLARE_FIELD(tolerance) - .set_lower_bound(0.0f) - .set_default(0.0f) - .describe("Stop if largest weight update is smaller than this number."); + DMLC_DECLARE_FIELD(updater).set_default("shotgun").describe( + "Update algorithm for linear model. One of shotgun/coord_descent"); + DMLC_DECLARE_FIELD(tolerance).set_lower_bound(0.0f).set_default(0.0f).describe( + "Stop if largest weight update is smaller than this number."); DMLC_DECLARE_FIELD(max_row_perbatch) .set_default(std::numeric_limits::max()) .describe("Maximum rows per batch."); @@ -86,9 +83,7 @@ class GBLinear : public GradientBooster { updater_->Configure(cfg); } - int32_t BoostedRounds() const override { - return model_.num_boosted_rounds; - } + int32_t BoostedRounds() const override { return model_.num_boosted_rounds; } bool ModelFitted() const override { return BoostedRounds() != 0; } @@ -152,7 +147,7 @@ class GBLinear : public GradientBooster { monitor_.Stop("PredictBatch"); } - void PredictLeaf(DMatrix *, HostDeviceVector *, unsigned, unsigned) override { + void PredictLeaf(DMatrix*, HostDeviceVector*, unsigned, unsigned) override { LOG(FATAL) << "gblinear does not support prediction of leaf index"; } @@ -170,7 +165,7 @@ class GBLinear : public GradientBooster { std::fill(contribs.begin(), contribs.end(), 0); auto base_score = learner_model_param_->BaseScore(ctx_); // start collecting the contributions - for (const auto &batch : p_fmat->GetBatches()) { + for (const auto& batch : p_fmat->GetBatches()) { // parallel over local batch const auto nsize = static_cast(batch.Size()); auto page = batch.GetView(); @@ -179,7 +174,7 @@ class GBLinear : public GradientBooster { auto row_idx = static_cast(batch.base_rowid + i); // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { - bst_float *p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; + bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; // calculate linear terms' contributions for (auto& ins : inst) { if (ins.index >= model_.learner_model_param->num_feature) continue; @@ -201,8 +196,8 @@ class GBLinear : public GradientBooster { std::vector& contribs = out_contribs->HostVector(); // linear models have no interaction effects - const size_t nelements = model_.learner_model_param->num_feature * - model_.learner_model_param->num_feature; + const size_t nelements = + model_.learner_model_param->num_feature * model_.learner_model_param->num_feature; contribs.resize(p_fmat->Info().num_row_ * nelements * model_.learner_model_param->num_output_group); std::fill(contribs.begin(), contribs.end(), 0); @@ -213,10 +208,9 @@ class GBLinear : public GradientBooster { return model_.DumpModel(fmap, with_stats, format); } - void FeatureScore(std::string const &importance_type, - common::Span trees, - std::vector *out_features, - std::vector *out_scores) const override { + void FeatureScore(std::string const& importance_type, common::Span trees, + std::vector* out_features, + std::vector* out_scores) const override { CHECK(!model_.weight.empty()) << "Model is not initialized"; CHECK(trees.empty()) << "gblinear doesn't support number of trees for feature importance."; CHECK_EQ(importance_type, "weight") @@ -248,18 +242,17 @@ class GBLinear : public GradientBooster { } protected: - void PredictBatchInternal(DMatrix *p_fmat, - std::vector *out_preds) { + void PredictBatchInternal(DMatrix* p_fmat, std::vector* out_preds) { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); - std::vector &preds = *out_preds; + std::vector& preds = *out_preds; auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU()); // start collecting the prediction const int ngroup = model_.learner_model_param->num_output_group; preds.resize(p_fmat->Info().num_row_ * ngroup); auto base_score = learner_model_param_->BaseScore(DeviceOrd::CPU()); - for (const auto &page : p_fmat->GetBatches()) { + for (const auto& page : p_fmat->GetBatches()) { auto const& batch = page.GetView(); // output convention: nrow * k, where nrow is number of rows // k is number of group @@ -289,8 +282,7 @@ class GBLinear : public GradientBooster { } float largest_dw = 0.0; for (size_t i = 0; i < model_.weight.size(); i++) { - largest_dw = std::max( - largest_dw, std::abs(model_.weight[i] - previous_model_.weight[i])); + largest_dw = std::max(largest_dw, std::abs(model_.weight[i] - previous_model_.weight[i])); } previous_model_ = model_; @@ -298,9 +290,9 @@ class GBLinear : public GradientBooster { return is_converged_; } - void LazySumWeights(DMatrix *p_fmat) { + void LazySumWeights(DMatrix* p_fmat) { if (!sum_weight_complete_) { - auto &info = p_fmat->Info(); + auto& info = p_fmat->Info(); for (size_t i = 0; i < info.num_row_; i++) { sum_instance_weight_ += info.GetWeight(i); } @@ -308,8 +300,7 @@ class GBLinear : public GradientBooster { } } - void Pred(const SparsePage::Inst &inst, bst_float *preds, int gid, - bst_float base) { + void Pred(const SparsePage::Inst& inst, bst_float* preds, int gid, bst_float base) { bst_float psum = model_.Bias()[gid] + base; for (const auto& ins : inst) { if (ins.index >= model_.learner_model_param->num_feature) continue; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 5b441d3b75e1..2261af986677 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -40,10 +40,7 @@ enum class TreeMethod : int { }; // boosting process types -enum class TreeProcessType : int { - kDefault = 0, - kUpdate = 1 -}; +enum class TreeProcessType : int { kDefault = 0, kUpdate = 1 }; // Sampling type for dart weights. enum class DartSampleType : std::int32_t { @@ -72,15 +69,16 @@ struct GBTreeTrainParam : public XGBoostParameter { .set_default(TreeProcessType::kDefault) .add_enum("default", TreeProcessType::kDefault) .add_enum("update", TreeProcessType::kUpdate) - .describe("Whether to run the normal boosting process that creates new trees,"\ - " or to update the trees in an existing model."); + .describe( + "Whether to run the normal boosting process that creates new trees," + " or to update the trees in an existing model."); DMLC_DECLARE_ALIAS(updater_seq, updater); DMLC_DECLARE_FIELD(tree_method) .set_default(TreeMethod::kAuto) - .add_enum("auto", TreeMethod::kAuto) - .add_enum("approx", TreeMethod::kApprox) - .add_enum("exact", TreeMethod::kExact) - .add_enum("hist", TreeMethod::kHist) + .add_enum("auto", TreeMethod::kAuto) + .add_enum("approx", TreeMethod::kApprox) + .add_enum("exact", TreeMethod::kExact) + .add_enum("hist", TreeMethod::kHist) .describe("Choice of tree construction method."); } }; @@ -268,10 +266,9 @@ class GBTree : public GradientBooster { } }); } else { - LOG(FATAL) - << "Unknown feature importance type, expected one of: " - << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" - << importance_type; + LOG(FATAL) << "Unknown feature importance type, expected one of: " + << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" + << importance_type; } if (importance_type == "gain" || importance_type == "cover") { for (size_t i = 0; i < gain_map.size(); ++i) { @@ -289,8 +286,7 @@ class GBTree : public GradientBooster { } } - void LeafSimilarityWeights(std::string const& weight_type, - bst_layer_t iteration_begin, + void LeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, bst_layer_t iteration_end, std::vector* weights) const override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, iteration_begin, iteration_end); @@ -298,8 +294,8 @@ class GBTree : public GradientBooster { weights->reserve(tree_end - tree_begin); auto const get_weight = [&](RegTree const& tree) { - CHECK(!tree.IsMultiTarget()) << "Leaf similarity weights for multi-target tree " - << MTNotImplemented(); + CHECK(!tree.IsMultiTarget()) + << "Leaf similarity weights for multi-target tree " << MTNotImplemented(); tree::ScalarTreeView view{&tree}; if (weight_type == "uniform") { @@ -330,9 +326,8 @@ class GBTree : public GradientBooster { [[nodiscard]] CatContainer const* Cats() const override { return this->model_.Cats(); } - void PredictLeaf(DMatrix* p_fmat, - HostDeviceVector* out_preds, - uint32_t layer_begin, uint32_t layer_end) override { + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, uint32_t layer_begin, + uint32_t layer_end) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: [0, " "n_iteration), use model slicing instead."; @@ -384,7 +379,7 @@ class GBTree : public GradientBooster { GBTreeTrainParam tparam_; // Tree training parameter tree::TrainParam tree_param_; - bool specified_updater_ {false}; + bool specified_updater_{false}; // the updaters that can be applied to each of tree std::vector> updaters_; // Predictors diff --git a/src/learner.cc b/src/learner.cc index ba80147d8f53..bc2c9e070e18 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -6,38 +6,38 @@ */ #include "xgboost/learner.h" -#include // for Stream -#include // for FieldEntry, DMLC_DECLARE_FIELD, Parameter, DMLC... -#include // for ThreadLocalStore - -#include // for equal, max, transform, sort, find_if, all_of -#include // for atomic -#include // for isalpha, isspace -#include // for isnan, isinf -#include // for int32_t, uint32_t, int64_t, uint64_t -#include // for atoi -#include // for memcpy, size_t, memset -#include // for operator<<, setiosflags -#include // for back_insert_iterator, distance, back_inserter -#include // for numeric_limits -#include // for allocator, unique_ptr, shared_ptr, operator== -#include // for mutex, lock_guard -#include // for operator<<, basic_ostream, basic_ostream::opera... -#include // for stack -#include // for basic_string, char_traits, operator<, string -#include // for errc -#include // for operator!=, unordered_map -#include // for pair, as_const, move, swap -#include // for vector +#include // for Stream +#include // for FieldEntry, DMLC_DECLARE_FIELD, Parameter, DMLC... +#include // for ThreadLocalStore + +#include // for equal, max, transform, sort, find_if, all_of +#include // for atomic +#include // for isalpha, isspace +#include // for isnan, isinf +#include // for int32_t, uint32_t, int64_t, uint64_t +#include // for atoi +#include // for memcpy, size_t, memset +#include // for operator<<, setiosflags +#include // for back_insert_iterator, distance, back_inserter +#include // for numeric_limits +#include // for allocator, unique_ptr, shared_ptr, operator== +#include // for mutex, lock_guard +#include // for operator<<, basic_ostream, basic_ostream::opera... +#include // for stack +#include // for basic_string, char_traits, operator<, string +#include // for errc +#include // for operator!=, unordered_map +#include // for pair, as_const, move, swap +#include // for vector #include "collective/aggregator.h" // for ApplyWithLabels #include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed #include "common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "common/param_array.h" // for ParamArray #include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_... #include "common/error_msg.h" // for MaxFeatureSize, WarnOldSerialization, ... #include "common/io.h" // for PeekableInStream, ReadAll, FixedSizeStream, Mem... #include "common/observer.h" // for TrainingObserver +#include "common/param_array.h" // for ParamArray #include "common/random.h" // for GlobalRandom #include "common/timer.h" // for Monitor #include "common/version.h" // for Version @@ -300,7 +300,7 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) { struct LearnerTrainParam : public XGBoostParameter { // flag to disable default metric - bool disable_default_eval_metric {false}; + bool disable_default_eval_metric{false}; // FIXME(trivialfis): The following parameters belong to model itself, but can be // specified by users. Move them to model parameter once we can get rid of binary IO. std::string booster; @@ -328,12 +328,11 @@ struct LearnerTrainParam : public XGBoostParameter { } }; - DMLC_REGISTER_PARAMETER(LearnerModelParamLegacy); DMLC_REGISTER_PARAMETER(LearnerTrainParam); using LearnerAPIThreadLocalStore = - dmlc::ThreadLocalStore>; + dmlc::ThreadLocalStore>; namespace { /** @@ -614,7 +613,7 @@ class LearnerConfiguration : public Intercept { void SaveConfig(Json* p_out) const override { CHECK(!this->need_configuration_) << "Call Configure before saving model."; Version::Save(p_out); - Json& out { *p_out }; + Json& out{*p_out}; // parameters out["learner"] = Object(); auto& learner_parameters = out["learner"]; @@ -642,8 +641,7 @@ class LearnerConfiguration : public Intercept { void SetParam(const std::string& key, const std::string& value) override { this->need_configuration_ = true; if (key == kEvalMetric) { - if (std::find(metric_names_.cbegin(), metric_names_.cend(), - value) == metric_names_.cend()) { + if (std::find(metric_names_.cbegin(), metric_names_.cend(), value) == metric_names_.cend()) { metric_names_.emplace_back(value); } } else { @@ -657,9 +655,7 @@ class LearnerConfiguration : public Intercept { } } - uint32_t GetNumFeature() const override { - return learner_model_param_.num_feature; - } + uint32_t GetNumFeature() const override { return learner_model_param_.num_feature; } void SetAttr(const std::string& key, const std::string& value) override { attributes_[key] = value; @@ -674,22 +670,18 @@ class LearnerConfiguration : public Intercept { bool DelAttr(const std::string& key) override { auto it = attributes_.find(key); - if (it == attributes_.end()) { return false; } + if (it == attributes_.end()) { + return false; + } attributes_.erase(it); return true; } - void SetFeatureNames(std::vector const& fn) override { - feature_names_ = fn; - } + void SetFeatureNames(std::vector const& fn) override { feature_names_ = fn; } - void GetFeatureNames(std::vector* fn) const override { - *fn = feature_names_; - } + void GetFeatureNames(std::vector* fn) const override { *fn = feature_names_; } - void SetFeatureTypes(std::vector const& ft) override { - this->feature_types_ = ft; - } + void SetFeatureTypes(std::vector const& ft) override { this->feature_types_ = ft; } void GetFeatureTypes(std::vector* p_ft) const override { auto& ft = *p_ft; @@ -716,13 +708,13 @@ class LearnerConfiguration : public Intercept { private: void ValidateParameters() { - Json config { Object() }; + Json config{Object()}; this->SaveConfig(&config); std::stack stack; stack.push(config); std::string const postfix{"_param"}; - auto is_parameter = [&postfix](std::string const &key) { + auto is_parameter = [&postfix](std::string const& key) { return key.size() > postfix.size() && std::equal(postfix.rbegin(), postfix.rend(), key.rbegin()); }; @@ -738,7 +730,7 @@ class LearnerConfiguration : public Intercept { while (!stack.empty()) { auto j_obj = stack.top(); stack.pop(); - auto const &obj = get(j_obj); + auto const& obj = get(j_obj); for (auto const& kv : obj) { if (is_parameter(kv.first)) { @@ -766,7 +758,7 @@ class LearnerConfiguration : public Intercept { std::sort(keys.begin(), keys.end()); std::vector provided; - for (auto const &kv : cfg_) { + for (auto const& kv : cfg_) { if (std::any_of(kv.first.cbegin(), kv.first.cend(), [](char ch) { return std::isspace(ch); })) { LOG(FATAL) << "Invalid parameter \"" << kv.first << "\" contains whitespace."; @@ -776,8 +768,8 @@ class LearnerConfiguration : public Intercept { std::sort(provided.begin(), provided.end()); std::vector diff; - std::set_difference(provided.begin(), provided.end(), keys.begin(), - keys.end(), std::back_inserter(diff)); + std::set_difference(provided.begin(), provided.end(), keys.begin(), keys.end(), + std::back_inserter(diff)); if (diff.size() != 0) { std::stringstream ss; ss << "\nParameters: { "; @@ -817,8 +809,7 @@ class LearnerConfiguration : public Intercept { void ConfigureGBM(LearnerTrainParam const& old, Args const& args) { if (gbm_ == nullptr || old.booster != tparam_.booster) { - gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, - &learner_model_param_)); + gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_)); } gbm_->Configure(args); } @@ -833,8 +824,7 @@ class LearnerConfiguration : public Intercept { } } - if (cfg_.find("max_delta_step") == cfg_.cend() && - cfg_.find("objective") != cfg_.cend() && + if (cfg_.find("max_delta_step") == cfg_.cend() && cfg_.find("objective") != cfg_.cend() && tparam_.objective == "count:poisson") { // max_delta_step is a duplicated parameter in Poisson regression and tree param. // Rename one of them once binary IO is gone. @@ -844,7 +834,7 @@ class LearnerConfiguration : public Intercept { obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_)); } - bool has_nc {cfg_.find("num_class") != cfg_.cend()}; + bool has_nc{cfg_.find("num_class") != cfg_.cend()}; // Inject num_class into configuration. // FIXME(jiamingy): Remove the duplicated parameter in softmax cfg_["num_class"] = std::to_string(mparam_.num_class); @@ -858,7 +848,9 @@ class LearnerConfiguration : public Intercept { void ConfigureMetrics(Args const& args) { for (auto const& name : metric_names_) { - auto DupCheck = [&name](std::unique_ptr const& m) { return m->Name() != name; }; + auto DupCheck = [&name](std::unique_ptr const& m) { + return m->Name() != name; + }; if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) { metrics_.emplace_back(std::unique_ptr(Metric::Create(name, &ctx_))); } @@ -877,7 +869,7 @@ class LearnerConfiguration : public Intercept { } }; -std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT +std::string const LearnerConfiguration::kEvalMetric{"eval_metric"}; // NOLINT class LearnerIO : public LearnerConfiguration { protected: @@ -908,8 +900,7 @@ class LearnerIO : public LearnerConfiguration { auto const& gradient_booster = learner.at("gradient_booster"); name = get(gradient_booster["name"]); tparam_.UpdateAllowUnknown(Args{{"booster", name}}); - gbm_.reset( - GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_)); + gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_)); gbm_->LoadModel(gradient_booster); auto const& j_attributes = get(learner.at("attributes")); @@ -943,7 +934,7 @@ class LearnerIO : public LearnerConfiguration { this->CheckModelInitialized(); Version::Save(p_out); - Json& out { *p_out }; + Json& out{*p_out}; out["learner"] = Object(); auto& learner = out["learner"]; @@ -1020,8 +1011,7 @@ class LearnerIO : public LearnerConfiguration { */ class LearnerImpl : public LearnerIO { public: - explicit LearnerImpl(std::vector > cache) - : LearnerIO{cache} {} + explicit LearnerImpl(std::vector> cache) : LearnerIO{cache} {} ~LearnerImpl() override { auto local_map = LearnerAPIThreadLocalStore::Get(); if (local_map->find(this) != local_map->cend()) { @@ -1148,8 +1138,7 @@ class LearnerImpl : public LearnerIO { this->monitor_.Stop(__func__); } - std::string EvalOneIter(int iter, - const std::vector>& data_sets, + std::string EvalOneIter(int iter, const std::vector>& data_sets, const std::vector& data_names) override { monitor_.Start("EvalOneIter"); this->Configure(); @@ -1173,7 +1162,7 @@ class LearnerImpl : public LearnerIO { this->ValidateDMatrix(m.get(), false); this->PredictRaw(m.get(), predt.get(), false, 0, 0); - auto &out = output_predictions_.Cache(m, ctx_.Device())->predictions; + auto& out = output_predictions_.Cache(m, ctx_.Device())->predictions; out.Resize(predt->predictions.Size()); out.Copy(predt->predictions); @@ -1191,8 +1180,7 @@ class LearnerImpl : public LearnerIO { HostDeviceVector* out_preds, bst_layer_t layer_begin, bst_layer_t layer_end, bool training, bool pred_leaf, bool pred_contribs, bool approx_contribs, bool pred_interactions) override { - int multiple_predictions = static_cast(pred_leaf) + - static_cast(pred_interactions) + + int multiple_predictions = static_cast(pred_leaf) + static_cast(pred_interactions) + static_cast(pred_contribs); this->Configure(); if (training) { @@ -1222,7 +1210,9 @@ class LearnerImpl : public LearnerIO { } int32_t BoostedRounds() const override { - if (!this->gbm_) { return 0; } // haven't call train or LoadModel. + if (!this->gbm_) { + return 0; + } // haven't call train or LoadModel. CHECK(!this->need_configuration_); return this->gbm_->BoostedRounds(); } @@ -1266,10 +1256,8 @@ class LearnerImpl : public LearnerIO { gbm_->FeatureScore(importance_type, trees, features, scores); } - void CalcLeafSimilarityWeights(std::string const& weight_type, - bst_layer_t iteration_begin, - bst_layer_t iteration_end, - std::vector* weights) override { + void CalcLeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, + bst_layer_t iteration_end, std::vector* weights) override { this->Configure(); this->CheckModelInitialized(); @@ -1289,7 +1277,7 @@ class LearnerImpl : public LearnerIO { * predictor, when it equals 0, this means we are using all the trees * \param training allow dropout when the DART booster is being used */ - void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training, + void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds, bool training, unsigned layer_begin, unsigned layer_end) const { CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; this->CheckModelInitialized(); @@ -1335,8 +1323,7 @@ class LearnerImpl : public LearnerIO { constexpr int32_t LearnerImpl::kRandSeedMagic; -Learner* Learner::Create( - const std::vector >& cache_data) { +Learner* Learner::Create(const std::vector>& cache_data) { return new LearnerImpl(cache_data); } } // namespace xgboost diff --git a/tests/python/test_leaf_similarity.py b/tests/python/test_leaf_similarity.py index 9a0ae8743546..f9158a34ce06 100644 --- a/tests/python/test_leaf_similarity.py +++ b/tests/python/test_leaf_similarity.py @@ -4,9 +4,9 @@ import numpy as np import pytest -from sklearn.datasets import load_diabetes, load_iris - import xgboost as xgb +from sklearn.datasets import load_diabetes, load_iris +from xgboost import testing as tm from xgboost.core import ( _LIB, _check_call, @@ -14,9 +14,6 @@ ctypes2numpy, from_pystr_to_cstr, ) -from xgboost import testing as tm - -rng = np.random.RandomState(1994) class TestLeafSimilarity: @@ -46,7 +43,10 @@ class TestLeafSimilarity: ), ], ) - @pytest.mark.parametrize(("weight_type", "column"), [("gain", "Gain"), ("cover", "Cover")]) + @pytest.mark.parametrize( + ("weight_type", "column"), [("gain", "Gain"), ("cover", "Cover")] + ) + @pytest.mark.skipif(**tm.no_pandas()) def test_leaf_similarity_weight_api( self, param: dict, num_boost_round: int, weight_type: str, column: str ) -> None: @@ -64,13 +64,7 @@ def test_leaf_similarity_weight_api( expected_weights[int(tree_id)] = weight config = from_pystr_to_cstr( - ( - "{" - f'"weight_type":"{weight_type}",' - '"iteration_begin":0,' - '"iteration_end":0' - "}" - ) + (f'{{"weight_type":"{weight_type}","iteration_begin":0,"iteration_end":0}}') ) out_len = c_bst_ulong() out_weights = ctypes.POINTER(ctypes.c_float)() @@ -112,7 +106,9 @@ def test_leaf_similarity(self) -> None: assert sim_gain.shape == sim_cover.shape # Default should be uniform - sim_uniform = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="uniform") + sim_uniform = bst.compute_leaf_similarity( + dm_query, dm_ref, weight_type="uniform" + ) sim_default = bst.compute_leaf_similarity(dm_query, dm_ref) np.testing.assert_array_equal(sim_default, sim_uniform) @@ -162,7 +158,9 @@ def test_leaf_similarity_supported_tree_modes( rounds = 8 if param.get("booster") == "dart" else 5 bst = xgb.train(param, dtrain, num_boost_round=rounds) - similarity = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type=weight_type) + similarity = bst.compute_leaf_similarity( + dm_query, dm_ref, weight_type=weight_type + ) assert similarity.shape == (dm_query.num_row(), dm_ref.num_row()) assert similarity.min() >= 0.0 assert similarity.max() <= 1.0 + 1e-6 @@ -202,7 +200,9 @@ def test_leaf_similarity_gblinear_error(self, weight_type: str) -> None: bst = xgb.train({"booster": "gblinear", "objective": "binary:logistic"}, dtrain) X = dtrain.get_data() - with pytest.raises(xgb.core.XGBoostError, match="Leaf similarity is only defined"): + with pytest.raises( + xgb.core.XGBoostError, match="Leaf similarity is only defined" + ): bst.compute_leaf_similarity( xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type )