diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index f1bb9c904ca5..dfc88c83dbb3 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -112,7 +112,8 @@ XGB_DLL int XGBSetGlobalConfig(char const *config); /** * @brief Get current global configuration (collection of parameters that apply globally). - * @param out_config pointer to received returned global configuration, represented as a JSON string. + * @param out_config pointer to received returned global configuration, represented as a JSON + * string. * @return 0 when success, -1 when failure happens */ XGB_DLL int XGBGetGlobalConfig(char const **out_config); @@ -149,12 +150,14 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle * @brief load a data matrix * * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: - * - uri: The URI of the input file. The URI parameter `format` is required when loading text data. + * - uri: The URI of the input file. The URI parameter `format` is required when loading text + * data. * @verbatim embed:rst:leading-asterisk * See :doc:`/tutorials/input_format` for more info. * @endverbatim * - silent (optional): Whether to print message during loading. Default to true. - * - data_split_mode (optional): Whether the file was split by row or column beforehand for distributed computing. Default to row. + * - data_split_mode (optional): Whether the file was split by row or column beforehand for + * distributed computing. Default to row. * @param out a loaded data matrix * @return 0 when success, -1 when failure happens */ @@ -243,7 +246,8 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char * @param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. - * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default + * to row. * @param out The created DMatrix * * @return 0 when success, -1 when failure happens @@ -310,7 +314,8 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, char const *config * @param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. - * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default + * to row. * @param out created dmatrix * @return 0 when success, -1 when failure happens */ @@ -493,7 +498,8 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLI * @param next Callback function yielding the next batch of data. * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: * - missing: Which value to represent missing value - * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. + * - cache_prefix: The path of cache file, caller must initialize all the directories in this + * path. * - nthread (optional): Number of threads used for initializing DMatrix. * @param[out] out The created external memory DMatrix * @@ -558,7 +564,8 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand * @param next Callback function yielding the next batch of data. * @param config JSON encoded parameters for DMatrix construction. Accepted fields are: * - missing: Which value to represent missing value - * - cache_prefix: The path of cache file, caller must initialize all the directories in this path. + * - cache_prefix: The path of cache file, caller must initialize all the directories in this + * path. * - nthread (optional): Number of threads used for initializing DMatrix. * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with * the corresponding booster training parameter. @@ -1143,16 +1150,16 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, int iter, DMatrixHandle d * 1:output margin instead of transformed value * 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree * 4:output feature contributions to individual predictions - * @param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees - * when the parameter is set to 0, we will use all the trees + * @param ntree_limit limit number of trees used for prediction, this is only valid for boosted + * trees when the parameter is set to 0, we will use all the trees * @param training Whether the prediction function is used as part of a training loop. * Prediction can be run in 2 scenarios: * 1. Given data matrix X, obtain prediction y_pred from the model. * 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout - * during training, and the prediction result will be different from the one obtained by normal - * inference step due to dropped trees. - * Set training=false for the first scenario. Set training=true for the second scenario. - * The second scenario applies when you are defining a custom objective function. + * during training, and the prediction result will be different from the one obtained by + * normal inference step due to dropped trees. Set training=false for the first scenario. Set + * training=true for the second scenario. The second scenario applies when you are defining a custom + * objective function. * @param out_len used to store length of returning result * @param out_result used to set a pointer to array * @return 0 when success, -1 when failure happens @@ -1183,21 +1190,17 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int optio * * Prediction can be run in 2 scenarios: * 1. Given data matrix X, obtain prediction y_pred from the model. - * 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout - * during training, and the prediction result will be different from the one obtained by normal - * inference step due to dropped trees. - * Set training=false for the first scenario. Set training=true for the second - * scenario. The second scenario applies when you are defining a custom objective - * function. - * "iteration_begin": int - * Beginning iteration of prediction. + * 2. Obtain the prediction for computing gradients. For example, DART booster performs + * dropout during training, and the prediction result will be different from the one obtained by + * normal inference step due to dropped trees. Set training=false for the first scenario. Set + * training=true for the second scenario. The second scenario applies when you are defining a + * custom objective function. "iteration_begin": int Beginning iteration of prediction. * "iteration_end": int - * End iteration of prediction. Set to 0 this will become the size of tree model (all the trees). - * "strict_shape": bool - * Whether should we reshape the output with stricter rules. If set to true, - * normal/margin/contrib/interaction predict will output consistent shape - * disregarding the use of multi-class model, and leaf prediction will output 4-dim - * array representing: (n_samples, n_iterations, n_classes, n_trees_in_forest) + * End iteration of prediction. Set to 0 this will become the size of tree model (all the + * trees). "strict_shape": bool Whether should we reshape the output with stricter rules. If set to + * true, normal/margin/contrib/interaction predict will output consistent shape disregarding the use + * of multi-class model, and leaf prediction will output 4-dim array representing: (n_samples, + * n_iterations, n_classes, n_trees_in_forest) * * Example JSON input for running a normal prediction with strict output shape, 2 dim * for softprob , 1 dim for others. @@ -1217,7 +1220,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, DMatrixHandle dmat, int optio * * @return 0 when success, -1 when failure happens * - * @see XGBoosterPredictFromDense XGBoosterPredictFromCSR XGBoosterPredictFromCudaArray XGBoosterPredictFromCudaColumnar + * @see XGBoosterPredictFromDense XGBoosterPredictFromCSR XGBoosterPredictFromCudaArray + * XGBoosterPredictFromCudaColumnar */ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat, char const *config, bst_ulong const **out_shape, @@ -1649,6 +1653,33 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, bst_ulong *out_n_features, char const ***out_features, bst_ulong *out_dim, bst_ulong const **out_shape, float const **out_scores); + +/** + * @brief Get per-tree weights for leaf similarity computation. + * + * @param config A JSON string with the following format: + * + * { + * "weight_type": str, + * "iteration_begin": int, + * "iteration_end": int + * } + * + * - weight_type: A JSON string with following possible values: + * * 'uniform': assign equal weight to each tree. + * * 'gain': sum split gain for each tree. + * * 'cover': sum split cover for each tree. + * - iteration_begin: Beginning iteration used when extracting tree weights. + * - iteration_end: End iteration used when extracting tree weights. 0 means + * using all remaining iterations. + * + * @param out_len Length of output tree weight array. + * @param out_weights Pointer to the output tree weight array. + * + * @return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterGetLeafSimilarityWeights(BoosterHandle handle, const char *config, + bst_ulong *out_len, float const **out_weights); /**@}*/ // End of Booster /** diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 65940773ffee..b979a070f868 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -116,8 +116,7 @@ class GradientBooster : public Model, public Configurable { * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. */ - virtual void PredictLeaf(DMatrix *dmat, - HostDeviceVector *out_preds, + virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, unsigned layer_begin, unsigned layer_end) = 0; /*! @@ -147,10 +146,13 @@ class GradientBooster : public Model, public Configurable { [[nodiscard]] virtual std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const = 0; - virtual void FeatureScore(std::string const& importance_type, - common::Span trees, + virtual void FeatureScore(std::string const& importance_type, common::Span trees, std::vector* features, std::vector* scores) const = 0; + + virtual void LeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) const = 0; /** * @brief Getter for categories. */ @@ -190,10 +192,10 @@ struct GradientBoosterReg * }); * \endcode */ -#define XGBOOST_REGISTER_GBM(UniqueId, Name) \ - static DMLC_ATTRIBUTE_UNUSED ::xgboost::GradientBoosterReg & \ - __make_ ## GradientBoosterReg ## _ ## UniqueId ## __ = \ - ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(Name) +#define XGBOOST_REGISTER_GBM(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::GradientBoosterReg& \ + __make_##GradientBoosterReg##_##UniqueId##__ = \ + ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(Name) } // namespace xgboost #endif // XGBOOST_GBM_H_ diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index ffaddfbe6442..c45992c33f34 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -96,8 +96,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param data_names name of each dataset * \return a string corresponding to the evaluation result */ - virtual std::string EvalOneIter(int iter, - const std::vector>& data_sets, + virtual std::string EvalOneIter(int iter, const std::vector>& data_sets, const std::vector& data_names) = 0; /*! * \brief get prediction given the model. @@ -107,7 +106,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. * \param training Whether the prediction result is used for training - * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor + * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree + * predictor * \param pred_contribs whether to only predict the feature contributions * \param approx_contribs whether to approximate the feature contributions for speed * \param pred_interactions whether to compute the feature pair contributions @@ -140,6 +140,10 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { std::vector* features, std::vector* scores) = 0; + virtual void CalcLeafSimilarityWeights(std::string const& weight_type, + bst_layer_t iteration_begin, bst_layer_t iteration_end, + std::vector* weights) = 0; + /* * \brief Get number of boosted rounds from gradient booster. */ @@ -206,7 +210,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \brief Set the feature names for current booster. * \param fn Input feature names */ - virtual void SetFeatureNames(std::vector const& fn) = 0; + virtual void SetFeatureNames(std::vector const& fn) = 0; /*! * \brief Get the feature names for current booster. * \param fn Output feature names @@ -245,8 +249,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param format the format to dump the model in * \return a vector of dump for boosters. */ - virtual std::vector DumpModel(const FeatureMap& fmap, - bool with_stats, + virtual std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) = 0; virtual XGBAPIThreadLocalEntry& GetThreadLocal() const = 0; @@ -259,7 +262,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param cache_data The matrix to cache the prediction. * \return Created learner. */ - static Learner* Create(const std::vector >& cache_data); + static Learner* Create(const std::vector>& cache_data); /** * \brief Return the context object of this Booster. */ @@ -276,7 +279,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { /*! \brief The gradient booster used by the model*/ std::unique_ptr gbm_; /*! \brief The evaluation metrics used to evaluate the model. */ - std::vector > metrics_; + std::vector> metrics_; /*! \brief Training parameter. */ Context ctx_; }; diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index c805fa486fb8..0f44fd7bcc39 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2765,6 +2765,127 @@ def inplace_predict( "Data type:" + str(type(data)) + " not supported by inplace prediction." ) + def compute_leaf_similarity( + self, + data: DMatrix, + reference: DMatrix, + weight_type: str = "uniform", + ) -> np.ndarray: + """Compute similarity between observations based on leaf node co-occurrence. + + Two samples are similar if they land in the same leaf nodes across trees. + This is similar to Random Forest proximity matrices. + + Parameters + ---------- + data : + Query dataset (m samples). + reference : + Reference dataset (n samples). + weight_type : + How to weight trees: "uniform" (equal tree weights), "gain" + (by loss improvement), or "cover" (by hessian sum, approximately + sample count for regression). + + Returns + ------- + similarity : ndarray of shape (m, n) + Similarity scores in [0, 1]. + """ + if weight_type not in ("uniform", "gain", "cover"): + raise ValueError( + "weight_type must be 'uniform', 'gain', or 'cover', " + f"got '{weight_type}'" + ) + + config = json.loads(self.save_config())["learner"] + booster = config["gradient_booster"]["name"] + if booster == "gblinear": + raise XGBoostError( + "Leaf similarity is only defined for tree boosters, got gblinear." + ) + + if config["learner_train_param"]["multi_strategy"] == "multi_output_tree": + raise XGBoostError("Leaf similarity does not support multi_output_tree.") + + query_leaves = self.predict(data, pred_leaf=True, strict_shape=True) + ref_leaves = self.predict(reference, pred_leaf=True, strict_shape=True) + + query_leaves = np.asarray(query_leaves, dtype=np.int64).reshape( + query_leaves.shape[0], -1 + ) + ref_leaves = np.asarray(ref_leaves, dtype=np.int64).reshape( + ref_leaves.shape[0], -1 + ) + + m, n = query_leaves.shape[0], ref_leaves.shape[0] + if query_leaves.shape[1] != ref_leaves.shape[1]: + raise ValueError( + "Query and reference leaf predictions have different shapes." + ) + + n_trees = query_leaves.shape[1] + if m == 0 or n == 0 or n_trees == 0: + return np.zeros((m, n), dtype=np.float32) + + if weight_type == "uniform": + weights = np.ones(n_trees, dtype=np.float32) + else: + out_len = c_bst_ulong() + out_weights = ctypes.POINTER(ctypes.c_float)() + _check_call( + _LIB.XGBoosterGetLeafSimilarityWeights( + self.handle, + make_jcargs( + weight_type=weight_type, + iteration_begin=0, + iteration_end=0, + ), + ctypes.byref(out_len), + ctypes.byref(out_weights), + ) + ) + weights = ctypes2numpy(out_weights, out_len.value, np.float32) + if weights.shape[0] != n_trees: + raise ValueError( + "Tree weight count does not match leaf prediction shape: " + f"{weights.shape[0]} != {n_trees}" + ) + + if weights.sum() == 0: + weights = np.ones(n_trees, dtype=np.float32) + + total_weight = weights.sum() + if total_weight == 0: + weights = np.ones(n_trees, dtype=np.float32) + total_weight = weights.sum() + + leaf_upper = np.maximum(query_leaves.max(axis=0), ref_leaves.max(axis=0)) + 1 + offsets = np.zeros(n_trees, dtype=np.int64) + if n_trees > 1: + offsets[1:] = np.cumsum(leaf_upper[:-1], dtype=np.int64) + + weight_values = np.sqrt(weights / total_weight, dtype=np.float32) + q_cols = (query_leaves + offsets).reshape(-1) + r_cols = (ref_leaves + offsets).reshape(-1) + q_rows = np.repeat(np.arange(m), n_trees) + r_rows = np.repeat(np.arange(n), n_trees) + feature_dim = int(offsets[-1] + leaf_upper[-1]) + + query_matrix = scipy.sparse.csr_matrix( + (np.tile(weight_values, m), (q_rows, q_cols)), + shape=(m, feature_dim), + dtype=np.float32, + ) + ref_matrix = scipy.sparse.csr_matrix( + (np.tile(weight_values, n), (r_rows, r_cols)), + shape=(n, feature_dim), + dtype=np.float32, + ) + + similarity = query_matrix @ ref_matrix.T + return similarity.toarray() + def save_model(self, fname: PathLike) -> None: """Save the model to a file. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 0be531d78815..902e3b27e2d6 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1961,3 +1961,25 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config, *out_features = dmlc::BeginPtr(feature_names_c); API_END(); } + +XGB_DLL int XGBoosterGetLeafSimilarityWeights(BoosterHandle handle, char const *config, + bst_ulong *out_len, float const **out_weights) { + API_BEGIN(); + CHECK_HANDLE(); + xgboost_CHECK_C_ARG_PTR(config); + auto *learner = static_cast(handle); + auto jconfig = Json::Load(StringView{config}); + + auto weight_type = RequiredArg(jconfig, "weight_type", __func__); + auto iteration_begin = RequiredArg(jconfig, "iteration_begin", __func__); + auto iteration_end = RequiredArg(jconfig, "iteration_end", __func__); + + auto &weights = learner->GetThreadLocal().ret_vec_float; + learner->CalcLeafSimilarityWeights(weight_type, iteration_begin, iteration_end, &weights); + + xgboost_CHECK_C_ARG_PTR(out_len); + xgboost_CHECK_C_ARG_PTR(out_weights); + *out_len = weights.size(); + *out_weights = dmlc::BeginPtr(weights); + API_END(); +} diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 3bd03a3b4a41..1f93f467bed5 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -13,7 +13,7 @@ #include #include -#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc +#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc #include "../common/threading_utils.h" #include "../common/timer.h" #include "gblinear_model.h" @@ -35,13 +35,10 @@ struct GBLinearTrainParam : public XGBoostParameter { size_t max_row_perbatch; DMLC_DECLARE_PARAMETER(GBLinearTrainParam) { - DMLC_DECLARE_FIELD(updater) - .set_default("shotgun") - .describe("Update algorithm for linear model. One of shotgun/coord_descent"); - DMLC_DECLARE_FIELD(tolerance) - .set_lower_bound(0.0f) - .set_default(0.0f) - .describe("Stop if largest weight update is smaller than this number."); + DMLC_DECLARE_FIELD(updater).set_default("shotgun").describe( + "Update algorithm for linear model. One of shotgun/coord_descent"); + DMLC_DECLARE_FIELD(tolerance).set_lower_bound(0.0f).set_default(0.0f).describe( + "Stop if largest weight update is smaller than this number."); DMLC_DECLARE_FIELD(max_row_perbatch) .set_default(std::numeric_limits::max()) .describe("Maximum rows per batch."); @@ -86,9 +83,7 @@ class GBLinear : public GradientBooster { updater_->Configure(cfg); } - int32_t BoostedRounds() const override { - return model_.num_boosted_rounds; - } + int32_t BoostedRounds() const override { return model_.num_boosted_rounds; } bool ModelFitted() const override { return BoostedRounds() != 0; } @@ -152,7 +147,7 @@ class GBLinear : public GradientBooster { monitor_.Stop("PredictBatch"); } - void PredictLeaf(DMatrix *, HostDeviceVector *, unsigned, unsigned) override { + void PredictLeaf(DMatrix*, HostDeviceVector*, unsigned, unsigned) override { LOG(FATAL) << "gblinear does not support prediction of leaf index"; } @@ -170,7 +165,7 @@ class GBLinear : public GradientBooster { std::fill(contribs.begin(), contribs.end(), 0); auto base_score = learner_model_param_->BaseScore(ctx_); // start collecting the contributions - for (const auto &batch : p_fmat->GetBatches()) { + for (const auto& batch : p_fmat->GetBatches()) { // parallel over local batch const auto nsize = static_cast(batch.Size()); auto page = batch.GetView(); @@ -179,7 +174,7 @@ class GBLinear : public GradientBooster { auto row_idx = static_cast(batch.base_rowid + i); // loop over output groups for (int gid = 0; gid < ngroup; ++gid) { - bst_float *p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; + bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; // calculate linear terms' contributions for (auto& ins : inst) { if (ins.index >= model_.learner_model_param->num_feature) continue; @@ -201,8 +196,8 @@ class GBLinear : public GradientBooster { std::vector& contribs = out_contribs->HostVector(); // linear models have no interaction effects - const size_t nelements = model_.learner_model_param->num_feature * - model_.learner_model_param->num_feature; + const size_t nelements = + model_.learner_model_param->num_feature * model_.learner_model_param->num_feature; contribs.resize(p_fmat->Info().num_row_ * nelements * model_.learner_model_param->num_output_group); std::fill(contribs.begin(), contribs.end(), 0); @@ -213,10 +208,9 @@ class GBLinear : public GradientBooster { return model_.DumpModel(fmap, with_stats, format); } - void FeatureScore(std::string const &importance_type, - common::Span trees, - std::vector *out_features, - std::vector *out_scores) const override { + void FeatureScore(std::string const& importance_type, common::Span trees, + std::vector* out_features, + std::vector* out_scores) const override { CHECK(!model_.weight.empty()) << "Model is not initialized"; CHECK(trees.empty()) << "gblinear doesn't support number of trees for feature importance."; CHECK_EQ(importance_type, "weight") @@ -237,19 +231,28 @@ class GBLinear : public GradientBooster { } } + void LeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) const override { + (void)weight_type; + (void)iteration_begin; + (void)iteration_end; + (void)weights; + LOG(FATAL) << "Leaf similarity weights are not defined for gblinear booster."; + } + protected: - void PredictBatchInternal(DMatrix *p_fmat, - std::vector *out_preds) { + void PredictBatchInternal(DMatrix* p_fmat, std::vector* out_preds) { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); - std::vector &preds = *out_preds; + std::vector& preds = *out_preds; auto base_margin = p_fmat->Info().base_margin_.View(DeviceOrd::CPU()); // start collecting the prediction const int ngroup = model_.learner_model_param->num_output_group; preds.resize(p_fmat->Info().num_row_ * ngroup); auto base_score = learner_model_param_->BaseScore(DeviceOrd::CPU()); - for (const auto &page : p_fmat->GetBatches()) { + for (const auto& page : p_fmat->GetBatches()) { auto const& batch = page.GetView(); // output convention: nrow * k, where nrow is number of rows // k is number of group @@ -279,8 +282,7 @@ class GBLinear : public GradientBooster { } float largest_dw = 0.0; for (size_t i = 0; i < model_.weight.size(); i++) { - largest_dw = std::max( - largest_dw, std::abs(model_.weight[i] - previous_model_.weight[i])); + largest_dw = std::max(largest_dw, std::abs(model_.weight[i] - previous_model_.weight[i])); } previous_model_ = model_; @@ -288,9 +290,9 @@ class GBLinear : public GradientBooster { return is_converged_; } - void LazySumWeights(DMatrix *p_fmat) { + void LazySumWeights(DMatrix* p_fmat) { if (!sum_weight_complete_) { - auto &info = p_fmat->Info(); + auto& info = p_fmat->Info(); for (size_t i = 0; i < info.num_row_; i++) { sum_instance_weight_ += info.GetWeight(i); } @@ -298,8 +300,7 @@ class GBLinear : public GradientBooster { } } - void Pred(const SparsePage::Inst &inst, bst_float *preds, int gid, - bst_float base) { + void Pred(const SparsePage::Inst& inst, bst_float* preds, int gid, bst_float base) { bst_float psum = model_.Bias()[gid] + base; for (const auto& ins : inst) { if (ins.index >= model_.learner_model_param->num_feature) continue; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 2d1e63133f52..5a9daa0d5849 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -287,6 +287,44 @@ class GBTree : public GradientBooster { } } + void LeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, + bst_layer_t iteration_end, + std::vector* weights) const override { + auto [tree_begin, tree_end] = detail::LayerToTree(model_, iteration_begin, iteration_end); + weights->clear(); + weights->reserve(tree_end - tree_begin); + + auto const get_weight = [&](RegTree const& tree) { + CHECK(!tree.IsMultiTarget()) + << "Leaf similarity weights for multi-target tree " << MTNotImplemented(); + tree::ScalarTreeView view{&tree}; + + if (weight_type == "uniform") { + return 1.0f; + } + + float weight = 0.0f; + for (bst_node_t nidx = 0; nidx < view.Size(); ++nidx) { + if (!view.IsLeaf(nidx)) { + if (weight_type == "gain") { + weight += view.LossChg(nidx); + } else if (weight_type == "cover") { + weight += view.SumHess(nidx); + } else { + LOG(FATAL) << "Unknown leaf similarity weight type, expected one of: " + << R"({"uniform", "gain", "cover"}, got: )" << weight_type; + } + } + } + return weight; + }; + + for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& tree = *model_.trees.at(tree_idx); + weights->push_back(get_weight(tree)); + } + } + [[nodiscard]] CatContainer const* Cats() const override { return this->model_.Cats(); } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, uint32_t layer_begin, diff --git a/src/learner.cc b/src/learner.cc index 8e28e78c8857..8ff456de1e19 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1274,6 +1274,14 @@ class LearnerImpl : public LearnerIO { gbm_->FeatureScore(importance_type, trees, features, scores); } + void CalcLeafSimilarityWeights(std::string const& weight_type, bst_layer_t iteration_begin, + bst_layer_t iteration_end, std::vector* weights) override { + this->Configure(); + this->CheckModelInitialized(); + + gbm_->LeafSimilarityWeights(weight_type, iteration_begin, iteration_end, weights); + } + const std::map& GetConfigurationArguments() const override { return cfg_; } diff --git a/tests/python/test_leaf_similarity.py b/tests/python/test_leaf_similarity.py new file mode 100644 index 000000000000..f9158a34ce06 --- /dev/null +++ b/tests/python/test_leaf_similarity.py @@ -0,0 +1,235 @@ +"""Tests for leaf similarity computation.""" + +import ctypes + +import numpy as np +import pytest +import xgboost as xgb +from sklearn.datasets import load_diabetes, load_iris +from xgboost import testing as tm +from xgboost.core import ( + _LIB, + _check_call, + c_bst_ulong, + ctypes2numpy, + from_pystr_to_cstr, +) + + +class TestLeafSimilarity: + """Tests for Booster.compute_leaf_similarity()""" + + @pytest.mark.parametrize( + ("param", "num_boost_round"), + [ + ({"max_depth": 4, "eta": 0.3, "objective": "binary:logistic"}, 8), + ( + { + "max_depth": 3, + "eta": 0.3, + "objective": "multi:softprob", + "num_class": 3, + }, + 6, + ), + ( + { + "max_depth": 4, + "eta": 1.0, + "objective": "binary:logistic", + "num_parallel_tree": 3, + }, + 5, + ), + ], + ) + @pytest.mark.parametrize( + ("weight_type", "column"), [("gain", "Gain"), ("cover", "Cover")] + ) + @pytest.mark.skipif(**tm.no_pandas()) + def test_leaf_similarity_weight_api( + self, param: dict, num_boost_round: int, weight_type: str, column: str + ) -> None: + """Test the low-level tree weight API shape and order contract.""" + dtrain, _ = tm.load_agaricus(__file__) + bst = xgb.train(param, dtrain, num_boost_round=num_boost_round) + + leaves = bst.predict(dtrain, pred_leaf=True, strict_shape=True) + expected_len = int(np.prod(leaves.shape[1:])) + expected_weights = np.zeros(expected_len, dtype=np.float32) + trees_df = bst.trees_to_dataframe() + split_nodes = trees_df[trees_df["Feature"] != "Leaf"] + tree_weights = split_nodes.groupby("Tree")[column].sum() + for tree_id, weight in tree_weights.items(): + expected_weights[int(tree_id)] = weight + + config = from_pystr_to_cstr( + (f'{{"weight_type":"{weight_type}","iteration_begin":0,"iteration_end":0}}') + ) + out_len = c_bst_ulong() + out_weights = ctypes.POINTER(ctypes.c_float)() + + _check_call( + _LIB.XGBoosterGetLeafSimilarityWeights( + bst.handle, config, ctypes.byref(out_len), ctypes.byref(out_weights) + ) + ) + + assert out_len.value == expected_len + weights = ctypes2numpy(out_weights, out_len.value, np.float32) + np.testing.assert_allclose(weights, expected_weights, rtol=1e-6, atol=1e-3) + + def test_leaf_similarity(self) -> None: + """Test basic leaf similarity computation.""" + dtrain, _ = tm.load_agaricus(__file__) + param = {"max_depth": 4, "eta": 0.3, "objective": "binary:logistic"} + bst = xgb.train(param, dtrain, num_boost_round=10) + + X = dtrain.get_data() + dm_query = xgb.DMatrix(X[:10]) + dm_ref = xgb.DMatrix(X[100:150]) + + # Test shape and range + similarity = bst.compute_leaf_similarity(dm_query, dm_ref) + assert similarity.shape == (10, 50) + assert similarity.min() >= 0.0 + assert similarity.max() <= 1.0 + 1e-6 + + # Self-similarity diagonal should be 1.0 + dm_self = xgb.DMatrix(X[:20]) + self_sim = bst.compute_leaf_similarity(dm_self, dm_self) + np.testing.assert_allclose(np.diag(self_sim), 1.0, rtol=1e-5) + + # Test weight types + sim_gain = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="gain") + sim_cover = bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="cover") + assert sim_gain.shape == sim_cover.shape + + # Default should be uniform + sim_uniform = bst.compute_leaf_similarity( + dm_query, dm_ref, weight_type="uniform" + ) + sim_default = bst.compute_leaf_similarity(dm_query, dm_ref) + np.testing.assert_array_equal(sim_default, sim_uniform) + + # Invalid weight_type + with pytest.raises(ValueError, match="weight_type must be"): + bst.compute_leaf_similarity(dm_query, dm_ref, weight_type="invalid") + + @pytest.mark.parametrize( + "param", + [ + { + "max_depth": 3, + "eta": 0.3, + "objective": "multi:softprob", + "num_class": 3, + }, + { + "booster": "dart", + "max_depth": 4, + "eta": 0.3, + "objective": "binary:logistic", + }, + { + "max_depth": 4, + "eta": 1.0, + "objective": "binary:logistic", + "num_parallel_tree": 3, + }, + ], + ) + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_supported_tree_modes( + self, param: dict, weight_type: str + ) -> None: + """Test supported tree model modes.""" + if param.get("objective") == "multi:softprob": + X, y = load_iris(return_X_y=True) + dtrain = xgb.DMatrix(X, label=y) + dm_query = xgb.DMatrix(X[:5]) + dm_ref = xgb.DMatrix(X[10:20]) + rounds = 8 + else: + dtrain, _ = tm.load_agaricus(__file__) + X = dtrain.get_data() + dm_query = xgb.DMatrix(X[:10]) + dm_ref = xgb.DMatrix(X[100:130]) + rounds = 8 if param.get("booster") == "dart" else 5 + + bst = xgb.train(param, dtrain, num_boost_round=rounds) + similarity = bst.compute_leaf_similarity( + dm_query, dm_ref, weight_type=weight_type + ) + assert similarity.shape == (dm_query.num_row(), dm_ref.num_row()) + assert similarity.min() >= 0.0 + assert similarity.max() <= 1.0 + 1e-6 + + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_one_output_per_tree_multi_target( + self, weight_type: str + ) -> None: + """Test multi-target model with one output per tree.""" + X, y = load_diabetes(return_X_y=True) + y = np.column_stack([y, y * 0.5]) + dtrain = xgb.DMatrix(X, label=y) + bst = xgb.train( + { + "max_depth": 3, + "eta": 0.3, + "tree_method": "hist", + "objective": "reg:squarederror", + "multi_strategy": "one_output_per_tree", + "num_target": 2, + }, + dtrain, + num_boost_round=6, + ) + + similarity = bst.compute_leaf_similarity( + xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type + ) + assert similarity.shape == (5, 10) + assert similarity.min() >= 0.0 + assert similarity.max() <= 1.0 + 1e-6 + + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_gblinear_error(self, weight_type: str) -> None: + """Test unsupported gblinear booster with stable error.""" + dtrain, _ = tm.load_agaricus(__file__) + bst = xgb.train({"booster": "gblinear", "objective": "binary:logistic"}, dtrain) + X = dtrain.get_data() + + with pytest.raises( + xgb.core.XGBoostError, match="Leaf similarity is only defined" + ): + bst.compute_leaf_similarity( + xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type + ) + + @pytest.mark.parametrize("weight_type", ["uniform", "gain", "cover"]) + def test_leaf_similarity_multi_output_tree_error(self, weight_type: str) -> None: + """Test unsupported multi-output tree with stable error.""" + X, y = load_diabetes(return_X_y=True) + y = np.column_stack([y, y * 0.5]) + dtrain = xgb.DMatrix(X, label=y) + bst = xgb.train( + { + "max_depth": 3, + "eta": 0.3, + "tree_method": "hist", + "objective": "reg:squarederror", + "multi_strategy": "multi_output_tree", + "num_target": 2, + }, + dtrain, + num_boost_round=6, + ) + + with pytest.raises( + xgb.core.XGBoostError, + match="Leaf similarity does not support multi_output_tree", + ): + bst.compute_leaf_similarity( + xgb.DMatrix(X[:5]), xgb.DMatrix(X[10:20]), weight_type=weight_type + )