Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,3 @@
path = dmlc-core
url = https://github.com/dmlc/dmlc-core
branch = main
[submodule "gputreeshap"]
path = gputreeshap
url = https://github.com/rapidsai/gputreeshap.git
4 changes: 0 additions & 4 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ function(xgboost_set_cuda_flags target)
target_link_libraries(${target} PRIVATE CCCL::CCCL CUDA::cudart_static)
endif()
target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
target_include_directories(
${target} PRIVATE
${xgboost_SOURCE_DIR}/gputreeshap)

if(MSVC)
xgboost_cuda_wrap_host_compiler_options(cuda_utf8_flags /utf-8)
target_compile_options(${target} PRIVATE ${cuda_utf8_flags})
Expand Down
4 changes: 2 additions & 2 deletions doc/contrib/featuremap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ XGBoost includes features designed to improve understanding of the model. Here's
- Tree visualization.
- Tree as dataframe.

For GPU support, the SHAP value uses the `GPUTreeShap <https://github.com/rapidsai/gputreeshap/tree/main>`_ project in rapidsai. They all support categorical features, while vector-leaf is still in progress.
For GPU support, the SHAP value uses XGBoost's in-tree ``QuadratureTreeSHAP`` implementation. It supports categorical features, while vector-leaf is still in progress.

----------
Evaluation
Expand All @@ -66,4 +66,4 @@ Inference normally doesn't require any special treatment since we are using samp
*****************
Language Bindings
*****************
We have a list of bindings for various languages. Inside the XGBoost repository, there's Python, R, Java, Scala, and C. All language bindings are built on top of the C version. Some others, like Julia and Rust, have their own repository. For guideline on adding a new binding, please see :doc:`/contrib/consistency`.
We have a list of bindings for various languages. Inside the XGBoost repository, there's Python, R, Java, Scala, and C. All language bindings are built on top of the C version. Some others, like Julia and Rust, have their own repository. For guidance on adding a new binding, please see :doc:`/contrib/consistency`.
2 changes: 1 addition & 1 deletion doc/gpu/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ The GPU algorithms currently work with CLI, Python, R, and JVM packages. See :do
GPU-Accelerated SHAP values
=============================
XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as a backend for computing shap values when the GPU is used.
XGBoost provides an in-tree GPU implementation of ``QuadratureTreeSHAP`` for computing SHAP values when the GPU is used.

.. code-block:: python
Expand Down
1 change: 0 additions & 1 deletion gputreeshap
Submodule gputreeshap deleted from d28571
1 change: 0 additions & 1 deletion python-package/packager/sdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def copy_cpp_src_tree(
"src",
"include",
"dmlc-core",
"gputreeshap",
"cmake",
"plugin",
]:
Expand Down
3 changes: 2 additions & 1 deletion python-package/xgboost/testing/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def _run_predt(
pred_interactions=pred_interactions,
pred_leaf=pred_leaf,
)
assert_allclose(device, predt_0, predt_1)
atol = 1e-6 if pred_contribs or pred_interactions else 0
assert_allclose(device, predt_0, predt_1, atol=atol)


def run_cat_shap(device: Device) -> None:
Expand Down
116 changes: 82 additions & 34 deletions src/predictor/interpretability/quadrature.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,42 @@
#define XGBOOST_PREDICTOR_INTERPRETABILITY_QUADRATURE_H_

#include <algorithm>
#include <array>
#include <cmath>
#include <cstddef>
#include <utility>
#include <vector>

#include "xgboost/base.h"
#include "xgboost/logging.h"
#include "xgboost/tree_model.h"

namespace xgboost::interpretability::detail {

constexpr double kPi = 3.141592653589793238462643383279502884;
constexpr std::size_t kQuadratureTreeShapPoints = 8;
constexpr float kQuadratureTreeShapUnseen = -999.0f;
constexpr float kQuadratureTreeShapMinChildWeight = 1e-12f;

template <std::size_t MaxPoints>
struct EndpointQuadratureRule {
std::size_t points{0};
std::array<double, MaxPoints> nodes{};
std::array<double, MaxPoints> weights{};
XGBOOST_DEVICE inline float BranchWeight(float cover, float parent_cover) {
// In a well-formed tree, split-node cover is positive and each child cover is a valid
// fraction of the parent cover. Zero cover can still appear after model refresh or when
// loading externally produced models, so fall back to a neutral branch probability instead
// of allowing NaN/Inf weights to propagate through SHAP.
if (parent_cover <= 0.0f) {
return 0.5f;
}
auto weight = cover / parent_cover;
// A zero-cover child is not expected for a normally trained split, but can occur in
// refreshed trees. Keep the path reachable with a tiny probability so quadrature remains
// numerically well-defined while preserving ordinary nonzero cover ratios unchanged.
if (weight < kQuadratureTreeShapMinChildWeight) {
return kQuadratureTreeShapMinChildWeight;
}
return weight;
}

struct QuadratureRule {
float nodes[kQuadratureTreeShapPoints];
float weights[kQuadratureTreeShapPoints];
};

inline double LegendrePolynomial(std::size_t n, double x) {
Expand All @@ -48,50 +67,79 @@ inline double LegendreDerivative(std::size_t n, double x, double pn) {
return n_d * (x * pn - LegendrePolynomial(n - 1, x)) / (x * x - 1.0);
}

template <std::size_t MaxPoints>
inline EndpointQuadratureRule<MaxPoints> MakeEndpointQuadrature(std::size_t n,
double convergence_eps) {
CHECK_GE(n, 2);
CHECK_LE(n, MaxPoints);
inline QuadratureRule MakeEndpointQuadrature() {
constexpr std::size_t kN = kQuadratureTreeShapPoints;
constexpr double kConvergenceEps = 1e-15;
QuadratureRule rule;

EndpointQuadratureRule<MaxPoints> rule;
rule.points = n;
std::vector<std::pair<double, double>> nodes_weights;
nodes_weights.reserve(n);

for (std::size_t i = 0; i < n; ++i) {
double theta = kPi * (static_cast<double>(i) + 0.75) / (static_cast<double>(n) + 0.5);
for (std::size_t i = 0; i < kN; ++i) {
double theta = kPi * (static_cast<double>(i) + 0.75) / (static_cast<double>(kN) + 0.5);
double x = std::cos(theta);
for (std::size_t iter = 0; iter < 64; ++iter) {
auto pn = LegendrePolynomial(n, x);
auto dpn = LegendreDerivative(n, x, pn);
auto pn = LegendrePolynomial(kN, x);
auto dpn = LegendreDerivative(kN, x, pn);
auto dx = pn / dpn;
x -= dx;
if (std::abs(dx) < convergence_eps) {
if (std::abs(dx) < kConvergenceEps) {
break;
}
}

auto pn = LegendrePolynomial(n, x);
auto dpn = LegendreDerivative(n, x, pn);
auto pn = LegendrePolynomial(kN, x);
auto dpn = LegendreDerivative(kN, x, pn);
auto w = 2.0 / ((1.0 - x * x) * dpn * dpn);
double s = 0.5 * (x + 1.0);
double ws = 0.5 * w;
nodes_weights.emplace_back(s * s, 2.0 * s * ws);
auto out_idx = kN - 1 - i;
rule.nodes[out_idx] = static_cast<float>(s * s);
rule.weights[out_idx] = static_cast<float>(2.0 * s * ws);
}
return rule;
}

std::sort(nodes_weights.begin(), nodes_weights.end(),
[](auto const &l, auto const &r) { return l.first < r.first; });
for (std::size_t i = 0; i < n; ++i) {
rule.nodes[i] = nodes_weights[i].first;
rule.weights[i] = nodes_weights[i].second;
inline QuadratureRule const& GetQuadratureRule() {
static QuadratureRule const kRule = MakeEndpointQuadrature();
return kRule;
}

template <typename Tree>
double FillRootMeanValue(Tree const& tree, bst_node_t nidx) {
if (tree.IsLeaf(nidx)) {
return tree.LeafValue(nidx);
}
return rule;
auto left = tree.LeftChild(nidx);
auto right = tree.RightChild(nidx);
CHECK_GE(tree.SumHess(nidx), 0.0f)
<< "QuadratureTreeSHAP is undefined for trees with negative cover at split nodes.";
CHECK_GE(tree.SumHess(left), 0.0f)
<< "QuadratureTreeSHAP is undefined for trees with negative child cover.";
CHECK_GE(tree.SumHess(right), 0.0f)
<< "QuadratureTreeSHAP is undefined for trees with negative child cover.";
auto const parent_cover = tree.SumHess(nidx);
auto const left_mean = FillRootMeanValue(tree, left);
auto const right_mean = FillRootMeanValue(tree, right);
if (parent_cover == 0.0f) {
return 0.5 * (left_mean + right_mean);
}
return (left_mean * tree.SumHess(left) + right_mean * tree.SumHess(right)) / parent_cover;
}

template <std::size_t Points>
inline EndpointQuadratureRule<Points> MakeEndpointQuadrature(double convergence_eps) {
return MakeEndpointQuadrature<Points>(Points, convergence_eps);
template <typename TreeGroups, typename GetTree>
std::vector<float> MakeGroupRootMeanSums(TreeGroups const& tree_groups, bst_target_t n_groups,
bst_tree_t tree_end,
std::vector<float> const* tree_weights,
GetTree&& get_tree) {
std::vector<double> group_root_mean_sums(n_groups, 0.0);
for (bst_tree_t tree_idx = 0; tree_idx < tree_end; ++tree_idx) {
auto const weight = tree_weights == nullptr ? 1.0f : (*tree_weights)[tree_idx];
group_root_mean_sums[tree_groups[tree_idx]] +=
FillRootMeanValue(get_tree(tree_idx), RegTree::kRoot) * weight;
}

std::vector<float> out(group_root_mean_sums.size());
std::transform(group_root_mean_sums.cbegin(), group_root_mean_sums.cend(), out.begin(),
[](double v) { return static_cast<float>(v); });
return out;
}

} // namespace xgboost::interpretability::detail
Expand Down
71 changes: 11 additions & 60 deletions src/predictor/interpretability/shap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,37 +58,6 @@ void FillNodeMeanValues(tree::ScalarTreeView const &tree, std::vector<float> *me
FillNodeMeanValues(tree, 0, mean_values);
}

double FillRootMeanValue(tree::ScalarTreeView const &tree, bst_node_t nidx) {
if (tree.IsLeaf(nidx)) {
return tree.LeafValue(nidx);
}
auto left = tree.LeftChild(nidx);
auto right = tree.RightChild(nidx);
double result = FillRootMeanValue(tree, left) * tree.SumHess(left);
result += FillRootMeanValue(tree, right) * tree.SumHess(right);
result /= tree.SumHess(nidx);
return result;
}

void ValidateQuadratureTreeShapCovers(tree::ScalarTreeView const &tree, bst_node_t nidx) {
if (tree.IsLeaf(nidx)) {
return;
}

CHECK_GT(tree.SumHess(nidx), 0.0f)
<< "CPU QuadratureTreeSHAP is undefined for trees with non-positive cover at split nodes.";

auto left = tree.LeftChild(nidx);
auto right = tree.RightChild(nidx);
CHECK_GT(tree.SumHess(left), 0.0f)
<< "CPU QuadratureTreeSHAP is undefined for trees with non-positive cover at child nodes.";
CHECK_GT(tree.SumHess(right), 0.0f)
<< "CPU QuadratureTreeSHAP is undefined for trees with non-positive cover at child nodes.";

ValidateQuadratureTreeShapCovers(tree, left);
ValidateQuadratureTreeShapCovers(tree, right);
}

void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVec const &feats,
std::vector<float> *mean_values,
std::vector<bst_float> *out_contribs) {
Expand Down Expand Up @@ -116,30 +85,12 @@ void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVe

// Keep the CPU quadrature recurrence on the same fixed 8-point rule as the GPU path so the hot
// loops stay small and the compiler can fully unroll the basis update and extraction work.
constexpr std::size_t kQuadratureTreeShapPoints = 8;
constexpr double kQuadratureTreeShapBuildQeps = 1e-15;
constexpr float kQuadratureTreeShapUnseen = -999.0f;
constexpr std::size_t kQuadratureTreeShapPoints = detail::kQuadratureTreeShapPoints;
constexpr float kQuadratureTreeShapUnseen = detail::kQuadratureTreeShapUnseen;

struct QuadratureRule {
std::array<float, kQuadratureTreeShapPoints> nodes{};
std::array<float, kQuadratureTreeShapPoints> weights{};
};
using QuadratureRule = detail::QuadratureRule;
using QuadratureBuffer = std::array<float, kQuadratureTreeShapPoints>;

QuadratureRule const &GetQuadratureRule() {
static QuadratureRule const kRule = [] {
auto const rule_d =
detail::MakeEndpointQuadrature<kQuadratureTreeShapPoints>(kQuadratureTreeShapBuildQeps);
QuadratureRule out;
for (std::size_t i = 0; i < kQuadratureTreeShapPoints; ++i) {
out.nodes[i] = static_cast<float>(rule_d.nodes[i]);
out.weights[i] = static_cast<float>(rule_d.weights[i]);
}
return out;
}();
return kRule;
}

void AddInPlace(QuadratureBuffer *lhs, QuadratureBuffer const &rhs) {
for (std::size_t i = 0; i < kQuadratureTreeShapPoints; ++i) {
(*lhs)[i] += rhs[i];
Expand Down Expand Up @@ -395,8 +346,9 @@ struct QuadratureTreeShapRunner {

[[nodiscard]] float ChildWeight(bst_node_t parent, bst_node_t child) const {
auto parent_cover = tree.Stat(parent).sum_hess;
CHECK_GT(parent_cover, 0.0f);
return tree.Stat(child).sum_hess / parent_cover;
CHECK_GE(parent_cover, 0.0f);
CHECK_GE(tree.Stat(child).sum_hess, 0.0f);
return detail::BranchWeight(tree.Stat(child).sum_hess, parent_cover);
}

void VisitChild(bst_node_t split_node, bst_node_t child_node, float child_weight, bool satisfies,
Expand Down Expand Up @@ -485,7 +437,6 @@ QuadratureTreeShapModelData MakeQuadratureTreeShapModelData(
out.trees.reserve(n_trees);
out.trees_by_group.resize(n_groups);
out.weights.resize(n_trees, 1.0f);
out.group_root_mean_sums.resize(n_groups, 0.0f);

for (std::size_t i = 0; i < n_trees; ++i) {
out.trees.emplace_back(model.trees[i].get());
Expand All @@ -495,10 +446,10 @@ QuadratureTreeShapModelData MakeQuadratureTreeShapModelData(
auto weight = tree_weights == nullptr ? 1.0f : (*tree_weights)[i];
out.trees_by_group[gid].push_back(i);
out.weights[i] = weight;
ValidateQuadratureTreeShapCovers(out.trees[i], RegTree::kRoot);
out.group_root_mean_sums[gid] +=
static_cast<float>(FillRootMeanValue(out.trees[i], RegTree::kRoot) * weight);
}
out.group_root_mean_sums = detail::MakeGroupRootMeanSums(
h_tree_groups, n_groups, tree_end, tree_weights,
[&](bst_tree_t tree_idx) -> tree::ScalarTreeView const & { return out.trees[tree_idx]; });
return out;
}

Expand Down Expand Up @@ -576,7 +527,7 @@ void QuadratureTreeShapValues(Context const *ctx, DMatrix *p_fmat,
contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group);
std::fill(contribs.begin(), contribs.end(), 0.0f);
CHECK_NE(n_groups, 0);
auto const &rule = GetQuadratureRule();
auto const &rule = detail::GetQuadratureRule();
auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU());
auto model_data = MakeQuadratureTreeShapModelData(model, tree_end, tree_weights);
std::vector<RegTree::FVec> feats_tloc(n_threads);
Expand Down Expand Up @@ -656,7 +607,7 @@ void QuadratureTreeShapInteractionValues(Context const *ctx, DMatrix *p_fmat,
contribs.resize(info.num_row_ * row_chunk);
std::fill(contribs.begin(), contribs.end(), 0.0f);

auto const &rule = GetQuadratureRule();
auto const &rule = detail::GetQuadratureRule();
auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU());
auto model_data = MakeQuadratureTreeShapModelData(model, tree_end, tree_weights);
std::vector<RegTree::FVec> feats_tloc(n_threads);
Expand Down
Loading
Loading