From 1f8e85f17854e17e54074e1bff1fba4805779559 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 02:53:36 -0700 Subject: [PATCH 1/7] Pass objective args through factory --- include/xgboost/objective.h | 15 +++--- plugin/example/custom_obj.cc | 20 ++++---- src/learner.cc | 10 ++-- src/objective/aft_obj.cu | 86 +++++++++++++++------------------ src/objective/hinge.cu | 2 +- src/objective/lambdarank_obj.cc | 15 ++++-- src/objective/multiclass_obj.cu | 9 +++- src/objective/objective.cc | 9 ++-- src/objective/quantile_obj.cu | 21 +++++--- src/objective/regression_obj.cu | 56 +++++++++++++++------ 10 files changed, 140 insertions(+), 103 deletions(-) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 497821590bc9..4a29ff75faa3 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -136,15 +136,14 @@ class ObjFunction : public Configurable { * @param name Name of the objective. * @param ctx Pointer to the context. */ - static ObjFunction* Create(const std::string& name, Context const* ctx); + static ObjFunction* Create(const std::string& name, Context const* ctx, Args const& args = {}); }; /*! * \brief Registry entry for objective factory functions. */ struct ObjFunctionReg - : public dmlc::FunctionRegEntryBase > { + : public dmlc::FunctionRegEntryBase > { }; /*! @@ -154,14 +153,14 @@ struct ObjFunctionReg * // example of registering a objective * XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:squarederror") * .describe("Linear regression objective") - * .set_body([]() { + * .set_body([](Args const&) { * return new RegLossObj(LossType::kLinearSquare); * }); * \endcode */ -#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \ - static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \ - __make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \ - ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name) +#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg& \ + __make_##ObjFunctionReg##_##UniqueId##__ = \ + ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name) } // namespace xgboost #endif // XGBOOST_OBJECTIVE_H_ diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index 86f941945518..7b278a9d00cf 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -19,7 +19,9 @@ struct MyLogisticParam : public XGBoostParameter { float scale_neg_weight; // declare parameters DMLC_DECLARE_PARAMETER(MyLogisticParam) { - DMLC_DECLARE_FIELD(scale_neg_weight).set_default(1.0f).set_lower_bound(0.0f) + DMLC_DECLARE_FIELD(scale_neg_weight) + .set_default(1.0f) + .set_lower_bound(0.0f) .describe("Scale the weight of negative examples by this factor"); } }; @@ -53,12 +55,10 @@ class MyLogistic : public ObjFunction { out_gpair_h(i) = GradientPair(grad, hess); } } - [[nodiscard]] const char* DefaultEvalMetric() const override { - return "logloss"; - } - void PredTransform(HostDeviceVector *io_preds) const override { + [[nodiscard]] const char* DefaultEvalMetric() const override { return "logloss"; } + void PredTransform(HostDeviceVector* io_preds) const override { // transform margin value to probability. - std::vector &preds = io_preds->HostVector(); + std::vector& preds = io_preds->HostVector(); for (auto& pred : preds) { pred = 1.0f / (1.0f + std::exp(-pred)); } @@ -77,9 +77,7 @@ class MyLogistic : public ObjFunction { out["my_logistic_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { - FromJson(in["my_logistic_param"], ¶m_); - } + void LoadConfig(Json const& in) override { FromJson(in["my_logistic_param"], ¶m_); } private: MyLogisticParam param_; @@ -88,7 +86,7 @@ class MyLogistic : public ObjFunction { // Finally register the objective function. // After it succeeds you can try use xgboost with objective=mylogistic XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic") -.describe("User defined logistic regression plugin") -.set_body([]() { return new MyLogistic(); }); + .describe("User defined logistic regression plugin") + .set_body([](Args const&) { return new MyLogistic(); }); } // namespace xgboost::obj diff --git a/src/learner.cc b/src/learner.cc index 23049e8e3a33..e1dc75e06aeb 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -847,17 +847,17 @@ class LearnerConfiguration : public Intercept { // Rename one of them once binary IO is gone. cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue; } - if (obj_ == nullptr || tparam_.objective != old.objective) { - obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_)); - } - bool has_nc{cfg_.find("num_class") != cfg_.cend()}; // Inject num_class into configuration. // FIXME(jiamingy): Remove the duplicated parameter in softmax cfg_["num_class"] = std::to_string(mparam_.num_class); auto& args = *p_args; args = {cfg_.cbegin(), cfg_.cend()}; // renew - obj_->Configure(args); + if (obj_ == nullptr || tparam_.objective != old.objective) { + obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_, args)); + } else { + obj_->Configure(args); + } if (!has_nc) { cfg_.erase("num_class"); } diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index f535fa0aecae..f09719bbae18 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -31,9 +31,10 @@ DMLC_REGISTRY_FILE_TAG(aft_obj_gpu); class AFTObj : public ObjFunction { public: - void Configure(Args const& args) override { - param_.UpdateAllowUnknown(args); - } + explicit AFTObj(Args const& args) { param_.UpdateAllowUnknown(args); } + AFTObj() = default; + + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } ObjInfo Task() const override { return ObjInfo::kSurvival; } @@ -42,27 +43,24 @@ class AFTObj : public ObjFunction { linalg::Matrix* out_gpair, size_t ndata, DeviceOrd device, bool is_null_weight, float aft_loss_distribution_scale) { common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, - common::Span _out_gpair, - common::Span _preds, - common::Span _labels_lower_bound, - common::Span _labels_upper_bound, - common::Span _weights) { - const double pred = static_cast(_preds[_idx]); - const double label_lower_bound = static_cast(_labels_lower_bound[_idx]); - const double label_upper_bound = static_cast(_labels_upper_bound[_idx]); - const float grad = static_cast( - AFTLoss::Gradient(label_lower_bound, label_upper_bound, - pred, aft_loss_distribution_scale)); - const float hess = static_cast( - AFTLoss::Hessian(label_lower_bound, label_upper_bound, - pred, aft_loss_distribution_scale)); - const bst_float w = is_null_weight ? 1.0f : _weights[_idx]; - _out_gpair[_idx] = GradientPair(grad * w, hess * w); - }, - common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device).Eval( - out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, - &info.weights_); + [=] XGBOOST_DEVICE(size_t _idx, common::Span _out_gpair, + common::Span _preds, + common::Span _labels_lower_bound, + common::Span _labels_upper_bound, + common::Span _weights) { + const double pred = static_cast(_preds[_idx]); + const double label_lower_bound = static_cast(_labels_lower_bound[_idx]); + const double label_upper_bound = static_cast(_labels_upper_bound[_idx]); + const float grad = static_cast(AFTLoss::Gradient( + label_lower_bound, label_upper_bound, pred, aft_loss_distribution_scale)); + const float hess = static_cast(AFTLoss::Hessian( + label_lower_bound, label_upper_bound, pred, aft_loss_distribution_scale)); + const bst_float w = is_null_weight ? 1.0f : _weights[_idx]; + _out_gpair[_idx] = GradientPair(grad * w, hess * w); + }, + common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), device) + .Eval(out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, + &info.weights_); } void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int /*iter*/, @@ -77,28 +75,28 @@ class AFTObj : public ObjFunction { const bool is_null_weight = info.weights_.Size() == 0; if (!is_null_weight) { CHECK_EQ(info.weights_.Size(), ndata) - << "Number of weights should be equal to number of data points."; + << "Number of weights should be equal to number of data points."; } switch (param_.aft_loss_distribution) { - case common::ProbabilityDistributionType::kNormal: - GetGradientImpl(preds, info, out_gpair, ndata, device, - is_null_weight, aft_loss_distribution_scale); - break; - case common::ProbabilityDistributionType::kLogistic: - GetGradientImpl(preds, info, out_gpair, ndata, device, + case common::ProbabilityDistributionType::kNormal: + GetGradientImpl(preds, info, out_gpair, ndata, device, is_null_weight, aft_loss_distribution_scale); - break; - case common::ProbabilityDistributionType::kExtreme: - GetGradientImpl(preds, info, out_gpair, ndata, device, - is_null_weight, aft_loss_distribution_scale); - break; - default: - LOG(FATAL) << "Unrecognized distribution"; + break; + case common::ProbabilityDistributionType::kLogistic: + GetGradientImpl(preds, info, out_gpair, ndata, device, + is_null_weight, aft_loss_distribution_scale); + break; + case common::ProbabilityDistributionType::kExtreme: + GetGradientImpl(preds, info, out_gpair, ndata, device, + is_null_weight, aft_loss_distribution_scale); + break; + default: + LOG(FATAL) << "Unrecognized distribution"; } } - void PredTransform(HostDeviceVector *io_preds) const override { + void PredTransform(HostDeviceVector* io_preds) const override { // Trees give us a prediction in log scale, so exponentiate common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { @@ -120,9 +118,7 @@ class AFTObj : public ObjFunction { }); } - const char* DefaultEvalMetric() const override { - return "aft-nloglik"; - } + const char* DefaultEvalMetric() const override { return "aft-nloglik"; } void SaveConfig(Json* p_out) const override { auto& out = *p_out; @@ -130,9 +126,7 @@ class AFTObj : public ObjFunction { out["aft_loss_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { - FromJson(in["aft_loss_param"], ¶m_); - } + void LoadConfig(Json const& in) override { FromJson(in["aft_loss_param"], ¶m_); } Json DefaultMetricConfig() const override { Json config{Object{}}; config["name"] = String{this->DefaultEvalMetric()}; @@ -147,7 +141,7 @@ class AFTObj : public ObjFunction { // register the objective functions XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft") .describe("AFT loss function") - .set_body([]() { return new AFTObj(); }); + .set_body([](Args const& args) { return new AFTObj{args}; }); } // namespace obj } // namespace xgboost diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index 285f65c6f4f5..fee7c4d2525f 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -99,6 +99,6 @@ class HingeObj : public FitIntercept { // register the objective functions XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") .describe("Hinge loss. Expects labels to be in [0,1f]") - .set_body([]() { return new HingeObj(); }); + .set_body([](Args const &) { return new HingeObj(); }); } // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index 3660a5b59ce2..eb6256f27242 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -250,6 +250,9 @@ class LambdaRankObj : public FitIntercept { } public: + explicit LambdaRankObj(Args const& args) { param_.UpdateAllowUnknown(args); } + LambdaRankObj() = default; + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } void SaveConfig(Json* p_out) const override { auto& out = *p_out; @@ -327,6 +330,8 @@ class LambdaRankObj : public FitIntercept { class LambdaRankNDCG : public LambdaRankObj { public: + using LambdaRankObj::LambdaRankObj; + template void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span g_predt, linalg::VectorView g_label, float w, @@ -474,6 +479,8 @@ void MAPStat(Context const* ctx, linalg::VectorView label, class LambdaRankMAP : public LambdaRankObj { public: + using LambdaRankObj::LambdaRankObj; + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, const MetaInfo& info, linalg::Matrix* out_gpair) { if (ctx_->IsCUDA()) { @@ -574,6 +581,8 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector { public: + using LambdaRankObj::LambdaRankObj; + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, const MetaInfo& info, linalg::Matrix* out_gpair) { if (ctx_->IsCUDA()) { @@ -657,15 +666,15 @@ void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVecto XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name()) .describe("LambdaRank with NDCG loss as objective") - .set_body([]() { return new LambdaRankNDCG{}; }); + .set_body([](Args const& args) { return new LambdaRankNDCG{args}; }); XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name()) .describe("LambdaRank with RankNet loss as objective") - .set_body([]() { return new LambdaRankPairwise{}; }); + .set_body([](Args const& args) { return new LambdaRankPairwise{args}; }); XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name()) .describe("LambdaRank with MAP loss as objective.") - .set_body([]() { return new LambdaRankMAP{}; }); + .set_body([](Args const& args) { return new LambdaRankMAP{args}; }); DMLC_REGISTRY_FILE_TAG(lambdarank_obj); } // namespace xgboost::obj diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 46bfbff686fc..d181f67e94bf 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -69,6 +69,11 @@ void ValidateLabel(Context const* ctx, MetaInfo const& info, std::int64_t n_clas class SoftmaxMultiClassObj : public ObjFunction { public: explicit SoftmaxMultiClassObj(bool output_prob) : output_prob_(output_prob) {} + SoftmaxMultiClassObj(bool output_prob, Args const& args) : output_prob_(output_prob) { + if (!args.empty()) { + param_.UpdateAllowUnknown(args); + } + } void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } @@ -233,9 +238,9 @@ DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam); XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax") .describe("Softmax for multi-class classification, output class index.") - .set_body([]() { return new SoftmaxMultiClassObj(false); }); + .set_body([](Args const& args) { return new SoftmaxMultiClassObj(false, args); }); XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob") .describe("Softmax for multi-class classification, output probability distribution.") - .set_body([]() { return new SoftmaxMultiClassObj(true); }); + .set_body([](Args const& args) { return new SoftmaxMultiClassObj(true, args); }); } // namespace xgboost::obj diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 8731394dfc25..e5ed37e4ac39 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -16,18 +16,17 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); namespace xgboost { // implement factory functions -ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) { +ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx, Args const& args) { std::string obj_name = name; - auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name); + auto* e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name); if (e == nullptr) { std::stringstream ss; for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) { ss << "Objective candidate: " << entry->name << "\n"; } - LOG(FATAL) << "Unknown objective function: `" << name << "`\n" - << ss.str(); + LOG(FATAL) << "Unknown objective function: `" << name << "`\n" << ss.str(); } - auto pobj = (e->body)(); + auto pobj = (e->body)(args); pobj->ctx_ = ctx; return pobj; } diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index c9241bb627f0..ed717c09a521 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -1,10 +1,10 @@ /** * Copyright 2023-2026, XGBoost contributors */ -#include // std::array -#include // std::size_t -#include // std::int32_t -#include // std::vector +#include // std::array +#include // std::size_t +#include // std::int32_t +#include // std::vector #include "../common/linalg_op.h" // ElementWiseKernel,cbegin,cend #include "../common/quantile_loss_utils.h" // QuantileLossParam @@ -20,9 +20,9 @@ #if defined(XGBOOST_USE_CUDA) -#include "../common/stats.cuh" // SegmentedQuantile +#include "../common/stats.cuh" // SegmentedQuantile -#endif // defined(XGBOOST_USE_CUDA) +#endif // defined(XGBOOST_USE_CUDA) namespace xgboost::obj { class QuantileRegression : public ObjFunction { @@ -45,6 +45,13 @@ class QuantileRegression : public ObjFunction { } public: + explicit QuantileRegression(Args const& args) { + if (!args.empty()) { + this->Configure(args); + } + } + QuantileRegression() = default; + void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, std::int32_t iter, linalg::Matrix* out_gpair) override { if (iter == 0) { @@ -207,7 +214,7 @@ class QuantileRegression : public ObjFunction { XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name()) .describe("Regression with quantile loss.") - .set_body([]() { return new QuantileRegression(); }); + .set_body([](Args const& args) { return new QuantileRegression{args}; }); #if defined(XGBOOST_USE_CUDA) DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index ce8203a000c0..b6ba8e508801 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -131,6 +131,7 @@ class RegLossObj : public FitInterceptGlmLike { public: // 0 - scale_pos_weight, 1 - is_null_weight RegLossObj() : additional_input_(2) {} + explicit RegLossObj(Args const& args) : additional_input_(2) { param_.UpdateAllowUnknown(args); } void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } @@ -247,32 +248,32 @@ DMLC_REGISTER_PARAMETER(RegLossParam); XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name()) .describe("Regression with squared error.") - .set_body([]() { return new RegLossObj(); }); + .set_body([](Args const& args) { return new RegLossObj{args}; }); XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name()) .describe("Logistic regression for probability regression task.") - .set_body([]() { return new RegLossObj(); }); + .set_body([](Args const& args) { return new RegLossObj{args}; }); XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name()) .describe("Logistic regression for binary classification task.") - .set_body([]() { return new RegLossObj(); }); + .set_body([](Args const& args) { return new RegLossObj{args}; }); XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, LogisticRaw::Name()) .describe( "Logistic regression for classification, output score " "before logistic transformation.") - .set_body([]() { return new RegLossObj(); }); + .set_body([](Args const& args) { return new RegLossObj{args}; }); XGBOOST_REGISTER_OBJECTIVE(GammaRegression, GammaDeviance::Name()) .describe("Gamma regression using the gamma deviance loss with log link.") - .set_body([]() { return new RegLossObj(); }); + .set_body([](Args const& args) { return new RegLossObj{args}; }); // Deprecated functions XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") .describe("Regression with squared error.") - .set_body([]() { + .set_body([](Args const& args) { LOG(WARNING) << "reg:linear is now deprecated in favor of reg:squarederror."; - return new RegLossObj(); + return new RegLossObj{args}; }); // End deprecated @@ -322,12 +323,15 @@ class SquaredLogErrorRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(SquaredLogErrorRegression, SquaredLogErrorRegression::Name()) .describe("Root mean squared log error.") - .set_body([]() { return new SquaredLogErrorRegression(); }); + .set_body([](Args const&) { return new SquaredLogErrorRegression(); }); class PseudoHuberRegression : public FitIntercept { PseudoHuberParam param_; public: + explicit PseudoHuberRegression(Args const& args) { param_.UpdateAllowUnknown(args); } + PseudoHuberRegression() = default; + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { @@ -390,7 +394,7 @@ class PseudoHuberRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(PseudoHuberRegression, "reg:pseudohubererror") .describe("Regression Pseudo Huber error.") - .set_body([]() { return new PseudoHuberRegression(); }); + .set_body([](Args const& args) { return new PseudoHuberRegression{args}; }); class ExpectileRegression : public FitIntercept { common::ExpectileLossParam param_; @@ -409,6 +413,13 @@ class ExpectileRegression : public FitIntercept { } public: + explicit ExpectileRegression(Args const& args) { + if (!args.empty()) { + this->Configure(args); + } + } + ExpectileRegression() = default; + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); param_.Validate(); @@ -541,7 +552,7 @@ class ExpectileRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(ExpectileRegression, "reg:expectileerror") .describe("Regression with expectile loss.") - .set_body([]() { return new ExpectileRegression(); }); + .set_body([](Args const& args) { return new ExpectileRegression{args}; }); // declare parameter struct PoissonRegressionParam : public XGBoostParameter { @@ -559,6 +570,9 @@ struct PoissonRegressionParam : public XGBoostParameter // poisson regression for count class PoissonRegression : public FitInterceptGlmLike { public: + explicit PoissonRegression(Args const& args) { param_.UpdateAllowUnknown(args); } + PoissonRegression() = default; + // declare functions void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } @@ -627,7 +641,7 @@ DMLC_REGISTER_PARAMETER(PoissonRegressionParam); XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") .describe("Poisson regression for count data.") - .set_body([]() { return new PoissonRegression(); }); + .set_body([](Args const& args) { return new PoissonRegression{args}; }); // cox regression for survival data (negative values mean they are censored) class CoxRegression : public FitIntercept { @@ -720,7 +734,7 @@ class CoxRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") .describe( "Cox regression for censored survival data (negative labels are considered censored).") - .set_body([]() { return new CoxRegression(); }); + .set_body([](Args const&) { return new CoxRegression(); }); // declare parameter struct TweedieRegressionParam : public XGBoostParameter { @@ -736,6 +750,13 @@ struct TweedieRegressionParam : public XGBoostParameter // tweedie regression class TweedieRegression : public FitInterceptGlmLike { public: + explicit TweedieRegression(Args const& args) { + if (!args.empty()) { + this->Configure(args); + } + } + TweedieRegression() = default; + // declare functions void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); @@ -798,7 +819,12 @@ class TweedieRegression : public FitInterceptGlmLike { out["name"] = String("reg:tweedie"); out["tweedie_regression_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { FromJson(in["tweedie_regression_param"], ¶m_); } + void LoadConfig(Json const& in) override { + FromJson(in["tweedie_regression_param"], ¶m_); + std::ostringstream os; + os << "tweedie-nloglik@" << param_.tweedie_variance_power; + metric_ = os.str(); + } private: std::string metric_; @@ -810,7 +836,7 @@ DMLC_REGISTER_PARAMETER(TweedieRegressionParam); XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") .describe("Tweedie regression for insurance data.") - .set_body([]() { return new TweedieRegression(); }); + .set_body([](Args const& args) { return new TweedieRegression{args}; }); class MeanAbsoluteError : public ObjFunction { public: @@ -905,5 +931,5 @@ class MeanAbsoluteError : public ObjFunction { XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") .describe("Mean absoluate error.") - .set_body([]() { return new MeanAbsoluteError(); }); + .set_body([](Args const&) { return new MeanAbsoluteError(); }); } // namespace xgboost::obj From ac9f672a6d443fdf6e5e01b549875bb254e6d559 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 03:05:50 -0700 Subject: [PATCH 2/7] Use constructor-time objective init in tests --- tests/cpp/objective/test_aft_obj.cc | 188 +++++++++--------- tests/cpp/objective/test_hinge.cc | 12 +- tests/cpp/objective/test_multiclass_obj.cc | 53 ++--- tests/cpp/objective/test_objective.cc | 27 ++- tests/cpp/objective/test_quantile_obj.cc | 9 +- tests/cpp/objective/test_regression_obj.cc | 67 ++----- .../cpp/objective/test_regression_obj_cpu.cc | 2 +- tests/cpp/plugin/test_example_objective.cc | 7 +- tests/cpp/plugin/test_sycl_regression_obj.cc | 8 +- tests/cpp/predictor/test_shap.cc | 3 +- 10 files changed, 173 insertions(+), 203 deletions(-) diff --git a/tests/cpp/objective/test_aft_obj.cc b/tests/cpp/objective/test_aft_obj.cc index cd031b6bcdf5..23cad514e49f 100644 --- a/tests/cpp/objective/test_aft_obj.cc +++ b/tests/cpp/objective/test_aft_obj.cc @@ -1,25 +1,26 @@ /** * Copyright 2020-2024, XGBoost Contributors */ +#include "test_aft_obj.h" + #include + +#include +#include #include #include -#include -#include -#include "xgboost/objective.h" -#include "xgboost/logging.h" #include "../helpers.h" -#include "test_aft_obj.h" +#include "xgboost/logging.h" +#include "xgboost/objective.h" namespace xgboost::common { void TestAFTObjConfiguration(const Context* ctx) { - std::unique_ptr objective(ObjFunction::Create("survival:aft", ctx)); - objective->Configure({ {"aft_loss_distribution", "logistic"}, - {"aft_loss_distribution_scale", "5"} }); + Args args{{"aft_loss_distribution", "logistic"}, {"aft_loss_distribution_scale", "5"}}; + std::unique_ptr objective(ObjFunction::Create("survival:aft", ctx, args)); // Configuration round-trip test - Json j_obj{ Object() }; + Json j_obj{Object()}; objective->SaveConfig(&j_obj); EXPECT_EQ(get(j_obj["name"]), "survival:aft"); auto aft_param_json = j_obj["aft_loss_param"]; @@ -35,27 +36,22 @@ void TestAFTObjConfiguration(const Context* ctx) { // Generate prediction value ranging from 2**1 to 2**15, using grid points in log scale // Then check prediction against the reference values -static inline void CheckGPairOverGridPoints( - ObjFunction* obj, - bst_float true_label_lower_bound, - bst_float true_label_upper_bound, - const std::string& dist_type, - const std::vector& expected_grad, - const std::vector& expected_hess, - float ftol = 1e-4f) { +static inline void CheckGPairOverGridPoints(ObjFunction* obj, bst_float true_label_lower_bound, + bst_float true_label_upper_bound, + const std::string& dist_type, + const std::vector& expected_grad, + const std::vector& expected_hess, + float ftol = 1e-4f) { const int num_point = 20; const double log_y_low = 1.0; const double log_y_high = 15.0; - obj->Configure({ {"aft_loss_distribution", dist_type}, - {"aft_loss_distribution_scale", "1"} }); + obj->Configure({{"aft_loss_distribution", dist_type}, {"aft_loss_distribution_scale", "1"}}); MetaInfo info; info.num_row_ = num_point; - info.labels_lower_bound_.HostVector() - = std::vector(num_point, true_label_lower_bound); - info.labels_upper_bound_.HostVector() - = std::vector(num_point, true_label_upper_bound); + info.labels_lower_bound_.HostVector() = std::vector(num_point, true_label_lower_bound); + info.labels_upper_bound_.HostVector() = std::vector(num_point, true_label_upper_bound); info.weights_.HostVector() = std::vector(); std::vector preds(num_point); for (int i = 0; i < num_point; ++i) { @@ -76,90 +72,102 @@ static inline void CheckGPairOverGridPoints( void TestAFTObjGPairUncensoredLabels(const Context* ctx) { std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "normal", - { -3.9120f, -3.4013f, -2.8905f, -2.3798f, -1.8691f, -1.3583f, -0.8476f, -0.3368f, 0.1739f, - 0.6846f, 1.1954f, 1.7061f, 2.2169f, 2.7276f, 3.2383f, 3.7491f, 4.2598f, 4.7706f, 5.2813f, - 5.7920f }, - { 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, - 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f }); - CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "logistic", - { -0.9608f, -0.9355f, -0.8948f, -0.8305f, -0.7327f, -0.5910f, -0.4001f, -0.1668f, 0.0867f, - 0.3295f, 0.5354f, 0.6927f, 0.8035f, 0.8773f, 0.9245f, 0.9540f, 0.9721f, 0.9832f, 0.9899f, - 0.9939f }, - { 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f, - 0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f }); + CheckGPairOverGridPoints( + obj.get(), 100.0f, 100.0f, "normal", + {-3.9120f, -3.4013f, -2.8905f, -2.3798f, -1.8691f, -1.3583f, -0.8476f, + -0.3368f, 0.1739f, 0.6846f, 1.1954f, 1.7061f, 2.2169f, 2.7276f, + 3.2383f, 3.7491f, 4.2598f, 4.7706f, 5.2813f, 5.7920f}, + {1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, + 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f}); + CheckGPairOverGridPoints( + obj.get(), 100.0f, 100.0f, "logistic", + {-0.9608f, -0.9355f, -0.8948f, -0.8305f, -0.7327f, -0.5910f, -0.4001f, + -0.1668f, 0.0867f, 0.3295f, 0.5354f, 0.6927f, 0.8035f, 0.8773f, + 0.9245f, 0.9540f, 0.9721f, 0.9832f, 0.9899f, 0.9939f}, + {0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f, + 0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f}); CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme", - { -15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f, - 0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f, - 0.9969f }, - { 15.0000f, 15.0000f, 15.0000f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f, - 0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f }); + {-15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, + -0.4005f, 0.1596f, 0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, + 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f, 0.9969f}, + {15.0000f, 15.0000f, 15.0000f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, + 1.4005f, 0.8404f, 0.5043f, 0.3026f, 0.1816f, 0.1090f, 0.0654f, + 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f}); } void TestAFTObjGPairLeftCensoredLabels(const Context* ctx) { std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "normal", - { 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f, - 3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f }, - { 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f, - 0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f }); - CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "logistic", - { 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f, - 0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f }, - { 0.0826f, 0.1224f, 0.1701f, 0.2163f, 0.2458f, 0.2461f, 0.2170f, 0.1709f, 0.1232f, 0.0832f, - 0.0538f, 0.0338f, 0.0209f, 0.0127f, 0.0077f, 0.0047f, 0.0028f, 0.0017f, 0.0010f, 0.0006f }); - CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "extreme", - { 0.0005f, 0.0149f, 0.1011f, 0.2815f, 0.4881f, 0.6610f, 0.7847f, 0.8665f, 0.9183f, 0.9504f, - 0.9700f, 0.9820f, 0.9891f, 0.9935f, 0.9961f, 0.9976f, 0.9986f, 0.9992f, 0.9995f, 0.9997f }, - { 0.0041f, 0.0747f, 0.2731f, 0.4059f, 0.3829f, 0.2901f, 0.1973f, 0.1270f, 0.0793f, 0.0487f, - 0.0296f, 0.0179f, 0.0108f, 0.0065f, 0.0039f, 0.0024f, 0.0014f, 0.0008f, 0.0005f, 0.0003f }); + CheckGPairOverGridPoints( + obj.get(), 0.0f, 20.0f, "normal", + {0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f, + 3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f}, + {0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f, + 0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f}); + CheckGPairOverGridPoints( + obj.get(), 0.0f, 20.0f, "logistic", + {0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f, + 0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f}, + {0.0826f, 0.1224f, 0.1701f, 0.2163f, 0.2458f, 0.2461f, 0.2170f, 0.1709f, 0.1232f, 0.0832f, + 0.0538f, 0.0338f, 0.0209f, 0.0127f, 0.0077f, 0.0047f, 0.0028f, 0.0017f, 0.0010f, 0.0006f}); + CheckGPairOverGridPoints( + obj.get(), 0.0f, 20.0f, "extreme", + {0.0005f, 0.0149f, 0.1011f, 0.2815f, 0.4881f, 0.6610f, 0.7847f, 0.8665f, 0.9183f, 0.9504f, + 0.9700f, 0.9820f, 0.9891f, 0.9935f, 0.9961f, 0.9976f, 0.9986f, 0.9992f, 0.9995f, 0.9997f}, + {0.0041f, 0.0747f, 0.2731f, 0.4059f, 0.3829f, 0.2901f, 0.1973f, 0.1270f, 0.0793f, 0.0487f, + 0.0296f, 0.0179f, 0.0108f, 0.0065f, 0.0039f, 0.0024f, 0.0014f, 0.0008f, 0.0005f, 0.0003f}); } void TestAFTObjGPairRightCensoredLabels(const Context* ctx) { std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "normal", - { -3.6583f, -3.1815f, -2.7135f, -2.2577f, -1.8190f, -1.4044f, -1.0239f, -0.6905f, -0.4190f, - -0.2209f, -0.0973f, -0.0346f, -0.0097f, -0.0021f, -0.0004f, -0.0000f, -0.0000f, -0.0000f, - -0.0000f, -0.0000f }, - { 0.9407f, 0.9259f, 0.9057f, 0.8776f, 0.8381f, 0.7821f, 0.7036f, 0.5970f, 0.4624f, 0.3128f, - 0.1756f, 0.0780f, 0.0265f, 0.0068f, 0.0013f, 0.0002f, 0.0000f, 0.0000f, 0.0000f, 0.0000f }); - CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "logistic", - { -0.9677f, -0.9474f, -0.9153f, -0.8663f, -0.7955f, -0.7000f, -0.5834f, -0.4566f, -0.3352f, - -0.2323f, -0.1537f, -0.0982f, -0.0614f, -0.0377f, -0.0230f, -0.0139f, -0.0084f, -0.0051f, - -0.0030f, -0.0018f }, - { 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f, - 0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f }); + CheckGPairOverGridPoints( + obj.get(), 60.0f, std::numeric_limits::infinity(), "normal", + {-3.6583f, -3.1815f, -2.7135f, -2.2577f, -1.8190f, -1.4044f, -1.0239f, + -0.6905f, -0.4190f, -0.2209f, -0.0973f, -0.0346f, -0.0097f, -0.0021f, + -0.0004f, -0.0000f, -0.0000f, -0.0000f, -0.0000f, -0.0000f}, + {0.9407f, 0.9259f, 0.9057f, 0.8776f, 0.8381f, 0.7821f, 0.7036f, 0.5970f, 0.4624f, 0.3128f, + 0.1756f, 0.0780f, 0.0265f, 0.0068f, 0.0013f, 0.0002f, 0.0000f, 0.0000f, 0.0000f, 0.0000f}); + CheckGPairOverGridPoints( + obj.get(), 60.0f, std::numeric_limits::infinity(), "logistic", + {-0.9677f, -0.9474f, -0.9153f, -0.8663f, -0.7955f, -0.7000f, -0.5834f, + -0.4566f, -0.3352f, -0.2323f, -0.1537f, -0.0982f, -0.0614f, -0.0377f, + -0.0230f, -0.0139f, -0.0084f, -0.0051f, -0.0030f, -0.0018f}, + {0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f, + 0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f}); CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "extreme", - { -15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f, - -0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f, - -0.0031f, -0.0018f }, - { 15.0000f, 15.0000f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f, - 0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f }); + {-15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, + -0.8403f, -0.5042f, -0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, + -0.0235f, -0.0141f, -0.0085f, -0.0051f, -0.0031f, -0.0018f}, + {15.0000f, 15.0000f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, + 0.8403f, 0.5042f, 0.3026f, 0.1816f, 0.1089f, 0.0654f, 0.0392f, + 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f}); } void TestAFTObjGPairIntervalCensoredLabels(const Context* ctx) { std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "normal", - { -2.4435f, -1.9965f, -1.5691f, -1.1679f, -0.7990f, -0.4649f, -0.1596f, 0.1336f, 0.4370f, - 0.7682f, 1.1340f, 1.5326f, 1.9579f, 2.4035f, 2.8639f, 3.3351f, 3.8143f, 4.2995f, 4.7891f, - 5.2822f }, - { 0.8909f, 0.8579f, 0.8134f, 0.7557f, 0.6880f, 0.6221f, 0.5789f, 0.5769f, 0.6171f, 0.6818f, - 0.7500f, 0.8088f, 0.8545f, 0.8884f, 0.9131f, 0.9312f, 0.9446f, 0.9547f, 0.9624f, 0.9684f }); - CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "logistic", - { -0.8790f, -0.8112f, -0.7153f, -0.5893f, -0.4375f, -0.2697f, -0.0955f, 0.0800f, 0.2545f, - 0.4232f, 0.5768f, 0.7054f, 0.8040f, 0.8740f, 0.9210f, 0.9513f, 0.9703f, 0.9820f, 0.9891f, - 0.9934f }, - { 0.1086f, 0.1588f, 0.2176f, 0.2745f, 0.3164f, 0.3374f, 0.3433f, 0.3434f, 0.3384f, 0.3191f, - 0.2789f, 0.2229f, 0.1637f, 0.1125f, 0.0737f, 0.0467f, 0.0290f, 0.0177f, 0.0108f, 0.0065f }); - CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "extreme", - { -8.0000f, -4.8004f, -2.8805f, -1.7284f, -1.0371f, -0.6168f, -0.3140f, -0.0121f, 0.2841f, - 0.5261f, 0.6989f, 0.8132f, 0.8857f, 0.9306f, 0.9581f, 0.9747f, 0.9848f, 0.9909f, 0.9945f, - 0.9967f }, - { 8.0000f, 4.8004f, 2.8805f, 1.7284f, 1.0380f, 0.6567f, 0.5727f, 0.6033f, 0.5384f, 0.4051f, - 0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f }); + CheckGPairOverGridPoints( + obj.get(), 16.0f, 200.0f, "normal", + {-2.4435f, -1.9965f, -1.5691f, -1.1679f, -0.7990f, -0.4649f, -0.1596f, + 0.1336f, 0.4370f, 0.7682f, 1.1340f, 1.5326f, 1.9579f, 2.4035f, + 2.8639f, 3.3351f, 3.8143f, 4.2995f, 4.7891f, 5.2822f}, + {0.8909f, 0.8579f, 0.8134f, 0.7557f, 0.6880f, 0.6221f, 0.5789f, 0.5769f, 0.6171f, 0.6818f, + 0.7500f, 0.8088f, 0.8545f, 0.8884f, 0.9131f, 0.9312f, 0.9446f, 0.9547f, 0.9624f, 0.9684f}); + CheckGPairOverGridPoints( + obj.get(), 16.0f, 200.0f, "logistic", + {-0.8790f, -0.8112f, -0.7153f, -0.5893f, -0.4375f, -0.2697f, -0.0955f, + 0.0800f, 0.2545f, 0.4232f, 0.5768f, 0.7054f, 0.8040f, 0.8740f, + 0.9210f, 0.9513f, 0.9703f, 0.9820f, 0.9891f, 0.9934f}, + {0.1086f, 0.1588f, 0.2176f, 0.2745f, 0.3164f, 0.3374f, 0.3433f, 0.3434f, 0.3384f, 0.3191f, + 0.2789f, 0.2229f, 0.1637f, 0.1125f, 0.0737f, 0.0467f, 0.0290f, 0.0177f, 0.0108f, 0.0065f}); + CheckGPairOverGridPoints( + obj.get(), 16.0f, 200.0f, "extreme", + {-8.0000f, -4.8004f, -2.8805f, -1.7284f, -1.0371f, -0.6168f, -0.3140f, + -0.0121f, 0.2841f, 0.5261f, 0.6989f, 0.8132f, 0.8857f, 0.9306f, + 0.9581f, 0.9747f, 0.9848f, 0.9909f, 0.9945f, 0.9967f}, + {8.0000f, 4.8004f, 2.8805f, 1.7284f, 1.0380f, 0.6567f, 0.5727f, 0.6033f, 0.5384f, 0.4051f, + 0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f}); } } // namespace xgboost::common diff --git a/tests/cpp/objective/test_hinge.cc b/tests/cpp/objective/test_hinge.cc index d7c6e3d89548..6b68b25aea09 100644 --- a/tests/cpp/objective/test_hinge.cc +++ b/tests/cpp/objective/test_hinge.cc @@ -1,21 +1,23 @@ /** * Copyright 2018-2023, XGBoost Contributors */ -#include +#include "test_hinge.h" + #include +#include + #include -#include "../helpers.h" -#include "test_hinge.h" #include "../../../src/common/linalg_op.h" +#include "../helpers.h" namespace xgboost { void TestHingeObj(const Context* ctx) { - std::unique_ptr obj{ObjFunction::Create("binary:hinge", ctx)}; + std::unique_ptr obj{ObjFunction::Create("binary:hinge", ctx, Args{})}; float eps = std::numeric_limits::min(); std::vector predt{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}; - std::vector label{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector label{0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}; std::vector grad{0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}; std::vector hess{eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps}; diff --git a/tests/cpp/objective/test_multiclass_obj.cc b/tests/cpp/objective/test_multiclass_obj.cc index ae3427b61398..fe0751c01e5f 100644 --- a/tests/cpp/objective/test_multiclass_obj.cc +++ b/tests/cpp/objective/test_multiclass_obj.cc @@ -1,35 +1,31 @@ /** * Copyright 2018-2025, XGBoost contributors */ -#include +#include "test_multiclass_obj.h" + #include +#include + #include "../helpers.h" -#include "test_multiclass_obj.h" namespace xgboost { void TestSoftmaxMultiClassObjGPair(const Context* ctx) { - std::vector> args {{"num_class", "3"}}; - std::unique_ptr obj { - ObjFunction::Create("multi:softmax", ctx) - }; - - obj->Configure(args); + std::vector> args{{"num_class", "3"}}; + std::unique_ptr obj{ObjFunction::Create("multi:softmax", ctx, args)}; CheckConfigReload(obj, "multi:softmax"); - CheckObjFunction(obj, - {1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds - {1.0f, 0.0f}, // labels - {1.0f, 1.0f}, // weights - {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad - {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess + CheckObjFunction(obj, {1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds + {1.0f, 0.0f}, // labels + {1.0f, 1.0f}, // weights + {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad + {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess - CheckObjFunction(obj, - {1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds - {1.0f, 0.0f}, // labels - {}, // weights - {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad - {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess + CheckObjFunction(obj, {1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds + {1.0f, 0.0f}, // labels + {}, // weights + {0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad + {0.36f, 0.16f, 0.44f, 0.45f, 0.16f, 0.37f}); // hess ASSERT_NO_THROW({ [[maybe_unused]] auto _ = obj->DefaultEvalMetric(); }); } @@ -38,12 +34,10 @@ void TestSoftmaxMultiClassBasic(const Context* ctx) { std::vector> args{ std::pair("num_class", "3")}; - std::unique_ptr obj{ObjFunction::Create("multi:softmax", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("multi:softmax", ctx, args)}; CheckConfigReload(obj, "multi:softmax"); - HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f, - 1.0f, 0.0f, 2.0f}; + HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f, 1.0f, 0.0f, 2.0f}; std::vector out_preds = {0.0f, 2.0f}; obj->PredTransform(&io_preds); @@ -55,16 +49,13 @@ void TestSoftmaxMultiClassBasic(const Context* ctx) { } void TestSoftprobMultiClassBasic(const Context* ctx) { - std::vector> args { - std::pair("num_class", "3")}; + std::vector> args{ + std::pair("num_class", "3")}; - std::unique_ptr obj { - ObjFunction::Create("multi:softprob", ctx) - }; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("multi:softprob", ctx, args)}; CheckConfigReload(obj, "multi:softprob"); - HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f}; + HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f}; std::vector out_preds = {0.66524096f, 0.09003057f, 0.24472847f}; obj->PredTransform(&io_preds); diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index 5df765789c3c..38a2a1f7b5df 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -29,16 +29,16 @@ TEST(Objective, PredTransform) { size_t n = 100; for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { - std::unique_ptr obj{xgboost::ObjFunction::Create(entry->name, &tparam)}; + Args args; if (entry->name.find("multi") != std::string::npos) { - obj->Configure(Args{{"num_class", "2"}}); - } - if (entry->name.find("quantile") != std::string::npos) { - obj->Configure(Args{{"quantile_alpha", "0.5"}}); - } - if (entry->name.find("expectile") != std::string::npos) { - obj->Configure(Args{{"expectile_alpha", "0.5"}}); + args = Args{{"num_class", "2"}}; + } else if (entry->name.find("quantile") != std::string::npos) { + args = Args{{"quantile_alpha", "0.5"}}; + } else if (entry->name.find("expectile") != std::string::npos) { + args = Args{{"expectile_alpha", "0.5"}}; } + std::unique_ptr obj{ + xgboost::ObjFunction::Create(entry->name, &tparam, args)}; HostDeviceVector predts; predts.Resize(n, 3.14f); // prediction is performed on host. ASSERT_FALSE(predts.DeviceCanRead()); @@ -55,21 +55,20 @@ class TestDefaultObjConfig : public ::testing::TestWithParam { void Run(std::string objective) { auto Xy = MakeFmatForObjTest(objective, 10, 10, 3); std::unique_ptr learner{Learner::Create({Xy})}; - std::unique_ptr objfn{ObjFunction::Create(objective, &ctx_)}; + Args args; learner->SetParam("objective", objective); if (objective.find("multi") != std::string::npos) { learner->SetParam("num_class", "3"); - objfn->Configure(Args{{"num_class", "3"}}); + args = Args{{"num_class", "3"}}; } else if (objective.find("quantile") != std::string::npos) { learner->SetParam("quantile_alpha", "0.5"); - objfn->Configure(Args{{"quantile_alpha", "0.5"}}); + args = Args{{"quantile_alpha", "0.5"}}; } else if (objective.find("expectile") != std::string::npos) { learner->SetParam("expectile_alpha", "0.5"); - objfn->Configure(Args{{"expectile_alpha", "0.5"}}); - } else { - objfn->Configure(Args{}); + args = Args{{"expectile_alpha", "0.5"}}; } + std::unique_ptr objfn{ObjFunction::Create(objective, &ctx_, args)}; learner->Configure(); learner->UpdateOneIter(0, Xy); learner->EvalOneIter(0, {Xy}, {"train"}); diff --git a/tests/cpp/objective/test_quantile_obj.cc b/tests/cpp/objective/test_quantile_obj.cc index 21e488859455..7f6c709f552b 100644 --- a/tests/cpp/objective/test_quantile_obj.cc +++ b/tests/cpp/objective/test_quantile_obj.cc @@ -18,14 +18,12 @@ namespace xgboost { void TestQuantile(Context const* ctx) { { Args args{{"quantile_alpha", "[0.6, 0.8]"}}; - std::unique_ptr obj{ObjFunction::Create("reg:quantileerror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:quantileerror", ctx, args)}; CheckConfigReload(obj, "reg:quantileerror"); } Args args{{"quantile_alpha", "0.6"}}; - std::unique_ptr obj{ObjFunction::Create("reg:quantileerror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:quantileerror", ctx, args)}; CheckConfigReload(obj, "reg:quantileerror"); std::vector predts{1.0f, 2.0f, 3.0f}; @@ -38,8 +36,7 @@ void TestQuantile(Context const* ctx) { void TestQuantileIntercept(Context const* ctx) { Args args{{"quantile_alpha", "[0.6, 0.8]"}}; - std::unique_ptr obj{ObjFunction::Create("reg:quantileerror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:quantileerror", ctx, args)}; MetaInfo info; info.num_row_ = 10; diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 517f71399c3c..4f8a37049120 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -38,9 +38,7 @@ void TestLinearRegressionGPair(const Context* ctx) { std::string obj_name = "reg:squarederror"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx, args)}; // clang-format off CheckObjFunction(obj, {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, @@ -62,8 +60,7 @@ void TestSquaredLog(const Context* ctx) { std::string obj_name = "reg:squaredlogerror"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx, args)}; CheckConfigReload(obj, obj_name); // clang-format off CheckObjFunction(obj, @@ -85,9 +82,7 @@ void TestSquaredLog(const Context* ctx) { void TestLogisticRegressionGPair(const Context* ctx) { std::string obj_name = "reg:logistic"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx, args)}; CheckConfigReload(obj, obj_name); // clang-format off CheckObjFunction(obj, @@ -102,9 +97,7 @@ void TestLogisticRegressionGPair(const Context* ctx) { void TestLogisticRegressionBasic(const Context* ctx) { std::string obj_name = "reg:logistic"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx, args)}; CheckConfigReload(obj, obj_name); // test label validation @@ -130,8 +123,7 @@ void TestLogisticRegressionBasic(const Context* ctx) { void TestsLogisticRawGPair(const Context* ctx) { std::string obj_name = "binary:logitraw"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx, args)}; // clang-format off CheckObjFunction(obj, { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, @@ -144,10 +136,8 @@ void TestsLogisticRawGPair(const Context* ctx) { void TestPoissonRegressionGPair(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("count:poisson", ctx)}; - args.emplace_back("max_delta_step", "0.1f"); - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("count:poisson", ctx, args)}; // clang-format off CheckObjFunction(obj, { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, @@ -166,9 +156,7 @@ void TestPoissonRegressionGPair(const Context* ctx) { void TestPoissonRegressionBasic(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("count:poisson", ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("count:poisson", ctx, args)}; CheckConfigReload(obj, "count:poisson"); // test label validation @@ -192,9 +180,7 @@ void TestPoissonRegressionBasic(const Context* ctx) { void TestGammaRegressionGPair(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:gamma", ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:gamma", ctx, args)}; // clang-format off CheckObjFunction(obj, {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, @@ -213,9 +199,7 @@ void TestGammaRegressionGPair(const Context* ctx) { void TestGammaRegressionBasic(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:gamma", ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:gamma", ctx, args)}; CheckConfigReload(obj, "reg:gamma"); // test label validation @@ -241,10 +225,8 @@ void TestGammaRegressionBasic(const Context* ctx) { void TestTweedieRegressionGPair(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:tweedie", ctx)}; - args.emplace_back("tweedie_variance_power", "1.1f"); - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:tweedie", ctx, args)}; // clang-format off CheckObjFunction(obj, { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, @@ -264,9 +246,7 @@ void TestTweedieRegressionGPair(const Context* ctx) { void TestTweedieRegressionBasic(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:tweedie", ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:tweedie", ctx, args)}; CheckConfigReload(obj, "reg:tweedie"); // test label validation @@ -290,9 +270,7 @@ void TestTweedieRegressionBasic(const Context* ctx) { void TestCoxRegressionGPair(const Context* ctx) { std::vector> args; - std::unique_ptr obj{ObjFunction::Create("survival:cox", ctx)}; - - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("survival:cox", ctx, args)}; // clang-format off CheckObjFunction(obj, { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, @@ -304,8 +282,7 @@ void TestCoxRegressionGPair(const Context* ctx) { } void TestAbsoluteError(const Context* ctx) { - std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", ctx)}; - obj->Configure({}); + std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", ctx, Args{})}; CheckConfigReload(obj, "reg:absoluteerror"); MetaInfo info; @@ -344,8 +321,7 @@ void TestAbsoluteError(const Context* ctx) { void TestAbsoluteErrorLeaf(const Context* ctx) { bst_target_t constexpr kTargets = 3, kRows = 16; - std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", ctx)}; - obj->Configure({}); + std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", ctx, Args{})}; MetaInfo info; info.num_row_ = kRows; @@ -401,8 +377,7 @@ void TestAbsoluteErrorLeaf(const Context* ctx) { void TestVectorLeafObj(Context const* ctx, std::string name, Args const& args, bst_idx_t n_samples, bst_idx_t n_target_labels, std::vector const& sol_left, std::vector const& sol_right) { - std::unique_ptr obj{ObjFunction::Create(name, ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create(name, ctx, args)}; bst_target_t n_targets = 3; auto tree = MakeMtTreeForTest(n_targets); @@ -433,8 +408,7 @@ void TestVectorLeafObj(Context const* ctx, std::string name, Args const& args, b void TestExpectileRegressionGPair(const Context* ctx) { Args args{{"expectile_alpha", "0.8"}}; - std::unique_ptr obj{ObjFunction::Create("reg:expectileerror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:expectileerror", ctx, args)}; CheckConfigReload(obj, "reg:expectileerror"); std::vector predts{1.0f, 2.0f, 3.0f}; @@ -451,8 +425,7 @@ void TestExpectileRegressionGPair(const Context* ctx) { void TestExpectileRegressionMultiAlpha(const Context* ctx) { Args args{{"expectile_alpha", "[0.2, 0.8]"}}; - std::unique_ptr obj{ObjFunction::Create("reg:expectileerror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:expectileerror", ctx, args)}; CheckConfigReload(obj, "reg:expectileerror"); std::vector predts{0.0f, 0.0f, 0.0f, 0.0f}; @@ -464,8 +437,7 @@ void TestExpectileRegressionMultiAlpha(const Context* ctx) { void TestExpectileRegressionInitEstimation(const Context* ctx) { Args args{{"expectile_alpha", "[0.2, 0.8]"}}; - std::unique_ptr obj{ObjFunction::Create("reg:expectileerror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:expectileerror", ctx, args)}; MetaInfo info; info.num_row_ = 10; @@ -504,8 +476,7 @@ void TestExpectileRegressionInitEstimation(const Context* ctx) { void TestPseudoHuber(const Context* ctx) { Args args; - std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", ctx)}; - obj->Configure(args); + std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", ctx, args)}; CheckConfigReload(obj, "reg:pseudohubererror"); CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred diff --git a/tests/cpp/objective/test_regression_obj_cpu.cc b/tests/cpp/objective/test_regression_obj_cpu.cc index d14c5e05f2ff..4658eb712a7b 100644 --- a/tests/cpp/objective/test_regression_obj_cpu.cc +++ b/tests/cpp/objective/test_regression_obj_cpu.cc @@ -84,7 +84,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { TEST(Objective, CPU_vs_CUDA) { Context ctx = MakeCUDACtx(GPUIDX); - std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx)}; + std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx, Args{})}; linalg::Matrix cpu_out_preds; linalg::Matrix cuda_out_preds; diff --git a/tests/cpp/plugin/test_example_objective.cc b/tests/cpp/plugin/test_example_objective.cc index ccb83c781fc1..c32529a01cb3 100644 --- a/tests/cpp/plugin/test_example_objective.cc +++ b/tests/cpp/plugin/test_example_objective.cc @@ -1,12 +1,17 @@ +/** + * Copyright 2026, XGBoost contributors + */ #include #include + #include + #include "../helpers.h" namespace xgboost { TEST(Plugin, ExampleObjective) { xgboost::Context ctx = MakeCUDACtx(GPUIDX); - auto* obj = xgboost::ObjFunction::Create("mylogistic", &ctx); + auto* obj = xgboost::ObjFunction::Create("mylogistic", &ctx, Args{}); ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"logloss"}); delete obj; } diff --git a/tests/cpp/plugin/test_sycl_regression_obj.cc b/tests/cpp/plugin/test_sycl_regression_obj.cc index d80fc0fb03f9..4e76968c453a 100644 --- a/tests/cpp/plugin/test_sycl_regression_obj.cc +++ b/tests/cpp/plugin/test_sycl_regression_obj.cc @@ -108,13 +108,11 @@ TEST(SyclObjective, DeclareUnifiedTest(PseudoHuber)) { TEST(SyclObjective, CPUvsSycl) { Context ctx_sycl; ctx_sycl.UpdateAllowUnknown(Args{{"device", "sycl"}}); - ObjFunction * obj_sycl = - ObjFunction::Create("reg:squarederror", &ctx_sycl); + ObjFunction* obj_sycl = ObjFunction::Create("reg:squarederror", &ctx_sycl, Args{}); Context ctx_cpu; ctx_cpu.UpdateAllowUnknown(Args{{"device", "cpu"}}); - ObjFunction * obj_cpu = - ObjFunction::Create("reg:squarederror", &ctx_cpu); + ObjFunction* obj_cpu = ObjFunction::Create("reg:squarederror", &ctx_cpu, Args{}); linalg::Matrix cpu_out_preds; linalg::Matrix sycl_out_preds; @@ -133,7 +131,7 @@ TEST(SyclObjective, CPUvsSycl) { info.labels.Reshape(kRows, 1); auto& h_labels = info.labels.Data()->HostVector(); for (size_t i = 0; i < h_labels.size(); ++i) { - h_labels[i] = 1 / static_cast(i+1); + h_labels[i] = 1 / static_cast(i + 1); } { diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index b86cb498733f..34e0255f7fa2 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -87,8 +87,7 @@ std::unique_ptr LoadGBTreeModel(Learner* learner, Context cons break; } } - auto obj = std::unique_ptr(ObjFunction::Create(objective, ctx)); - obj->Configure(model_args); + auto obj = std::unique_ptr(ObjFunction::Create(objective, ctx, model_args)); obj->ProbToMargin(&base_score_vec); // Keep both host/device views readable, matching LearnerModelParam invariants. std::as_const(base_score_vec).HostView(); From 1dc5a4e1b421f1fcae4ed3b9d5528d96ee662250 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 04:29:01 -0700 Subject: [PATCH 3/7] Construct objectives directly from config --- include/xgboost/objective.h | 1 + src/learner.cc | 6 ++---- src/objective/init_estimation.cc | 6 ++---- src/objective/objective.cc | 9 +++++++++ tests/cpp/objective/test_objective.cc | 14 ++++++++++++++ 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 4a29ff75faa3..0e82fa729548 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -137,6 +137,7 @@ class ObjFunction : public Configurable { * @param ctx Pointer to the context. */ static ObjFunction* Create(const std::string& name, Context const* ctx, Args const& args = {}); + static ObjFunction* Create(Context const* ctx, Json const& config); }; /*! diff --git a/src/learner.cc b/src/learner.cc index e1dc75e06aeb..211ed5c65046 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -587,9 +587,8 @@ class LearnerConfiguration : public Intercept { auto const& objective_fn = learner_parameters.at("objective"); if (!obj_) { CHECK_EQ(get(objective_fn["name"]), tparam_.objective); - obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_)); } - obj_->LoadConfig(objective_fn); + obj_.reset(ObjFunction::Create(&ctx_, objective_fn)); learner_model_param_.task = obj_->Task(); tparam_.booster = CanonicalizeBoosterName(get(gradient_booster["name"])); @@ -911,8 +910,7 @@ class LearnerIO : public LearnerConfiguration { std::string name = get(objective_fn["name"]); tparam_.UpdateAllowUnknown(Args{{"objective", name}}); - obj_.reset(ObjFunction::Create(name, &ctx_)); - obj_->LoadConfig(objective_fn); + obj_.reset(ObjFunction::Create(&ctx_, objective_fn)); auto const& gradient_booster = learner.at("gradient_booster"); name = get(gradient_booster["name"]); diff --git a/src/objective/init_estimation.cc b/src/objective/init_estimation.cc index f94d2f8ba286..55617fe43b8a 100644 --- a/src/objective/init_estimation.cc +++ b/src/objective/init_estimation.cc @@ -3,7 +3,7 @@ */ #include "init_estimation.h" -#include // unique_ptr +#include // unique_ptr #include "../common/stats.h" // Mean #include "../tree/fit_stump.h" // FitStump @@ -26,9 +26,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* b Json config{Object{}}; this->SaveConfig(&config); - std::unique_ptr new_obj{ - ObjFunction::Create(get(config["name"]), this->ctx_)}; - new_obj->LoadConfig(config); + std::unique_ptr new_obj{ObjFunction::Create(this->ctx_, config)}; new_obj->GetGradient(dummy_predt, info, 0, &gpair); bst_target_t n_targets = this->Targets(info); diff --git a/src/objective/objective.cc b/src/objective/objective.cc index e5ed37e4ac39..04113662c479 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -7,6 +7,7 @@ #include #include +#include #include // for stringstream #include // for string @@ -31,6 +32,14 @@ ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx, Ar return pobj; } +ObjFunction* ObjFunction::Create(Context const* ctx, Json const& config) { + auto const& obj = get(config); + auto objective = + std::unique_ptr{ObjFunction::Create(get(obj.at("name")), ctx)}; + objective->LoadConfig(config); + return objective.release(); +} + void ObjFunction::InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const { CHECK(base_score); auto n_targets = this->Targets(info); diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index 38a2a1f7b5df..f01d80478bf8 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -3,6 +3,7 @@ */ #include #include +#include #include #include "../helpers.h" @@ -21,6 +22,19 @@ TEST(Objective, UnknownFunction) { } } +TEST(Objective, LoadConfigFactory) { + xgboost::Context ctx; + std::unique_ptr obj{ + xgboost::ObjFunction::Create("reg:quantileerror", &ctx, {{"quantile_alpha", "0.8"}})}; + xgboost::Json config{xgboost::Object{}}; + obj->SaveConfig(&config); + + std::unique_ptr loaded{xgboost::ObjFunction::Create(&ctx, config)}; + xgboost::Json loaded_config{xgboost::Object{}}; + loaded->SaveConfig(&loaded_config); + ASSERT_EQ(config, loaded_config); +} + namespace xgboost { TEST(Objective, PredTransform) { // Test that show PredTransform uses the same device with predictor. From c8c0a5ef4fa91a6a6ae5c65888581160caa8f5bd Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 04:36:27 -0700 Subject: [PATCH 4/7] Recreate objectives during learner configure --- src/learner.cc | 11 +++++++---- tests/cpp/test_learner.cc | 30 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index 211ed5c65046..cc33d77ce83e 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -831,6 +831,7 @@ class LearnerConfiguration : public Intercept { } void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) { + auto preserve_loaded_objective = cfg_.empty() && obj_ != nullptr; // Once binary IO is gone, NONE of these config is useful. if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" && tparam_.objective != "multi:softprob") { @@ -852,11 +853,13 @@ class LearnerConfiguration : public Intercept { cfg_["num_class"] = std::to_string(mparam_.num_class); auto& args = *p_args; args = {cfg_.cbegin(), cfg_.cend()}; // renew - if (obj_ == nullptr || tparam_.objective != old.objective) { - obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_, args)); - } else { - obj_->Configure(args); + if (preserve_loaded_objective && tparam_.objective == old.objective) { + if (!has_nc) { + cfg_.erase("num_class"); + } + return; } + obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_, args)); if (!has_nc) { cfg_.erase("num_class"); } diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index e493532f3633..3b6372bc74f1 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -245,6 +245,36 @@ TEST(Learner, ConfigIO) { ASSERT_EQ(eval_res_0, eval_res_1); } +TEST(Learner, RecreateConfigurableObjective) { + auto m = MakeFmatForObjTest("reg:quantileerror", 16, 4, 1); + std::unique_ptr learner{Learner::Create({m})}; + learner->SetParams({{"objective", "reg:quantileerror"}, {"quantile_alpha", "0.8"}}); + learner->Configure(); + + Json config{Object{}}; + learner->SaveConfig(&config); + auto obj_config_0 = config["learner"]["objective"]; + + learner->SetParam("quantile_alpha", "0.2"); + learner->Configure(); + learner->SaveConfig(&config); + auto obj_config_1 = config["learner"]["objective"]; + ASSERT_FALSE(obj_config_0 == obj_config_1); + + std::unique_ptr expected{ + ObjFunction::Create("reg:quantileerror", learner->Ctx(), {{"quantile_alpha", "0.2"}})}; + Json expected_config{Object{}}; + expected->SaveConfig(&expected_config); + ASSERT_EQ(obj_config_1, expected_config); + + std::unique_ptr loaded{Learner::Create({m})}; + loaded->LoadConfig(config); + loaded->Configure(); + Json loaded_config{Object{}}; + loaded->SaveConfig(&loaded_config); + ASSERT_EQ(loaded_config["learner"]["objective"], obj_config_1); +} + // Crashes the test runner if there are race condiditions. // // Build with additional cmake flags to enable thread sanitizer From 80a604ac2764c9bf46593d45c8a2745b5b1ad91e Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 04:42:47 -0700 Subject: [PATCH 5/7] Fix objective constructor initialization --- include/xgboost/objective.h | 6 +++--- plugin/example/custom_obj.cc | 5 ++++- src/objective/quantile_obj.cu | 14 ++++++++------ src/objective/regression_obj.cu | 34 ++++++++++++++++----------------- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 0e82fa729548..639e90daaefa 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -152,10 +152,10 @@ struct ObjFunctionReg * * \code * // example of registering a objective - * XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:squarederror") - * .describe("Linear regression objective") + * XGBOOST_REGISTER_OBJECTIVE(MyObjective, "my:objective") + * .describe("A custom objective") * .set_body([](Args const&) { - * return new RegLossObj(LossType::kLinearSquare); + * return new MyObjective(); * }); * \endcode */ diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index 7b278a9d00cf..955096f41c5c 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -32,6 +32,9 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam); // Implement the interface. class MyLogistic : public ObjFunction { public: + explicit MyLogistic(Args const& args) { param_.UpdateAllowUnknown(args); } + MyLogistic() = default; + void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -87,6 +90,6 @@ class MyLogistic : public ObjFunction { // After it succeeds you can try use xgboost with objective=mylogistic XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic") .describe("User defined logistic regression plugin") - .set_body([](Args const&) { return new MyLogistic(); }); + .set_body([](Args const& args) { return new MyLogistic{args}; }); } // namespace xgboost::obj diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index ed717c09a521..7dc94a1a07db 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -29,6 +29,12 @@ class QuantileRegression : public ObjFunction { common::QuantileLossParam param_; HostDeviceVector alpha_; + void UpdateConfig(Args const& args) { + param_.UpdateAllowUnknown(args); + param_.Validate(); + alpha_.HostVector() = param_.quantile_alpha.Get(); + } + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { auto const& alpha = param_.quantile_alpha.Get(); CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured."; @@ -47,7 +53,7 @@ class QuantileRegression : public ObjFunction { public: explicit QuantileRegression(Args const& args) { if (!args.empty()) { - this->Configure(args); + this->UpdateConfig(args); } } QuantileRegression() = default; @@ -183,11 +189,7 @@ class QuantileRegression : public ObjFunction { } } - void Configure(Args const& args) override { - param_.UpdateAllowUnknown(args); - param_.Validate(); - this->alpha_.HostVector() = param_.quantile_alpha.Get(); - } + void Configure(Args const& args) override { this->UpdateConfig(args); } [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } static char const* Name() { return "reg:quantileerror"; } diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index b6ba8e508801..eea565151e50 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -400,6 +400,12 @@ class ExpectileRegression : public FitIntercept { common::ExpectileLossParam param_; HostDeviceVector alpha_; + void UpdateConfig(Args const& args) { + param_.UpdateAllowUnknown(args); + param_.Validate(); + alpha_.HostVector() = param_.expectile_alpha.Get(); + } + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { auto const& alpha = param_.expectile_alpha.Get(); CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured."; @@ -415,16 +421,12 @@ class ExpectileRegression : public FitIntercept { public: explicit ExpectileRegression(Args const& args) { if (!args.empty()) { - this->Configure(args); + this->UpdateConfig(args); } } ExpectileRegression() = default; - void Configure(Args const& args) override { - param_.UpdateAllowUnknown(args); - param_.Validate(); - alpha_.HostVector() = param_.expectile_alpha.Get(); - } + void Configure(Args const& args) override { this->UpdateConfig(args); } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -749,22 +751,20 @@ struct TweedieRegressionParam : public XGBoostParameter // tweedie regression class TweedieRegression : public FitInterceptGlmLike { - public: - explicit TweedieRegression(Args const& args) { - if (!args.empty()) { - this->Configure(args); - } - } - TweedieRegression() = default; - - // declare functions - void Configure(Args const& args) override { + void UpdateConfig(Args const& args) { param_.UpdateAllowUnknown(args); std::ostringstream os; os << "tweedie-nloglik@" << param_.tweedie_variance_power; metric_ = os.str(); } + public: + explicit TweedieRegression(Args const& args) { this->UpdateConfig(args); } + TweedieRegression() = default; + + // declare functions + void Configure(Args const& args) override { this->UpdateConfig(args); } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { @@ -930,6 +930,6 @@ class MeanAbsoluteError : public ObjFunction { }; XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") - .describe("Mean absoluate error.") + .describe("Mean absolute error.") .set_body([](Args const&) { return new MeanAbsoluteError(); }); } // namespace xgboost::obj From 5962f202be1ecdb04dccc6d85642d00da9a43887 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 05:03:26 -0700 Subject: [PATCH 6/7] Construct objectives directly from JSON --- include/xgboost/objective.h | 28 +++-- plugin/example/custom_obj.cc | 8 +- src/objective/aft_obj.cu | 7 +- src/objective/hinge.cu | 7 +- src/objective/lambdarank_obj.cc | 34 +++--- src/objective/multiclass_obj.cu | 13 +- src/objective/objective.cc | 26 ++-- src/objective/quantile_obj.cu | 15 ++- src/objective/regression_obj.cu | 133 ++++++++++----------- tests/cpp/helpers.cc | 21 ++++ tests/cpp/helpers.h | 68 +++++------ tests/cpp/objective/test_aft_obj.cc | 38 +++--- tests/cpp/objective/test_lambdarank_obj.cc | 62 ++++------ tests/cpp/objective/test_lambdarank_obj.h | 24 ++-- tests/cpp/objective/test_regression_obj.cc | 2 +- 15 files changed, 243 insertions(+), 243 deletions(-) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 639e90daaefa..ac729f952561 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -25,7 +25,7 @@ class RegTree; struct Context; /** @brief The interface of objective function */ -class ObjFunction : public Configurable { +class ObjFunction { protected: Context const* ctx_{nullptr}; @@ -33,13 +33,7 @@ class ObjFunction : public Configurable { static constexpr float DefaultBaseScore() { return 0.5f; } public: - ~ObjFunction() override = default; - /** - * @brief Configure the objective with the specified parameters. - * - * @param args arguments to the objective function. - */ - virtual void Configure(Args const& args) = 0; + virtual ~ObjFunction() = default; /** * @brief Get gradient over each of predictions, given existing information. * @@ -130,6 +124,11 @@ class ObjFunction : public Configurable { MetaInfo const& /*info*/, float /*learning_rate*/, HostDeviceVector const& /*prediction*/, bst_target_t /*group_idx*/, RegTree* /*p_tree*/) const {} + /** + * @brief Save configuration to JSON object. + * @param out pointer to output JSON object. + */ + virtual void SaveConfig(Json* out) const = 0; /** * @brief Create an objective function according to the name. * @@ -144,7 +143,13 @@ class ObjFunction : public Configurable { * \brief Registry entry for objective factory functions. */ struct ObjFunctionReg - : public dmlc::FunctionRegEntryBase > { + : public dmlc::FunctionRegEntryBase> { + std::function json_body; + + inline ObjFunctionReg& set_body_json(std::function body) { + json_body = std::move(body); + return *this; + } }; /*! @@ -156,12 +161,15 @@ struct ObjFunctionReg * .describe("A custom objective") * .set_body([](Args const&) { * return new MyObjective(); + * }) + * .set_body_json([](Json const& config) { + * return new MyObjective(config); * }); * \endcode */ #define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \ static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg& \ __make_##ObjFunctionReg##_##UniqueId##__ = \ - ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name) + ::dmlc::Registry<::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name) } // namespace xgboost #endif // XGBOOST_OBJECTIVE_H_ diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index 955096f41c5c..5a08b55dda98 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -33,10 +33,9 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam); class MyLogistic : public ObjFunction { public: explicit MyLogistic(Args const& args) { param_.UpdateAllowUnknown(args); } + explicit MyLogistic(Json const& in) { FromJson(in["my_logistic_param"], ¶m_); } MyLogistic() = default; - void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); } - [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector& preds, MetaInfo const& info, @@ -80,8 +79,6 @@ class MyLogistic : public ObjFunction { out["my_logistic_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { FromJson(in["my_logistic_param"], ¶m_); } - private: MyLogisticParam param_; }; @@ -90,6 +87,7 @@ class MyLogistic : public ObjFunction { // After it succeeds you can try use xgboost with objective=mylogistic XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic") .describe("User defined logistic regression plugin") - .set_body([](Args const& args) { return new MyLogistic{args}; }); + .set_body([](Args const& args) { return new MyLogistic{args}; }) + .set_body_json([](Json const& config) { return new MyLogistic{config}; }); } // namespace xgboost::obj diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index f09719bbae18..849a5ff54722 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -32,10 +32,9 @@ DMLC_REGISTRY_FILE_TAG(aft_obj_gpu); class AFTObj : public ObjFunction { public: explicit AFTObj(Args const& args) { param_.UpdateAllowUnknown(args); } + explicit AFTObj(Json const& in) { FromJson(in["aft_loss_param"], ¶m_); } AFTObj() = default; - void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return ObjInfo::kSurvival; } template @@ -126,7 +125,6 @@ class AFTObj : public ObjFunction { out["aft_loss_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { FromJson(in["aft_loss_param"], ¶m_); } Json DefaultMetricConfig() const override { Json config{Object{}}; config["name"] = String{this->DefaultEvalMetric()}; @@ -141,7 +139,8 @@ class AFTObj : public ObjFunction { // register the objective functions XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft") .describe("AFT loss function") - .set_body([](Args const& args) { return new AFTObj{args}; }); + .set_body([](Args const& args) { return new AFTObj{args}; }) + .set_body_json([](Json const& config) { return new AFTObj{config}; }); } // namespace obj } // namespace xgboost diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index fee7c4d2525f..b54483431397 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -27,8 +27,7 @@ DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu); class HingeObj : public FitIntercept { public: HingeObj() = default; - - void Configure(Args const &) override {} + explicit HingeObj(Json const &) {} ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override { @@ -93,12 +92,12 @@ class HingeObj : public FitIntercept { auto &out = *p_out; out["name"] = String("binary:hinge"); } - void LoadConfig(Json const &) override {} }; // register the objective functions XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") .describe("Hinge loss. Expects labels to be in [0,1f]") - .set_body([](Args const &) { return new HingeObj(); }); + .set_body([](Args const &) { return new HingeObj(); }) + .set_body_json([](Json const &config) { return new HingeObj(config); }); } // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index eb6256f27242..d0f8a7e818a7 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -251,9 +251,18 @@ class LambdaRankObj : public FitIntercept { public: explicit LambdaRankObj(Args const& args) { param_.UpdateAllowUnknown(args); } - LambdaRankObj() = default; + explicit LambdaRankObj(Json const& in) { + auto const& obj = get(in); + if (obj.find("lambdarank_param") != obj.cend()) { + FromJson(in["lambdarank_param"], ¶m_); + } - void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + if (param_.lambdarank_unbiased) { + linalg::LoadVector(in["ti+"], &ti_plus_); + linalg::LoadVector(in["tj-"], &tj_minus_); + } + } + LambdaRankObj() = default; void SaveConfig(Json* p_out) const override { auto& out = *p_out; out["name"] = String(Loss::Name()); @@ -266,18 +275,6 @@ class LambdaRankObj : public FitIntercept { linalg::SaveVector(tj_minus_, &out["tj-"]); } } - void LoadConfig(Json const& in) override { - auto const& obj = get(in); - if (obj.find("lambdarank_param") != obj.cend()) { - FromJson(in["lambdarank_param"], ¶m_); - } - - if (param_.lambdarank_unbiased) { - linalg::LoadVector(in["ti+"], &ti_plus_); - linalg::LoadVector(in["tj-"], &tj_minus_); - } - } - [[nodiscard]] ObjInfo Task() const override { return ObjInfo{ObjInfo::kRanking}; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { @@ -666,15 +663,18 @@ void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVecto XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name()) .describe("LambdaRank with NDCG loss as objective") - .set_body([](Args const& args) { return new LambdaRankNDCG{args}; }); + .set_body([](Args const& args) { return new LambdaRankNDCG{args}; }) + .set_body_json([](Json const& config) { return new LambdaRankNDCG{config}; }); XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name()) .describe("LambdaRank with RankNet loss as objective") - .set_body([](Args const& args) { return new LambdaRankPairwise{args}; }); + .set_body([](Args const& args) { return new LambdaRankPairwise{args}; }) + .set_body_json([](Json const& config) { return new LambdaRankPairwise{config}; }); XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name()) .describe("LambdaRank with MAP loss as objective.") - .set_body([](Args const& args) { return new LambdaRankMAP{args}; }); + .set_body([](Args const& args) { return new LambdaRankMAP{args}; }) + .set_body_json([](Json const& config) { return new LambdaRankMAP{config}; }); DMLC_REGISTRY_FILE_TAG(lambdarank_obj); } // namespace xgboost::obj diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index d181f67e94bf..ff58c8767e08 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -74,8 +74,9 @@ class SoftmaxMultiClassObj : public ObjFunction { param_.UpdateAllowUnknown(args); } } - - void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + SoftmaxMultiClassObj(bool output_prob, Json const& in) : output_prob_(output_prob) { + FromJson(in["softmax_multiclass_param"], ¶m_); + } ObjInfo Task() const override { return ObjInfo::kClassification; } @@ -194,8 +195,6 @@ class SoftmaxMultiClassObj : public ObjFunction { out["softmax_multiclass_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { FromJson(in["softmax_multiclass_param"], ¶m_); } - void InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const override { std::int64_t n_classes = this->param_.num_class; ValidateLabel(this->ctx_, info, n_classes); @@ -238,9 +237,11 @@ DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam); XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax") .describe("Softmax for multi-class classification, output class index.") - .set_body([](Args const& args) { return new SoftmaxMultiClassObj(false, args); }); + .set_body([](Args const& args) { return new SoftmaxMultiClassObj(false, args); }) + .set_body_json([](Json const& config) { return new SoftmaxMultiClassObj(false, config); }); XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob") .describe("Softmax for multi-class classification, output probability distribution.") - .set_body([](Args const& args) { return new SoftmaxMultiClassObj(true, args); }); + .set_body([](Args const& args) { return new SoftmaxMultiClassObj(true, args); }) + .set_body_json([](Json const& config) { return new SoftmaxMultiClassObj(true, config); }); } // namespace xgboost::obj diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 04113662c479..887680aac365 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -16,17 +16,23 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); } // namespace dmlc namespace xgboost { -// implement factory functions -ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx, Args const& args) { - std::string obj_name = name; - auto* e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name); +namespace { +ObjFunctionReg const* GetObjRegistryEntry(std::string const& name) { + auto* e = ::dmlc::Registry<::xgboost::ObjFunctionReg>::Get()->Find(name); if (e == nullptr) { std::stringstream ss; - for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) { + for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) { ss << "Objective candidate: " << entry->name << "\n"; } LOG(FATAL) << "Unknown objective function: `" << name << "`\n" << ss.str(); } + return e; +} +} // namespace + +// implement factory functions +ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx, Args const& args) { + auto* e = GetObjRegistryEntry(name); auto pobj = (e->body)(args); pobj->ctx_ = ctx; return pobj; @@ -34,10 +40,12 @@ ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx, Ar ObjFunction* ObjFunction::Create(Context const* ctx, Json const& config) { auto const& obj = get(config); - auto objective = - std::unique_ptr{ObjFunction::Create(get(obj.at("name")), ctx)}; - objective->LoadConfig(config); - return objective.release(); + auto name = get(obj.at("name")); + auto* e = GetObjRegistryEntry(name); + CHECK(e->json_body) << "JSON factory is not defined for objective `" << name << "`."; + auto pobj = (e->json_body)(config); + pobj->ctx_ = ctx; + return pobj; } void ObjFunction::InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const { diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 7dc94a1a07db..ff87f1e36be6 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -56,6 +56,11 @@ class QuantileRegression : public ObjFunction { this->UpdateConfig(args); } } + explicit QuantileRegression(Json const& in) { + CHECK_EQ(get(in["name"]), Name()); + FromJson(in["quantile_loss_param"], ¶m_); + alpha_.HostVector() = param_.quantile_alpha.Get(); + } QuantileRegression() = default; void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, std::int32_t iter, @@ -189,7 +194,6 @@ class QuantileRegression : public ObjFunction { } } - void Configure(Args const& args) override { this->UpdateConfig(args); } [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } static char const* Name() { return "reg:quantileerror"; } @@ -198,12 +202,6 @@ class QuantileRegression : public ObjFunction { out["name"] = String(Name()); out["quantile_loss_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { - CHECK_EQ(get(in["name"]), Name()); - FromJson(in["quantile_loss_param"], ¶m_); - alpha_.HostVector() = param_.quantile_alpha.Get(); - } - [[nodiscard]] const char* DefaultEvalMetric() const override { return "quantile"; } [[nodiscard]] Json DefaultMetricConfig() const override { CHECK(param_.GetInitialised()); @@ -216,7 +214,8 @@ class QuantileRegression : public ObjFunction { XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name()) .describe("Regression with quantile loss.") - .set_body([](Args const& args) { return new QuantileRegression{args}; }); + .set_body([](Args const& args) { return new QuantileRegression{args}; }) + .set_body_json([](Json const& config) { return new QuantileRegression{config}; }); #if defined(XGBOOST_USE_CUDA) DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index eea565151e50..6427bbba24e8 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -132,8 +132,13 @@ class RegLossObj : public FitInterceptGlmLike { // 0 - scale_pos_weight, 1 - is_null_weight RegLossObj() : additional_input_(2) {} explicit RegLossObj(Args const& args) : additional_input_(2) { param_.UpdateAllowUnknown(args); } - - void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + explicit RegLossObj(Json const& in) : additional_input_(2) { + auto obj = get(in); + auto it = obj.find("reg_loss_param"); + if (it != obj.cend()) { + FromJson(it->second, ¶m_); + } + } [[nodiscard]] ObjInfo Task() const override { return Loss::Info(); } @@ -231,14 +236,6 @@ class RegLossObj : public FitInterceptGlmLike { out["reg_loss_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { - auto obj = get(in); - auto it = obj.find("reg_loss_param"); - if (it != obj.cend()) { - FromJson(it->second, ¶m_); - } - } - protected: RegLossParam param_; }; @@ -248,25 +245,32 @@ DMLC_REGISTER_PARAMETER(RegLossParam); XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name()) .describe("Regression with squared error.") - .set_body([](Args const& args) { return new RegLossObj{args}; }); + .set_body([](Args const& args) { return new RegLossObj{args}; }) + .set_body_json([](Json const& config) { return new RegLossObj{config}; }); XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name()) .describe("Logistic regression for probability regression task.") - .set_body([](Args const& args) { return new RegLossObj{args}; }); + .set_body([](Args const& args) { return new RegLossObj{args}; }) + .set_body_json([](Json const& config) { return new RegLossObj{config}; }); XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name()) .describe("Logistic regression for binary classification task.") - .set_body([](Args const& args) { return new RegLossObj{args}; }); + .set_body([](Args const& args) { return new RegLossObj{args}; }) + .set_body_json([](Json const& config) { + return new RegLossObj{config}; + }); XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, LogisticRaw::Name()) .describe( "Logistic regression for classification, output score " "before logistic transformation.") - .set_body([](Args const& args) { return new RegLossObj{args}; }); + .set_body([](Args const& args) { return new RegLossObj{args}; }) + .set_body_json([](Json const& config) { return new RegLossObj{config}; }); XGBOOST_REGISTER_OBJECTIVE(GammaRegression, GammaDeviance::Name()) .describe("Gamma regression using the gamma deviance loss with log link.") - .set_body([](Args const& args) { return new RegLossObj{args}; }); + .set_body([](Args const& args) { return new RegLossObj{args}; }) + .set_body_json([](Json const& config) { return new RegLossObj{config}; }); // Deprecated functions XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") @@ -274,14 +278,16 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") .set_body([](Args const& args) { LOG(WARNING) << "reg:linear is now deprecated in favor of reg:squarederror."; return new RegLossObj{args}; - }); + }) + .set_body_json([](Json const& config) { return new RegLossObj{config}; }); // End deprecated class SquaredLogErrorRegression : public FitIntercept { public: static auto Name() { return SquaredLogError::Name(); } + SquaredLogErrorRegression() = default; + explicit SquaredLogErrorRegression(Json const&) {} - void Configure(Args const&) override {} [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { return std::max(static_cast(1), info.labels.Shape(1)); @@ -318,21 +324,26 @@ class SquaredLogErrorRegression : public FitIntercept { auto& out = *p_out; out["name"] = String(Name()); } - void LoadConfig(Json const&) override {} }; XGBOOST_REGISTER_OBJECTIVE(SquaredLogErrorRegression, SquaredLogErrorRegression::Name()) .describe("Root mean squared log error.") - .set_body([](Args const&) { return new SquaredLogErrorRegression(); }); + .set_body([](Args const&) { return new SquaredLogErrorRegression(); }) + .set_body_json([](Json const& config) { return new SquaredLogErrorRegression{config}; }); class PseudoHuberRegression : public FitIntercept { PseudoHuberParam param_; public: explicit PseudoHuberRegression(Args const& args) { param_.UpdateAllowUnknown(args); } + explicit PseudoHuberRegression(Json const& in) { + auto const& config = get(in); + auto it = config.find("pseudo_huber_param"); + if (it != config.cend()) { + FromJson(it->second, ¶m_); + } + } PseudoHuberRegression() = default; - - void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { return std::max(static_cast(1), info.labels.Shape(1)); @@ -375,14 +386,6 @@ class PseudoHuberRegression : public FitIntercept { out["pseudo_huber_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { - auto const& config = get(in); - if (config.find("pseudo_huber_param") == config.cend()) { - // The parameter is added in 1.6. - return; - } - FromJson(in["pseudo_huber_param"], ¶m_); - } [[nodiscard]] Json DefaultMetricConfig() const override { CHECK(param_.GetInitialised()); Json config{Object{}}; @@ -394,7 +397,8 @@ class PseudoHuberRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(PseudoHuberRegression, "reg:pseudohubererror") .describe("Regression Pseudo Huber error.") - .set_body([](Args const& args) { return new PseudoHuberRegression{args}; }); + .set_body([](Args const& args) { return new PseudoHuberRegression{args}; }) + .set_body_json([](Json const& config) { return new PseudoHuberRegression{config}; }); class ExpectileRegression : public FitIntercept { common::ExpectileLossParam param_; @@ -424,10 +428,16 @@ class ExpectileRegression : public FitIntercept { this->UpdateConfig(args); } } + explicit ExpectileRegression(Json const& in) { + auto const& obj = get(in); + auto it = obj.find("expectile_loss_param"); + if (it != obj.cend()) { + FromJson(it->second, ¶m_); + alpha_.HostVector() = param_.expectile_alpha.Get(); + } + } ExpectileRegression() = default; - void Configure(Args const& args) override { this->UpdateConfig(args); } - [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, std::int32_t iter, @@ -540,21 +550,12 @@ class ExpectileRegression : public FitIntercept { out["name"] = String("reg:expectileerror"); out["expectile_loss_param"] = ToJson(param_); } - - void LoadConfig(Json const& in) override { - CHECK_EQ(get(in["name"]), "reg:expectileerror"); - auto const& obj = get(in); - auto it = obj.find("expectile_loss_param"); - if (it != obj.cend()) { - FromJson(it->second, ¶m_); - alpha_.HostVector() = param_.expectile_alpha.Get(); - } - } }; XGBOOST_REGISTER_OBJECTIVE(ExpectileRegression, "reg:expectileerror") .describe("Regression with expectile loss.") - .set_body([](Args const& args) { return new ExpectileRegression{args}; }); + .set_body([](Args const& args) { return new ExpectileRegression{args}; }) + .set_body_json([](Json const& config) { return new ExpectileRegression{config}; }); // declare parameter struct PoissonRegressionParam : public XGBoostParameter { @@ -573,11 +574,9 @@ struct PoissonRegressionParam : public XGBoostParameter class PoissonRegression : public FitInterceptGlmLike { public: explicit PoissonRegression(Args const& args) { param_.UpdateAllowUnknown(args); } + explicit PoissonRegression(Json const& in) { FromJson(in["poisson_regression_param"], ¶m_); } PoissonRegression() = default; - // declare functions - void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } - [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { @@ -632,8 +631,6 @@ class PoissonRegression : public FitInterceptGlmLike { out["poisson_regression_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { FromJson(in["poisson_regression_param"], ¶m_); } - private: PoissonRegressionParam param_; }; @@ -643,12 +640,14 @@ DMLC_REGISTER_PARAMETER(PoissonRegressionParam); XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") .describe("Poisson regression for count data.") - .set_body([](Args const& args) { return new PoissonRegression{args}; }); + .set_body([](Args const& args) { return new PoissonRegression{args}; }) + .set_body_json([](Json const& config) { return new PoissonRegression{config}; }); // cox regression for survival data (negative values mean they are censored) class CoxRegression : public FitIntercept { public: - void Configure(Args const&) override {} + CoxRegression() = default; + explicit CoxRegression(Json const&) {} [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int, @@ -729,14 +728,14 @@ class CoxRegression : public FitIntercept { auto& out = *p_out; out["name"] = String("survival:cox"); } - void LoadConfig(Json const&) override {} }; // register the objective function XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") .describe( "Cox regression for censored survival data (negative labels are considered censored).") - .set_body([](Args const&) { return new CoxRegression(); }); + .set_body([](Args const&) { return new CoxRegression(); }) + .set_body_json([](Json const& config) { return new CoxRegression{config}; }); // declare parameter struct TweedieRegressionParam : public XGBoostParameter { @@ -760,11 +759,14 @@ class TweedieRegression : public FitInterceptGlmLike { public: explicit TweedieRegression(Args const& args) { this->UpdateConfig(args); } + explicit TweedieRegression(Json const& in) { + FromJson(in["tweedie_regression_param"], ¶m_); + std::ostringstream os; + os << "tweedie-nloglik@" << param_.tweedie_variance_power; + metric_ = os.str(); + } TweedieRegression() = default; - // declare functions - void Configure(Args const& args) override { this->UpdateConfig(args); } - [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { @@ -819,12 +821,6 @@ class TweedieRegression : public FitInterceptGlmLike { out["name"] = String("reg:tweedie"); out["tweedie_regression_param"] = ToJson(param_); } - void LoadConfig(Json const& in) override { - FromJson(in["tweedie_regression_param"], ¶m_); - std::ostringstream os; - os << "tweedie-nloglik@" << param_.tweedie_variance_power; - metric_ = os.str(); - } private: std::string metric_; @@ -836,11 +832,15 @@ DMLC_REGISTER_PARAMETER(TweedieRegressionParam); XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") .describe("Tweedie regression for insurance data.") - .set_body([](Args const& args) { return new TweedieRegression{args}; }); + .set_body([](Args const& args) { return new TweedieRegression{args}; }) + .set_body_json([](Json const& config) { return new TweedieRegression{config}; }); class MeanAbsoluteError : public ObjFunction { public: - void Configure(Args const&) override {} + MeanAbsoluteError() = default; + explicit MeanAbsoluteError(Json const& in) { + CHECK_EQ(StringView{get(in["name"])}, StringView{"reg:absoluteerror"}); + } [[nodiscard]] ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { return std::max(static_cast(1), info.labels.Shape(1)); @@ -923,13 +923,10 @@ class MeanAbsoluteError : public ObjFunction { auto& out = *p_out; out["name"] = String("reg:absoluteerror"); } - - void LoadConfig(Json const& in) override { - CHECK_EQ(StringView{get(in["name"])}, StringView{"reg:absoluteerror"}); - } }; XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") .describe("Mean absolute error.") - .set_body([](Args const&) { return new MeanAbsoluteError(); }); + .set_body([](Args const&) { return new MeanAbsoluteError(); }) + .set_body_json([](Json const& config) { return new MeanAbsoluteError{config}; }); } // namespace xgboost::obj diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 3cb357845a0b..4773946b3c2d 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -155,6 +155,27 @@ xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable, s return config_1; } +xgboost::Json CheckConfigReload(std::unique_ptr const& obj, + std::string name) { + xgboost::Json config_0{xgboost::Object()}; + obj->SaveConfig(&config_0); + auto loaded = + std::unique_ptr{xgboost::ObjFunction::Create(obj->Ctx(), config_0)}; + + xgboost::Json config_1{xgboost::Object()}; + loaded->SaveConfig(&config_1); + + std::string str_0, str_1; + xgboost::Json::Dump(config_0, &str_0); + xgboost::Json::Dump(config_1, &str_1); + EXPECT_EQ(str_0, str_1); + + if (!name.empty()) { + EXPECT_EQ(xgboost::get(config_1["name"]), name); + } + return config_1; +} + void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector preds, std::vector labels, diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 747d7c450bbf..712bd1c28cef 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -19,17 +19,16 @@ #include #include - #if defined(__CUDACC__) #include "../../src/collective/communicator-inl.h" // for GetRank #include "../../src/common/cuda_rt_utils.h" // for AllVisibleGPUs -#endif // defined(__CUDACC__) +#endif // defined(__CUDACC__) #include "filesystem.h" // for TemporaryDirectory #include "xgboost/linalg.h" #if defined(__CUDACC__) -#define DeclareUnifiedTest(name) GPU ## name +#define DeclareUnifiedTest(name) GPU##name #else #define DeclareUnifiedTest(name) name #endif @@ -41,7 +40,7 @@ #endif #if defined(__CUDACC__) -#define DeclareUnifiedDistributedTest(name) MGPU ## name +#define DeclareUnifiedDistributedTest(name) MGPU##name #else #define DeclareUnifiedDistributedTest(name) name #endif @@ -51,7 +50,7 @@ class ObjFunction; class Metric; struct LearnerModelParam; class GradientBooster; -} +} // namespace xgboost template Float RelError(Float l, Float r) { @@ -82,20 +81,18 @@ void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_ void CreateTestCSV(std::string const& path, size_t rows, size_t cols); void CheckObjFunction(std::unique_ptr const& obj, - std::vector preds, - std::vector labels, + std::vector preds, std::vector labels, std::vector weights, std::vector out_grad, std::vector out_hess); -xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable, - std::string name); +xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable, std::string name); +xgboost::Json CheckConfigReload(std::unique_ptr const& obj, + std::string name = ""); template -xgboost::Json CheckConfigReload(std::unique_ptr const& configurable, - std::string name = "") { - return CheckConfigReloadImpl(dynamic_cast(configurable.get()), - name); +xgboost::Json CheckConfigReload(std::unique_ptr const& configurable, std::string name = "") { + return CheckConfigReloadImpl(dynamic_cast(configurable.get()), name); } void CheckRankingObjFunction(std::unique_ptr const& obj, @@ -107,12 +104,11 @@ void CheckRankingObjFunction(std::unique_ptr const& obj, std::vector out_hess); xgboost::bst_float GetMetricEval( - xgboost::Metric * metric, - xgboost::HostDeviceVector const& preds, - std::vector labels, - std::vector weights = std::vector(), - std::vector groups = std::vector(), - xgboost::DataSplitMode data_split_Mode = xgboost::DataSplitMode::kRow); + xgboost::Metric* metric, xgboost::HostDeviceVector const& preds, + std::vector labels, + std::vector weights = std::vector(), + std::vector groups = std::vector(), + xgboost::DataSplitMode data_split_Mode = xgboost::DataSplitMode::kRow); double GetMultiMetricEval(xgboost::Metric* metric, xgboost::HostDeviceVector const& preds, @@ -179,8 +175,8 @@ class SimpleRealUniformDistribution { template ResultT GenerateCanonical(GeneratorT* rng) const { static_assert(std::is_floating_point_v, "Result type must be floating point."); - long double const r = (static_cast(rng->Max()) - - static_cast(rng->Min())) + 1.0L; + long double const r = + (static_cast(rng->Max()) - static_cast(rng->Min())) + 1.0L; auto const log2r = static_cast(std::log(r) / std::log(2.0L)); size_t m = std::max(1UL, (Bits + log2r - 1UL) / log2r); ResultT sum_value = 0, r_k = 1; @@ -195,13 +191,11 @@ class SimpleRealUniformDistribution { } public: - SimpleRealUniformDistribution(ResultT l, ResultT u) : - lower_{l}, upper_{u} {} + SimpleRealUniformDistribution(ResultT l, ResultT u) : lower_{l}, upper_{u} {} template ResultT operator()(GeneratorT* rng) const { - ResultT tmp = GenerateCanonical::digits, - GeneratorT>(rng); + ResultT tmp = GenerateCanonical::digits, GeneratorT>(rng); auto ret = (tmp * (upper_ - lower_)) + lower_; // Correct floating point error. return std::max(ret, lower_); @@ -225,7 +219,7 @@ Json GetArrayInterface(HostDeviceVector const* storage, size_t rows, size_t c array_interface["shape"][1] = cols; char t = linalg::detail::ArrayInterfaceHandler::TypeChar(); - array_interface["typestr"] = String(std::string{"<"} + t + std::to_string(sizeof(T))); + array_interface["typestr"] = String(std::string("<") + t + std::to_string(sizeof(T))); array_interface["version"] = 3; return array_interface; } @@ -410,12 +404,13 @@ inline auto GenerateRandomGradients(Context const* ctx, bst_idx_t n_rows, bst_ta float lower = 0.0f, float upper = 1.0f) { auto g = GenerateRandomGradients(n_rows * n_targets, lower, upper); GradientContainer gpair; - gpair.gpair = linalg::Matrix{{n_rows, static_cast(n_targets)}, ctx->Device()}; + gpair.gpair = + linalg::Matrix{{n_rows, static_cast(n_targets)}, ctx->Device()}; gpair.gpair.Data()->Copy(g); return gpair; } -typedef void *DMatrixHandle; // NOLINT(*); +typedef void* DMatrixHandle; // NOLINT(*); class ArrayIterForTest { protected: @@ -476,19 +471,14 @@ class NumpyArrayIterForTest : public ArrayIterForTest { ~NumpyArrayIterForTest() override = default; }; -void DMatrixToCSR(DMatrix *dmat, std::vector *p_data, - std::vector *p_row_ptr, - std::vector *p_cids); +void DMatrixToCSR(DMatrix* dmat, std::vector* p_data, std::vector* p_row_ptr, + std::vector* p_cids); -typedef void *DataIterHandle; // NOLINT(*) +typedef void* DataIterHandle; // NOLINT(*) -inline void Reset(DataIterHandle self) { - static_cast(self)->Reset(); -} +inline void Reset(DataIterHandle self) { static_cast(self)->Reset(); } -inline int Next(DataIterHandle self) { - return static_cast(self)->Next(); -} +inline int Next(DataIterHandle self) { return static_cast(self)->Next(); } /** * @brief Create an array interface for host vector. @@ -501,7 +491,7 @@ char const* Make1dInterfaceTest(T const* vec, std::size_t len) { } class RMMAllocator; -using RMMAllocatorPtr = std::unique_ptr; +using RMMAllocatorPtr = std::unique_ptr; RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv); /* diff --git a/tests/cpp/objective/test_aft_obj.cc b/tests/cpp/objective/test_aft_obj.cc index 23cad514e49f..1b232232a615 100644 --- a/tests/cpp/objective/test_aft_obj.cc +++ b/tests/cpp/objective/test_aft_obj.cc @@ -36,7 +36,7 @@ void TestAFTObjConfiguration(const Context* ctx) { // Generate prediction value ranging from 2**1 to 2**15, using grid points in log scale // Then check prediction against the reference values -static inline void CheckGPairOverGridPoints(ObjFunction* obj, bst_float true_label_lower_bound, +static inline void CheckGPairOverGridPoints(Context const* ctx, bst_float true_label_lower_bound, bst_float true_label_upper_bound, const std::string& dist_type, const std::vector& expected_grad, @@ -46,7 +46,9 @@ static inline void CheckGPairOverGridPoints(ObjFunction* obj, bst_float true_lab const double log_y_low = 1.0; const double log_y_high = 15.0; - obj->Configure({{"aft_loss_distribution", dist_type}, {"aft_loss_distribution_scale", "1"}}); + std::unique_ptr obj{ObjFunction::Create( + "survival:aft", ctx, + {{"aft_loss_distribution", dist_type}, {"aft_loss_distribution_scale", "1"}})}; MetaInfo info; info.num_row_ = num_point; @@ -70,23 +72,21 @@ static inline void CheckGPairOverGridPoints(ObjFunction* obj, bst_float true_lab } void TestAFTObjGPairUncensoredLabels(const Context* ctx) { - std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints( - obj.get(), 100.0f, 100.0f, "normal", + ctx, 100.0f, 100.0f, "normal", {-3.9120f, -3.4013f, -2.8905f, -2.3798f, -1.8691f, -1.3583f, -0.8476f, -0.3368f, 0.1739f, 0.6846f, 1.1954f, 1.7061f, 2.2169f, 2.7276f, 3.2383f, 3.7491f, 4.2598f, 4.7706f, 5.2813f, 5.7920f}, {1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f}); CheckGPairOverGridPoints( - obj.get(), 100.0f, 100.0f, "logistic", + ctx, 100.0f, 100.0f, "logistic", {-0.9608f, -0.9355f, -0.8948f, -0.8305f, -0.7327f, -0.5910f, -0.4001f, -0.1668f, 0.0867f, 0.3295f, 0.5354f, 0.6927f, 0.8035f, 0.8773f, 0.9245f, 0.9540f, 0.9721f, 0.9832f, 0.9899f, 0.9939f}, {0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f, 0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f}); - CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme", + CheckGPairOverGridPoints(ctx, 100.0f, 100.0f, "extreme", {-15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f, 0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f, 0.9969f}, @@ -96,22 +96,20 @@ void TestAFTObjGPairUncensoredLabels(const Context* ctx) { } void TestAFTObjGPairLeftCensoredLabels(const Context* ctx) { - std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints( - obj.get(), 0.0f, 20.0f, "normal", + ctx, 0.0f, 20.0f, "normal", {0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f, 3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f}, {0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f, 0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f}); CheckGPairOverGridPoints( - obj.get(), 0.0f, 20.0f, "logistic", + ctx, 0.0f, 20.0f, "logistic", {0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f, 0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f}, {0.0826f, 0.1224f, 0.1701f, 0.2163f, 0.2458f, 0.2461f, 0.2170f, 0.1709f, 0.1232f, 0.0832f, 0.0538f, 0.0338f, 0.0209f, 0.0127f, 0.0077f, 0.0047f, 0.0028f, 0.0017f, 0.0010f, 0.0006f}); CheckGPairOverGridPoints( - obj.get(), 0.0f, 20.0f, "extreme", + ctx, 0.0f, 20.0f, "extreme", {0.0005f, 0.0149f, 0.1011f, 0.2815f, 0.4881f, 0.6610f, 0.7847f, 0.8665f, 0.9183f, 0.9504f, 0.9700f, 0.9820f, 0.9891f, 0.9935f, 0.9961f, 0.9976f, 0.9986f, 0.9992f, 0.9995f, 0.9997f}, {0.0041f, 0.0747f, 0.2731f, 0.4059f, 0.3829f, 0.2901f, 0.1973f, 0.1270f, 0.0793f, 0.0487f, @@ -119,23 +117,21 @@ void TestAFTObjGPairLeftCensoredLabels(const Context* ctx) { } void TestAFTObjGPairRightCensoredLabels(const Context* ctx) { - std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints( - obj.get(), 60.0f, std::numeric_limits::infinity(), "normal", + ctx, 60.0f, std::numeric_limits::infinity(), "normal", {-3.6583f, -3.1815f, -2.7135f, -2.2577f, -1.8190f, -1.4044f, -1.0239f, -0.6905f, -0.4190f, -0.2209f, -0.0973f, -0.0346f, -0.0097f, -0.0021f, -0.0004f, -0.0000f, -0.0000f, -0.0000f, -0.0000f, -0.0000f}, {0.9407f, 0.9259f, 0.9057f, 0.8776f, 0.8381f, 0.7821f, 0.7036f, 0.5970f, 0.4624f, 0.3128f, 0.1756f, 0.0780f, 0.0265f, 0.0068f, 0.0013f, 0.0002f, 0.0000f, 0.0000f, 0.0000f, 0.0000f}); CheckGPairOverGridPoints( - obj.get(), 60.0f, std::numeric_limits::infinity(), "logistic", + ctx, 60.0f, std::numeric_limits::infinity(), "logistic", {-0.9677f, -0.9474f, -0.9153f, -0.8663f, -0.7955f, -0.7000f, -0.5834f, -0.4566f, -0.3352f, -0.2323f, -0.1537f, -0.0982f, -0.0614f, -0.0377f, -0.0230f, -0.0139f, -0.0084f, -0.0051f, -0.0030f, -0.0018f}, {0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f, 0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f}); - CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits::infinity(), "extreme", + CheckGPairOverGridPoints(ctx, 60.0f, std::numeric_limits::infinity(), "extreme", {-15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f, -0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f, -0.0031f, -0.0018f}, @@ -145,24 +141,22 @@ void TestAFTObjGPairRightCensoredLabels(const Context* ctx) { } void TestAFTObjGPairIntervalCensoredLabels(const Context* ctx) { - std::unique_ptr obj(ObjFunction::Create("survival:aft", ctx)); - CheckGPairOverGridPoints( - obj.get(), 16.0f, 200.0f, "normal", + ctx, 16.0f, 200.0f, "normal", {-2.4435f, -1.9965f, -1.5691f, -1.1679f, -0.7990f, -0.4649f, -0.1596f, 0.1336f, 0.4370f, 0.7682f, 1.1340f, 1.5326f, 1.9579f, 2.4035f, 2.8639f, 3.3351f, 3.8143f, 4.2995f, 4.7891f, 5.2822f}, {0.8909f, 0.8579f, 0.8134f, 0.7557f, 0.6880f, 0.6221f, 0.5789f, 0.5769f, 0.6171f, 0.6818f, 0.7500f, 0.8088f, 0.8545f, 0.8884f, 0.9131f, 0.9312f, 0.9446f, 0.9547f, 0.9624f, 0.9684f}); CheckGPairOverGridPoints( - obj.get(), 16.0f, 200.0f, "logistic", + ctx, 16.0f, 200.0f, "logistic", {-0.8790f, -0.8112f, -0.7153f, -0.5893f, -0.4375f, -0.2697f, -0.0955f, 0.0800f, 0.2545f, 0.4232f, 0.5768f, 0.7054f, 0.8040f, 0.8740f, 0.9210f, 0.9513f, 0.9703f, 0.9820f, 0.9891f, 0.9934f}, {0.1086f, 0.1588f, 0.2176f, 0.2745f, 0.3164f, 0.3374f, 0.3433f, 0.3434f, 0.3384f, 0.3191f, 0.2789f, 0.2229f, 0.1637f, 0.1125f, 0.0737f, 0.0467f, 0.0290f, 0.0177f, 0.0108f, 0.0065f}); CheckGPairOverGridPoints( - obj.get(), 16.0f, 200.0f, "extreme", + ctx, 16.0f, 200.0f, "extreme", {-8.0000f, -4.8004f, -2.8805f, -1.7284f, -1.0371f, -0.6168f, -0.3140f, -0.0121f, 0.2841f, 0.5261f, 0.6989f, 0.8132f, 0.8857f, 0.9306f, 0.9581f, 0.9747f, 0.9848f, 0.9909f, 0.9945f, 0.9967f}, diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index 36dcab5bc065..9d0aac8a8aa2 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -32,46 +32,38 @@ TEST(LambdaRank, NDCGJsonIO) { void TestNDCGGPair(Context const* ctx) { { - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; - obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + std::unique_ptr obj{ + xgboost::ObjFunction::Create("rank:ndcg", ctx, {{"lambdarank_pair_method", "topk"}})}; CheckConfigReload(obj, "rank:ndcg"); // No gain in swapping 2 documents. - CheckRankingObjFunction(obj, - {1, 1, 1, 1}, - {1, 1, 1, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.0f, -0.0f, 0.0f, 0.0f}, - {0.0f, 0.0f, 0.0f, 0.0f}); + CheckRankingObjFunction(obj, {1, 1, 1, 1}, {1, 1, 1, 1}, {1.0f, 1.0f}, {0, 2, 4}, + {0.0f, -0.0f, 0.0f, 0.0f}, {0.0f, 0.0f, 0.0f, 0.0f}); } { - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; - obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + std::unique_ptr obj{ + xgboost::ObjFunction::Create("rank:ndcg", ctx, {{"lambdarank_pair_method", "topk"}})}; // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {2.06611f, -2.06611f, 0.0f, 0.0f}, - {2.169331f, 2.169331f, 0.0f, 0.0f}); + CheckRankingObjFunction(obj, {0, 0.1f, 0, 0.1f}, {0, 1, 0, 1}, {2.0f, 0.0f}, {0, 2, 4}, + {2.06611f, -2.06611f, 0.0f, 0.0f}, {2.169331f, 2.169331f, 0.0f, 0.0f}); } { - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; - obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + std::unique_ptr obj{ + xgboost::ObjFunction::Create("rank:ndcg", ctx, {{"lambdarank_pair_method", "topk"}})}; float weight_norm = 0.5; // n_groups / sum_weights std::vector out_grad{2.06611f, -2.06611f, 2.06611f, -2.06611f}; std::vector out_hess{2.169331f, 2.169331f, 2.169331f, 2.169331f}; - auto norm = [=](auto v) { return v * weight_norm; }; + auto norm = [=](auto v) { + return v * weight_norm; + }; std::transform(out_grad.begin(), out_grad.end(), out_grad.begin(), norm); std::transform(out_hess.begin(), out_hess.end(), out_hess.begin(), norm); CheckRankingObjFunction(obj, {0, 0.1f, 0, 0.1f}, {0, 1, 0, 1}, {2.0f, 2.0f}, {0, 2, 4}, out_grad, out_hess); } - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; - obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + std::unique_ptr obj{ + xgboost::ObjFunction::Create("rank:ndcg", ctx, {{"lambdarank_pair_method", "topk"}})}; HostDeviceVector predts{0, 1, 0, 1}; MetaInfo info; @@ -112,8 +104,8 @@ void TestNDCGGPair(Context const* ctx) { { // Test empty input - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; - obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + std::unique_ptr obj{ + xgboost::ObjFunction::Create("rank:ndcg", ctx, {{"lambdarank_pair_method", "topk"}})}; HostDeviceVector predts; MetaInfo info; @@ -133,10 +125,11 @@ TEST(LambdaRank, NDCGGPair) { } void TestUnbiasedNDCG(Context const* ctx) { - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; - obj->Configure(Args{{"lambdarank_pair_method", "topk"}, - {"lambdarank_unbiased", "true"}, - {"lambdarank_bias_norm", "0"}}); + std::unique_ptr obj{ + xgboost::ObjFunction::Create("rank:ndcg", ctx, + {{"lambdarank_pair_method", "topk"}, + {"lambdarank_unbiased", "true"}, + {"lambdarank_bias_norm", "0"}})}; std::shared_ptr p_fmat{ RandomDataGenerator{10, 1, 0.0f}.Classes(2).GenerateDMatrix(true)}; auto h_label = p_fmat->Info().labels.HostView().Values(); @@ -323,8 +316,7 @@ TEST(LambdaRank, MAPStat) { } void TestMAPGPair(Context const* ctx) { - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:map", ctx)}; - obj->Configure({}); + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:map", ctx, {})}; CheckConfigReload(obj, "rank:map"); @@ -336,8 +328,7 @@ void TestMAPGPair(Context const* ctx) { {1.2054923f, -1.2054923f, 1.2054923f, -1.2054923f}, // out grad {1.2657166f, 1.2657166f, 1.2657166f, 1.2657166f}); - obj.reset(xgboost::ObjFunction::Create("rank:map", ctx)); - obj->Configure({}); + obj.reset(xgboost::ObjFunction::Create("rank:map", ctx, {})); // disable the second query group with 0 weight auto w = 2.0f; // weight for the first group @@ -358,10 +349,7 @@ TEST(LambdaRank, MAPGPair) { void TestPairWiseGPair(Context const* ctx) { std::unique_ptr obj{xgboost::ObjFunction::Create("rank:pairwise", ctx)}; - Args args; - obj->Configure(args); - - args.emplace_back("lambdarank_unbiased", "true"); + obj.reset(xgboost::ObjFunction::Create("rank:pairwise", ctx, {{"lambdarank_unbiased", "true"}})); } TEST(LambdaRank, Pairwise) { diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h index 4383a44d1a75..78b7dcc43f8e 100644 --- a/tests/cpp/objective/test_lambdarank_obj.h +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -1,26 +1,24 @@ /** * Copyright 2023-2025, XGBoost Contributors */ -#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ -#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#ifndef TESTS_CPP_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#define TESTS_CPP_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ #include -#include // for MetaInfo -#include // for HostDeviceVector -#include // for All -#include // for ObjFunction +#include // for MetaInfo +#include // for HostDeviceVector +#include // for All +#include // for ObjFunction -#include // for shared_ptr, make_shared +#include // for shared_ptr, make_shared -#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache -#include "../helpers.h" // for EmptyDMatrix +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache +#include "../helpers.h" // for EmptyDMatrix namespace xgboost::obj { void TestMAPStat(Context const* ctx); inline void TestNDCGJsonIO(Context const* ctx) { - std::unique_ptr obj{ObjFunction::Create("rank:ndcg", ctx)}; - - obj->Configure(Args{}); + std::unique_ptr obj{ObjFunction::Create("rank:ndcg", ctx, {})}; Json j_obj{Object()}; obj->SaveConfig(&j_obj); @@ -43,4 +41,4 @@ void TestMAPGPair(Context const* ctx); */ void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt); } // namespace xgboost::obj -#endif // XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#endif // TESTS_CPP_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 4f8a37049120..f3309a593677 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -491,7 +491,7 @@ void TestPseudoHuber(const Context* ctx) { {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); - obj->Configure({{"huber_slope", "0.1"}}); + obj.reset(ObjFunction::Create("reg:pseudohubererror", ctx, {{"huber_slope", "0.1"}})); CheckConfigReload(obj, "reg:pseudohubererror"); CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels From 5da1e058d78f6f38da9cc9b22bf629f8d71bb7f4 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Mar 2026 06:58:25 -0700 Subject: [PATCH 7/7] Remove no-arg objective factory path --- include/xgboost/objective.h | 2 +- plugin/example/custom_obj.cc | 1 - src/objective/aft_obj.cu | 1 - src/objective/hinge.cu | 4 ++-- src/objective/lambdarank_obj.cc | 1 - src/objective/quantile_obj.cu | 1 - src/objective/regression_obj.cu | 16 ++++++---------- tests/cpp/objective/test_lambdarank_obj.cc | 2 +- tests/cpp/objective/test_objective.cc | 4 ++-- 9 files changed, 12 insertions(+), 20 deletions(-) diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index ac729f952561..77249f7da070 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -135,7 +135,7 @@ class ObjFunction { * @param name Name of the objective. * @param ctx Pointer to the context. */ - static ObjFunction* Create(const std::string& name, Context const* ctx, Args const& args = {}); + static ObjFunction* Create(const std::string& name, Context const* ctx, Args const& args); static ObjFunction* Create(Context const* ctx, Json const& config); }; diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index 5a08b55dda98..30e7155446e5 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -34,7 +34,6 @@ class MyLogistic : public ObjFunction { public: explicit MyLogistic(Args const& args) { param_.UpdateAllowUnknown(args); } explicit MyLogistic(Json const& in) { FromJson(in["my_logistic_param"], ¶m_); } - MyLogistic() = default; [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 849a5ff54722..d5697c06d82f 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -33,7 +33,6 @@ class AFTObj : public ObjFunction { public: explicit AFTObj(Args const& args) { param_.UpdateAllowUnknown(args); } explicit AFTObj(Json const& in) { FromJson(in["aft_loss_param"], ¶m_); } - AFTObj() = default; ObjInfo Task() const override { return ObjInfo::kSurvival; } diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index b54483431397..0c81fe3aabd0 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -26,7 +26,7 @@ DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu); class HingeObj : public FitIntercept { public: - HingeObj() = default; + explicit HingeObj(Args const &) {} explicit HingeObj(Json const &) {} ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -97,7 +97,7 @@ class HingeObj : public FitIntercept { // register the objective functions XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") .describe("Hinge loss. Expects labels to be in [0,1f]") - .set_body([](Args const &) { return new HingeObj(); }) + .set_body([](Args const &args) { return new HingeObj(args); }) .set_body_json([](Json const &config) { return new HingeObj(config); }); } // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index d0f8a7e818a7..150c6c645b76 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -262,7 +262,6 @@ class LambdaRankObj : public FitIntercept { linalg::LoadVector(in["tj-"], &tj_minus_); } } - LambdaRankObj() = default; void SaveConfig(Json* p_out) const override { auto& out = *p_out; out["name"] = String(Loss::Name()); diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index ff87f1e36be6..a4897fc58bb9 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -61,7 +61,6 @@ class QuantileRegression : public ObjFunction { FromJson(in["quantile_loss_param"], ¶m_); alpha_.HostVector() = param_.quantile_alpha.Get(); } - QuantileRegression() = default; void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, std::int32_t iter, linalg::Matrix* out_gpair) override { diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 6427bbba24e8..e90b3410296d 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -285,7 +285,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear") class SquaredLogErrorRegression : public FitIntercept { public: static auto Name() { return SquaredLogError::Name(); } - SquaredLogErrorRegression() = default; + explicit SquaredLogErrorRegression(Args const&) {} explicit SquaredLogErrorRegression(Json const&) {} [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -328,7 +328,7 @@ class SquaredLogErrorRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(SquaredLogErrorRegression, SquaredLogErrorRegression::Name()) .describe("Root mean squared log error.") - .set_body([](Args const&) { return new SquaredLogErrorRegression(); }) + .set_body([](Args const& args) { return new SquaredLogErrorRegression{args}; }) .set_body_json([](Json const& config) { return new SquaredLogErrorRegression{config}; }); class PseudoHuberRegression : public FitIntercept { @@ -343,7 +343,6 @@ class PseudoHuberRegression : public FitIntercept { FromJson(it->second, ¶m_); } } - PseudoHuberRegression() = default; [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { return std::max(static_cast(1), info.labels.Shape(1)); @@ -436,7 +435,6 @@ class ExpectileRegression : public FitIntercept { alpha_.HostVector() = param_.expectile_alpha.Get(); } } - ExpectileRegression() = default; [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -575,7 +573,6 @@ class PoissonRegression : public FitInterceptGlmLike { public: explicit PoissonRegression(Args const& args) { param_.UpdateAllowUnknown(args); } explicit PoissonRegression(Json const& in) { FromJson(in["poisson_regression_param"], ¶m_); } - PoissonRegression() = default; [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -646,7 +643,7 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") // cox regression for survival data (negative values mean they are censored) class CoxRegression : public FitIntercept { public: - CoxRegression() = default; + explicit CoxRegression(Args const&) {} explicit CoxRegression(Json const&) {} [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -734,7 +731,7 @@ class CoxRegression : public FitIntercept { XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") .describe( "Cox regression for censored survival data (negative labels are considered censored).") - .set_body([](Args const&) { return new CoxRegression(); }) + .set_body([](Args const& args) { return new CoxRegression{args}; }) .set_body_json([](Json const& config) { return new CoxRegression{config}; }); // declare parameter @@ -765,7 +762,6 @@ class TweedieRegression : public FitInterceptGlmLike { os << "tweedie-nloglik@" << param_.tweedie_variance_power; metric_ = os.str(); } - TweedieRegression() = default; [[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; } @@ -837,7 +833,7 @@ XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") class MeanAbsoluteError : public ObjFunction { public: - MeanAbsoluteError() = default; + explicit MeanAbsoluteError(Args const&) {} explicit MeanAbsoluteError(Json const& in) { CHECK_EQ(StringView{get(in["name"])}, StringView{"reg:absoluteerror"}); } @@ -927,6 +923,6 @@ class MeanAbsoluteError : public ObjFunction { XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") .describe("Mean absolute error.") - .set_body([](Args const&) { return new MeanAbsoluteError(); }) + .set_body([](Args const& args) { return new MeanAbsoluteError{args}; }) .set_body_json([](Json const& config) { return new MeanAbsoluteError{config}; }); } // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index 9d0aac8a8aa2..80dc12e21e17 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -348,7 +348,7 @@ TEST(LambdaRank, MAPGPair) { } void TestPairWiseGPair(Context const* ctx) { - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:pairwise", ctx)}; + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:pairwise", ctx, {})}; obj.reset(xgboost::ObjFunction::Create("rank:pairwise", ctx, {{"lambdarank_unbiased", "true"}})); } diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index f01d80478bf8..6be36c9a764c 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -15,8 +15,8 @@ TEST(Objective, UnknownFunction) { std::vector> args; tparam.UpdateAllowUnknown(args); - EXPECT_ANY_THROW(obj = xgboost::ObjFunction::Create("unknown_name", &tparam)); - EXPECT_NO_THROW(obj = xgboost::ObjFunction::Create("reg:squarederror", &tparam)); + EXPECT_ANY_THROW(obj = xgboost::ObjFunction::Create("unknown_name", &tparam, xgboost::Args{})); + EXPECT_NO_THROW(obj = xgboost::ObjFunction::Create("reg:squarederror", &tparam, xgboost::Args{})); if (obj) { delete obj; }