Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,15 @@ class RegTree;
struct Context;

/** @brief The interface of objective function */
class ObjFunction : public Configurable {
class ObjFunction {
protected:
Context const* ctx_{nullptr};

public:
static constexpr float DefaultBaseScore() { return 0.5f; }

public:
~ObjFunction() override = default;
/**
* @brief Configure the objective with the specified parameters.
*
* @param args arguments to the objective function.
*/
virtual void Configure(Args const& args) = 0;
virtual ~ObjFunction() = default;
/**
* @brief Get gradient over each of predictions, given existing information.
*
Expand Down Expand Up @@ -130,38 +124,52 @@ class ObjFunction : public Configurable {
MetaInfo const& /*info*/, float /*learning_rate*/,
HostDeviceVector<float> const& /*prediction*/,
bst_target_t /*group_idx*/, RegTree* /*p_tree*/) const {}
/**
* @brief Save configuration to JSON object.
* @param out pointer to output JSON object.
*/
virtual void SaveConfig(Json* out) const = 0;
/**
* @brief Create an objective function according to the name.
*
* @param name Name of the objective.
* @param ctx Pointer to the context.
*/
static ObjFunction* Create(const std::string& name, Context const* ctx);
static ObjFunction* Create(const std::string& name, Context const* ctx, Args const& args);
static ObjFunction* Create(Context const* ctx, Json const& config);
};

/*!
* \brief Registry entry for objective factory functions.
*/
struct ObjFunctionReg
: public dmlc::FunctionRegEntryBase<ObjFunctionReg,
std::function<ObjFunction* ()> > {
: public dmlc::FunctionRegEntryBase<ObjFunctionReg, std::function<ObjFunction*(Args const&)>> {
std::function<ObjFunction*(Json const&)> json_body;

inline ObjFunctionReg& set_body_json(std::function<ObjFunction*(Json const&)> body) {
json_body = std::move(body);
return *this;
}
};

/*!
* \brief Macro to register objective function.
*
* \code
* // example of registering a objective
* XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:squarederror")
* .describe("Linear regression objective")
* .set_body([]() {
* return new RegLossObj(LossType::kLinearSquare);
* XGBOOST_REGISTER_OBJECTIVE(MyObjective, "my:objective")
* .describe("A custom objective")
* .set_body([](Args const&) {
* return new MyObjective();
* })
* .set_body_json([](Json const& config) {
* return new MyObjective(config);
* });
* \endcode
*/
#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \
__make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg& \
__make_##ObjFunctionReg##_##UniqueId##__ = \
::dmlc::Registry<::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
} // namespace xgboost
#endif // XGBOOST_OBJECTIVE_H_
24 changes: 11 additions & 13 deletions plugin/example/custom_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ struct MyLogisticParam : public XGBoostParameter<MyLogisticParam> {
float scale_neg_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(MyLogisticParam) {
DMLC_DECLARE_FIELD(scale_neg_weight).set_default(1.0f).set_lower_bound(0.0f)
DMLC_DECLARE_FIELD(scale_neg_weight)
.set_default(1.0f)
.set_lower_bound(0.0f)
.describe("Scale the weight of negative examples by this factor");
}
};
Expand All @@ -30,7 +32,8 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam);
// Implement the interface.
class MyLogistic : public ObjFunction {
public:
void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); }
explicit MyLogistic(Args const& args) { param_.UpdateAllowUnknown(args); }
explicit MyLogistic(Json const& in) { FromJson(in["my_logistic_param"], &param_); }

[[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }

Expand All @@ -53,12 +56,10 @@ class MyLogistic : public ObjFunction {
out_gpair_h(i) = GradientPair(grad, hess);
}
}
[[nodiscard]] const char* DefaultEvalMetric() const override {
return "logloss";
}
void PredTransform(HostDeviceVector<float> *io_preds) const override {
[[nodiscard]] const char* DefaultEvalMetric() const override { return "logloss"; }
void PredTransform(HostDeviceVector<float>* io_preds) const override {
// transform margin value to probability.
std::vector<float> &preds = io_preds->HostVector();
std::vector<float>& preds = io_preds->HostVector();
for (auto& pred : preds) {
pred = 1.0f / (1.0f + std::exp(-pred));
}
Expand All @@ -77,18 +78,15 @@ class MyLogistic : public ObjFunction {
out["my_logistic_param"] = ToJson(param_);
}

void LoadConfig(Json const& in) override {
FromJson(in["my_logistic_param"], &param_);
}

private:
MyLogisticParam param_;
};

// Finally register the objective function.
// After it succeeds you can try use xgboost with objective=mylogistic
XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic")
.describe("User defined logistic regression plugin")
.set_body([]() { return new MyLogistic(); });
.describe("User defined logistic regression plugin")
.set_body([](Args const& args) { return new MyLogistic{args}; })
.set_body_json([](Json const& config) { return new MyLogistic{config}; });

} // namespace xgboost::obj
19 changes: 10 additions & 9 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,8 @@ class LearnerConfiguration : public Intercept {
auto const& objective_fn = learner_parameters.at("objective");
if (!obj_) {
CHECK_EQ(get<String const>(objective_fn["name"]), tparam_.objective);
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
}
obj_->LoadConfig(objective_fn);
obj_.reset(ObjFunction::Create(&ctx_, objective_fn));
learner_model_param_.task = obj_->Task();

tparam_.booster = CanonicalizeBoosterName(get<String>(gradient_booster["name"]));
Expand Down Expand Up @@ -832,6 +831,7 @@ class LearnerConfiguration : public Intercept {
}

void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
auto preserve_loaded_objective = cfg_.empty() && obj_ != nullptr;
// Once binary IO is gone, NONE of these config is useful.
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" &&
tparam_.objective != "multi:softprob") {
Expand All @@ -847,17 +847,19 @@ class LearnerConfiguration : public Intercept {
// Rename one of them once binary IO is gone.
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
}
if (obj_ == nullptr || tparam_.objective != old.objective) {
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
}

bool has_nc{cfg_.find("num_class") != cfg_.cend()};
// Inject num_class into configuration.
// FIXME(jiamingy): Remove the duplicated parameter in softmax
cfg_["num_class"] = std::to_string(mparam_.num_class);
auto& args = *p_args;
args = {cfg_.cbegin(), cfg_.cend()}; // renew
obj_->Configure(args);
if (preserve_loaded_objective && tparam_.objective == old.objective) {
if (!has_nc) {
cfg_.erase("num_class");
}
return;
}
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_, args));
if (!has_nc) {
cfg_.erase("num_class");
}
Expand Down Expand Up @@ -911,8 +913,7 @@ class LearnerIO : public LearnerConfiguration {

std::string name = get<String>(objective_fn["name"]);
tparam_.UpdateAllowUnknown(Args{{"objective", name}});
obj_.reset(ObjFunction::Create(name, &ctx_));
obj_->LoadConfig(objective_fn);
obj_.reset(ObjFunction::Create(&ctx_, objective_fn));

auto const& gradient_booster = learner.at("gradient_booster");
name = get<String>(gradient_booster["name"]);
Expand Down
84 changes: 38 additions & 46 deletions src/objective/aft_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ DMLC_REGISTRY_FILE_TAG(aft_obj_gpu);

class AFTObj : public ObjFunction {
public:
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
}
explicit AFTObj(Args const& args) { param_.UpdateAllowUnknown(args); }
explicit AFTObj(Json const& in) { FromJson(in["aft_loss_param"], &param_); }

ObjInfo Task() const override { return ObjInfo::kSurvival; }

Expand All @@ -42,27 +41,24 @@ class AFTObj : public ObjFunction {
linalg::Matrix<GradientPair>* out_gpair, size_t ndata, DeviceOrd device,
bool is_null_weight, float aft_loss_distribution_scale) {
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels_lower_bound,
common::Span<const bst_float> _labels_upper_bound,
common::Span<const bst_float> _weights) {
const double pred = static_cast<double>(_preds[_idx]);
const double label_lower_bound = static_cast<double>(_labels_lower_bound[_idx]);
const double label_upper_bound = static_cast<double>(_labels_upper_bound[_idx]);
const float grad = static_cast<float>(
AFTLoss<Distribution>::Gradient(label_lower_bound, label_upper_bound,
pred, aft_loss_distribution_scale));
const float hess = static_cast<float>(
AFTLoss<Distribution>::Hessian(label_lower_bound, label_upper_bound,
pred, aft_loss_distribution_scale));
const bst_float w = is_null_weight ? 1.0f : _weights[_idx];
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_);
[=] XGBOOST_DEVICE(size_t _idx, common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels_lower_bound,
common::Span<const bst_float> _labels_upper_bound,
common::Span<const bst_float> _weights) {
const double pred = static_cast<double>(_preds[_idx]);
const double label_lower_bound = static_cast<double>(_labels_lower_bound[_idx]);
const double label_upper_bound = static_cast<double>(_labels_upper_bound[_idx]);
const float grad = static_cast<float>(AFTLoss<Distribution>::Gradient(
label_lower_bound, label_upper_bound, pred, aft_loss_distribution_scale));
const float hess = static_cast<float>(AFTLoss<Distribution>::Hessian(
label_lower_bound, label_upper_bound, pred, aft_loss_distribution_scale));
const bst_float w = is_null_weight ? 1.0f : _weights[_idx];
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device)
.Eval(out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_);
}

void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int /*iter*/,
Expand All @@ -77,28 +73,28 @@ class AFTObj : public ObjFunction {
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
<< "Number of weights should be equal to number of data points.";
}

switch (param_.aft_loss_distribution) {
case common::ProbabilityDistributionType::kNormal:
GetGradientImpl<common::NormalDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
case common::ProbabilityDistributionType::kLogistic:
GetGradientImpl<common::LogisticDistribution>(preds, info, out_gpair, ndata, device,
case common::ProbabilityDistributionType::kNormal:
GetGradientImpl<common::NormalDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
case common::ProbabilityDistributionType::kExtreme:
GetGradientImpl<common::ExtremeDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
default:
LOG(FATAL) << "Unrecognized distribution";
break;
case common::ProbabilityDistributionType::kLogistic:
GetGradientImpl<common::LogisticDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
case common::ProbabilityDistributionType::kExtreme:
GetGradientImpl<common::ExtremeDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
default:
LOG(FATAL) << "Unrecognized distribution";
}
}

void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
// Trees give us a prediction in log scale, so exponentiate
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
Expand All @@ -120,19 +116,14 @@ class AFTObj : public ObjFunction {
});
}

const char* DefaultEvalMetric() const override {
return "aft-nloglik";
}
const char* DefaultEvalMetric() const override { return "aft-nloglik"; }

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("survival:aft");
out["aft_loss_param"] = ToJson(param_);
}

void LoadConfig(Json const& in) override {
FromJson(in["aft_loss_param"], &param_);
}
Json DefaultMetricConfig() const override {
Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()};
Expand All @@ -147,7 +138,8 @@ class AFTObj : public ObjFunction {
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
.describe("AFT loss function")
.set_body([]() { return new AFTObj(); });
.set_body([](Args const& args) { return new AFTObj{args}; })
.set_body_json([](Json const& config) { return new AFTObj{config}; });

} // namespace obj
} // namespace xgboost
9 changes: 4 additions & 5 deletions src/objective/hinge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);

class HingeObj : public FitIntercept {
public:
HingeObj() = default;

void Configure(Args const &) override {}
explicit HingeObj(Args const &) {}
explicit HingeObj(Json const &) {}
ObjInfo Task() const override { return ObjInfo::kRegression; }

[[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override {
Expand Down Expand Up @@ -93,12 +92,12 @@ class HingeObj : public FitIntercept {
auto &out = *p_out;
out["name"] = String("binary:hinge");
}
void LoadConfig(Json const &) override {}
};

// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
.set_body([](Args const &args) { return new HingeObj(args); })
.set_body_json([](Json const &config) { return new HingeObj(config); });

} // namespace xgboost::obj
6 changes: 2 additions & 4 deletions src/objective/init_estimation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*/
#include "init_estimation.h"

#include <memory> // unique_ptr
#include <memory> // unique_ptr

#include "../common/stats.h" // Mean
#include "../tree/fit_stump.h" // FitStump
Expand All @@ -26,9 +26,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
Json config{Object{}};
this->SaveConfig(&config);

std::unique_ptr<ObjFunction> new_obj{
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
new_obj->LoadConfig(config);
std::unique_ptr<ObjFunction> new_obj{ObjFunction::Create(this->ctx_, config)};
new_obj->GetGradient(dummy_predt, info, 0, &gpair);

bst_target_t n_targets = this->Targets(info);
Expand Down
Loading
Loading