From 28112160c626b84662f99af59ad00ac1190780cd Mon Sep 17 00:00:00 2001 From: Muhammad Ali Date: Mon, 4 May 2026 00:58:00 +0500 Subject: [PATCH 1/2] [doc] Add R code tabs to custom_metric_obj tutorial Port Python code blocks to multi-language tab format for the Custom Objective and Evaluation Metric tutorial. Adds equivalent R code alongside existing Python examples. Contributes to #11413 --- doc/tutorials/custom_metric_obj.rst | 376 +++++++++++++++++++--------- 1 file changed, 257 insertions(+), 119 deletions(-) diff --git a/doc/tutorials/custom_metric_obj.rst b/doc/tutorials/custom_metric_obj.rst index 08bf99b328b9..8cc536700971 100644 --- a/doc/tutorials/custom_metric_obj.rst +++ b/doc/tutorials/custom_metric_obj.rst @@ -63,32 +63,47 @@ information, both first and second order gradient, based on model predictions an data labels (or targets). Therefore, a valid objective function should accept two inputs, namely prediction and labels. For implementing ``SLE``, we define: -.. code-block:: python - - import numpy as np - import xgboost as xgb - from typing import Tuple - - def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: - '''Compute the gradient squared log error.''' - y = dtrain.get_label() - return (np.log1p(predt) - np.log1p(y)) / (predt + 1) - - def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: - '''Compute the hessian for squared log error.''' - y = dtrain.get_label() - return ((-np.log1p(predt) + np.log1p(y) + 1) / - np.power(predt + 1, 2)) - - def squared_log(predt: np.ndarray, - dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]: - '''Squared Log Error objective. A simplified version for RMSLE used as - objective function. - ''' - predt[predt < -1] = -1 + 1e-6 - grad = gradient(predt, dtrain) - hess = hessian(predt, dtrain) - return grad, hess +.. tabs:: + .. code-tab:: py + + import numpy as np + import xgboost as xgb + from typing import Tuple + + def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: + '''Compute the gradient squared log error.''' + y = dtrain.get_label() + return (np.log1p(predt) - np.log1p(y)) / (predt + 1) + + def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray: + '''Compute the hessian for squared log error.''' + y = dtrain.get_label() + return ((-np.log1p(predt) + np.log1p(y) + 1) / + np.power(predt + 1, 2)) + + def squared_log(predt: np.ndarray, + dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]: + '''Squared Log Error objective. A simplified version for RMSLE used as + objective function. + ''' + predt[predt < -1] = -1 + 1e-6 + grad = gradient(predt, dtrain) + hess = hessian(predt, dtrain) + return grad, hess + + .. code-tab:: r R + + library(xgboost) + + squared_log <- function(preds, dtrain) { + labels <- getinfo(dtrain, "label") + preds <- pmax(preds, -1 + 1e-6) + # Gradient + grad <- (log1p(preds) - log1p(labels)) / (preds + 1) + # Hessian + hess <- (-log1p(preds) + log1p(labels) + 1) / (preds + 1)^2 + return(list(grad = grad, hess = hess)) + } In the above code snippet, ``squared_log`` is the objective function we want. It accepts a @@ -97,12 +112,22 @@ information, including labels and weights (not used here). This objective is th a callback function for XGBoost during training by passing it as an argument to ``xgb.train``: -.. code-block:: python +.. tabs:: + .. code-tab:: py - xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine. - dtrain=dtrain, - num_boost_round=10, - obj=squared_log) + xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine. + dtrain=dtrain, + num_boost_round=10, + obj=squared_log) + + .. code-tab:: r R + + model <- xgb.train( + params = list(tree_method = "hist", seed = 1994), + data = dtrain, + nrounds = 10, + obj = squared_log + ) Notice that in our definition of the objective, whether we subtract the labels from the prediction or the other way around is important. If you find the training error goes up @@ -117,14 +142,25 @@ So after having a customized objective, we might also need a corresponding metri monitor our model's performance. As mentioned above, the default metric for ``SLE`` is ``RMSLE``. Similarly we define another callback like function as the new metric: -.. code-block:: python +.. tabs:: + .. code-tab:: py + + def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: + ''' Root mean squared log error metric.''' + y = dtrain.get_label() + predt[predt < -1] = -1 + 1e-6 + elements = np.power(np.log1p(y) - np.log1p(predt), 2) + return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y))) - def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]: - ''' Root mean squared log error metric.''' - y = dtrain.get_label() - predt[predt < -1] = -1 + 1e-6 - elements = np.power(np.log1p(y) - np.log1p(predt), 2) - return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y))) + .. code-tab:: r R + + rmsle <- function(preds, dtrain) { + labels <- getinfo(dtrain, "label") + preds <- pmax(preds, -1 + 1e-6) + elements <- (log1p(labels) - log1p(preds))^2 + err <- sqrt(sum(elements) / length(labels)) + return(list(metric = "RRMSLE", value = err)) + } Since we are demonstrating in Python, the metric or objective need not be a function, any callable object should suffice. Similar to the objective function, our metric also @@ -132,16 +168,31 @@ accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric i a floating point value as the result. After passing it into XGBoost as argument of ``custom_metric`` parameter: -.. code-block:: python - - xgb.train({'tree_method': 'hist', 'seed': 1994, - 'disable_default_eval_metric': 1}, - dtrain=dtrain, - num_boost_round=10, - obj=squared_log, - custom_metric=rmsle, - evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], - evals_result=results) +.. tabs:: + .. code-tab:: py + + xgb.train({'tree_method': 'hist', 'seed': 1994, + 'disable_default_eval_metric': 1}, + dtrain=dtrain, + num_boost_round=10, + obj=squared_log, + custom_metric=rmsle, + evals=[(dtrain, 'dtrain'), (dtest, 'dtest')], + evals_result=results) + + .. code-tab:: r R + + results <- list() + model <- xgb.train( + params = list(tree_method = "hist", seed = 1994, + disable_default_eval_metric = TRUE), + data = dtrain, + nrounds = 10, + obj = squared_log, + feval = rmsle, + evals = list(dtrain = dtrain, dtest = dtest), + evals_result = results + ) We will be able to see XGBoost printing something like: @@ -182,94 +233,175 @@ metric functions implementing the same underlying metric for comparison, `merror_with_transform` is used when custom objective is also used, otherwise the simpler `merror` is preferred since XGBoost can perform the transformation itself. -.. code-block:: python - - import xgboost as xgb - import numpy as np - - def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix): - """Used when custom objective is supplied.""" - y = dtrain.get_label() - n_classes = predt.size // y.shape[0] - # Like custom objective, the predt is untransformed leaf weight when custom objective - # is provided. - - # With the use of `custom_metric` parameter in train function, custom metric receives - # raw input only when custom objective is also being used. Otherwise custom metric - # will receive transformed prediction. - assert predt.shape == (d_train.num_row(), n_classes) - out = np.zeros(dtrain.num_row()) - for r in range(predt.shape[0]): - i = np.argmax(predt[r]) - out[r] = i - - assert y.shape == out.shape - - errors = np.zeros(dtrain.num_row()) - errors[y != out] = 1.0 - return 'PyMError', np.sum(errors) / dtrain.num_row() +.. tabs:: + .. code-tab:: py + + import xgboost as xgb + import numpy as np + + def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix): + """Used when custom objective is supplied.""" + y = dtrain.get_label() + n_classes = predt.size // y.shape[0] + # Like custom objective, the predt is untransformed leaf weight when custom objective + # is provided. + + # With the use of `custom_metric` parameter in train function, custom metric receives + # raw input only when custom objective is also being used. Otherwise custom metric + # will receive transformed prediction. + assert predt.shape == (d_train.num_row(), n_classes) + out = np.zeros(dtrain.num_row()) + for r in range(predt.shape[0]): + i = np.argmax(predt[r]) + out[r] = i + + assert y.shape == out.shape + + errors = np.zeros(dtrain.num_row()) + errors[y != out] = 1.0 + return 'PyMError', np.sum(errors) / dtrain.num_row() + + .. code-tab:: r R + + library(xgboost) + + merror_with_transform <- function(preds, dtrain) { + # Used when custom objective is supplied. + # Predictions are raw (untransformed) when custom objective is provided. + labels <- getinfo(dtrain, "label") + n_samples <- length(labels) + n_classes <- length(preds) / n_samples + # Reshape predictions into matrix (n_samples x n_classes) + pred_matrix <- matrix(preds, nrow = n_samples, ncol = n_classes, byrow = TRUE) + # Get predicted class (0-indexed to match labels) + out <- max.col(pred_matrix) - 1 + err <- sum(labels != out) / n_samples + return(list(metric = "RMError", value = err)) + } The above function is only needed when we want to use custom objective and XGBoost doesn't know how to transform the prediction. The normal implementation for multi-class error function is: -.. code-block:: python - - def merror(predt: np.ndarray, dtrain: xgb.DMatrix): - """Used when there's no custom objective.""" - # No need to do transform, XGBoost handles it internally. - errors = np.zeros(dtrain.num_row()) - errors[y != out] = 1.0 - return 'PyMError', np.sum(errors) / dtrain.num_row() +.. tabs:: + .. code-tab:: py + def merror(predt: np.ndarray, dtrain: xgb.DMatrix): + """Used when there's no custom objective.""" + # No need to do transform, XGBoost handles it internally. + errors = np.zeros(dtrain.num_row()) + errors[y != out] = 1.0 + return 'PyMError', np.sum(errors) / dtrain.num_row() -Next we need the custom softprob objective: + .. code-tab:: r R -.. code-block:: python + merror <- function(preds, dtrain) { + # Used when there's no custom objective. + # No need to transform, XGBoost handles it internally. + labels <- getinfo(dtrain, "label") + err <- sum(labels != preds) / length(labels) + return(list(metric = "RMError", value = err)) + } - def softprob_obj(predt: np.ndarray, data: xgb.DMatrix): - """Loss function. Computing the gradient and approximated hessian (diagonal). - Reimplements the `multi:softprob` inside XGBoost. - """ - # Full implementation is available in the Python demo script linked below - ... +Next we need the custom softprob objective: - return grad, hess +.. tabs:: + .. code-tab:: py + + def softprob_obj(predt: np.ndarray, data: xgb.DMatrix): + """Loss function. Computing the gradient and approximated hessian (diagonal). + Reimplements the `multi:softprob` inside XGBoost. + """ + + # Full implementation is available in the Python demo script linked below + ... + + return grad, hess + + .. code-tab:: r R + + softprob_obj <- function(preds, dtrain) { + # Loss function. Computing the gradient and approximated hessian (diagonal). + # Reimplements the `multi:softprob` inside XGBoost. + labels <- getinfo(dtrain, "label") + n_samples <- length(labels) + n_classes <- length(preds) / n_samples + # Reshape predictions + pred_matrix <- matrix(preds, nrow = n_samples, ncol = n_classes, byrow = TRUE) + # Softmax transform + pred_matrix <- exp(pred_matrix) + pred_matrix <- pred_matrix / rowSums(pred_matrix) + # Gradient and hessian + grad <- pred_matrix + for (i in seq_len(n_samples)) { + grad[i, labels[i] + 1] <- grad[i, labels[i] + 1] - 1 + } + hess <- pmax(2 * pred_matrix * (1 - pred_matrix), 1e-6) + return(list(grad = as.vector(t(grad)), hess = as.vector(t(hess)))) + } Lastly we can train the model using ``obj`` and ``custom_metric`` parameters: -.. code-block:: python - - Xy = xgb.DMatrix(X, y) - booster = xgb.train( - {"num_class": kClasses, "disable_default_eval_metric": True}, - m, - num_boost_round=kRounds, - obj=softprob_obj, - custom_metric=merror_with_transform, - evals_result=custom_results, - evals=[(m, "train")], - ) +.. tabs:: + .. code-tab:: py + + Xy = xgb.DMatrix(X, y) + booster = xgb.train( + {"num_class": kClasses, "disable_default_eval_metric": True}, + m, + num_boost_round=kRounds, + obj=softprob_obj, + custom_metric=merror_with_transform, + evals_result=custom_results, + evals=[(m, "train")], + ) + + .. code-tab:: r R + + dtrain <- xgb.DMatrix(data = X, label = y) + model <- xgb.train( + params = list(num_class = kClasses, + disable_default_eval_metric = TRUE), + data = dtrain, + nrounds = kRounds, + obj = softprob_obj, + feval = merror_with_transform, + evals = list(train = dtrain) + ) Or if you don't need the custom objective and just want to supply a metric that's not available in XGBoost: -.. code-block:: python - - booster = xgb.train( - { - "num_class": kClasses, - "disable_default_eval_metric": True, - "objective": "multi:softmax", - }, - m, - num_boost_round=kRounds, - # Use a simpler metric implementation. - custom_metric=merror, - evals_result=custom_results, - evals=[(m, "train")], - ) +.. tabs:: + .. code-tab:: py + + booster = xgb.train( + { + "num_class": kClasses, + "disable_default_eval_metric": True, + "objective": "multi:softmax", + }, + m, + num_boost_round=kRounds, + # Use a simpler metric implementation. + custom_metric=merror, + evals_result=custom_results, + evals=[(m, "train")], + ) + + .. code-tab:: r R + + model <- xgb.train( + params = list(num_class = kClasses, + disable_default_eval_metric = TRUE, + objective = "multi:softmax"), + data = dtrain, + nrounds = kRounds, + # Use a simpler metric implementation. + feval = merror, + evals = list(train = dtrain) + ) We use ``multi:softmax`` to illustrate the differences of transformed prediction. With ``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for @@ -282,6 +414,12 @@ available at :ref:`sphx_glr_python_examples_custom_softmax.py`. Also, see Scikit-Learn Interface ********************** +.. note:: + + The scikit-learn interface is Python-specific. R users can use the native + ``xgb.train()`` interface with custom objective and evaluation functions as shown + in the examples above. + The scikit-learn interface of XGBoost has some utilities to improve the integration with standard scikit-learn functions. For instance, after XGBoost 1.6.0 users can use the cost function (not scoring functions) from scikit-learn out of the box: @@ -321,4 +459,4 @@ access ``DMatrix``: hess = hess.reshape((rows * classes, 1)) return grad, hess - clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj) + clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj) \ No newline at end of file From a7abc439fd0f500e260aea3d828ba35a64266494 Mon Sep 17 00:00:00 2001 From: Muhammad Ali Date: Thu, 7 May 2026 00:38:42 +0500 Subject: [PATCH 2/2] doc: fix R API issues and Python undefined vars in custom_metric_obj tutorial --- doc/tutorials/custom_metric_obj.rst | 43 +++++++++++++++++------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/doc/tutorials/custom_metric_obj.rst b/doc/tutorials/custom_metric_obj.rst index 8cc536700971..25fc639e925a 100644 --- a/doc/tutorials/custom_metric_obj.rst +++ b/doc/tutorials/custom_metric_obj.rst @@ -126,7 +126,7 @@ a callback function for XGBoost during training by passing it as an argument to params = list(tree_method = "hist", seed = 1994), data = dtrain, nrounds = 10, - obj = squared_log + objective = squared_log ) Notice that in our definition of the objective, whether we subtract the labels from the @@ -162,7 +162,7 @@ monitor our model's performance. As mentioned above, the default metric for ``S return(list(metric = "RRMSLE", value = err)) } -Since we are demonstrating in Python, the metric or objective need not be a function, any +For the Python tab, the metric or objective need not be a function, any callable object should suffice. Similar to the objective function, our metric also accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and a floating point value as the result. After passing it into XGBoost as argument of @@ -185,11 +185,11 @@ a floating point value as the result. After passing it into XGBoost as argument results <- list() model <- xgb.train( params = list(tree_method = "hist", seed = 1994, - disable_default_eval_metric = TRUE), + disable_default_eval_metric = TRUE, + objective = squared_log), data = dtrain, nrounds = 10, - obj = squared_log, - feval = rmsle, + custom_metric = rmsle, evals = list(dtrain = dtrain, dtest = dtest), evals_result = results ) @@ -249,7 +249,7 @@ metric functions implementing the same underlying metric for comparison, # With the use of `custom_metric` parameter in train function, custom metric receives # raw input only when custom objective is also being used. Otherwise custom metric # will receive transformed prediction. - assert predt.shape == (d_train.num_row(), n_classes) + assert predt.shape == (dtrain.num_row(), n_classes) out = np.zeros(dtrain.num_row()) for r in range(predt.shape[0]): i = np.argmax(predt[r]) @@ -270,9 +270,10 @@ metric functions implementing the same underlying metric for comparison, # Predictions are raw (untransformed) when custom objective is provided. labels <- getinfo(dtrain, "label") n_samples <- length(labels) - n_classes <- length(preds) / n_samples - # Reshape predictions into matrix (n_samples x n_classes) - pred_matrix <- matrix(preds, nrow = n_samples, ncol = n_classes, byrow = TRUE) + # In the R package, multi-class predictions are already provided as a + # matrix with shape (n_samples x n_classes). + pred_matrix <- preds + stopifnot(is.matrix(pred_matrix), nrow(pred_matrix) == n_samples) # Get predicted class (0-indexed to match labels) out <- max.col(pred_matrix) - 1 err <- sum(labels != out) / n_samples @@ -289,6 +290,8 @@ function is: def merror(predt: np.ndarray, dtrain: xgb.DMatrix): """Used when there's no custom objective.""" # No need to do transform, XGBoost handles it internally. + y = dtrain.get_label() + out = predt errors = np.zeros(dtrain.num_row()) errors[y != out] = 1.0 return 'PyMError', np.sum(errors) / dtrain.num_row() @@ -298,8 +301,12 @@ function is: merror <- function(preds, dtrain) { # Used when there's no custom objective. # No need to transform, XGBoost handles it internally. + # For multi-class custom metrics in R, preds contains per-class scores. labels <- getinfo(dtrain, "label") - err <- sum(labels != preds) / length(labels) + n_samples <- length(labels) + pred_matrix <- matrix(preds, nrow = n_samples, ncol = length(preds) / n_samples, byrow = TRUE) + out <- max.col(pred_matrix) - 1 + err <- sum(labels != out) / n_samples return(list(metric = "RMError", value = err)) } @@ -326,9 +333,9 @@ Next we need the custom softprob objective: # Reimplements the `multi:softprob` inside XGBoost. labels <- getinfo(dtrain, "label") n_samples <- length(labels) - n_classes <- length(preds) / n_samples - # Reshape predictions - pred_matrix <- matrix(preds, nrow = n_samples, ncol = n_classes, byrow = TRUE) + # In the R package, multi-class predictions are already provided as a + # matrix with shape (n_samples x n_classes). + pred_matrix <- preds # Softmax transform pred_matrix <- exp(pred_matrix) pred_matrix <- pred_matrix / rowSums(pred_matrix) @@ -338,7 +345,7 @@ Next we need the custom softprob objective: grad[i, labels[i] + 1] <- grad[i, labels[i] + 1] - 1 } hess <- pmax(2 * pred_matrix * (1 - pred_matrix), 1e-6) - return(list(grad = as.vector(t(grad)), hess = as.vector(t(hess)))) + return(list(grad = grad, hess = hess)) } Lastly we can train the model using ``obj`` and ``custom_metric`` parameters: @@ -362,11 +369,11 @@ Lastly we can train the model using ``obj`` and ``custom_metric`` parameters: dtrain <- xgb.DMatrix(data = X, label = y) model <- xgb.train( params = list(num_class = kClasses, - disable_default_eval_metric = TRUE), + disable_default_eval_metric = TRUE, + objective = softprob_obj), data = dtrain, nrounds = kRounds, - obj = softprob_obj, - feval = merror_with_transform, + custom_metric = merror_with_transform, evals = list(train = dtrain) ) @@ -399,7 +406,7 @@ available in XGBoost: data = dtrain, nrounds = kRounds, # Use a simpler metric implementation. - feval = merror, + custom_metric = merror, evals = list(train = dtrain) )