-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Expose objective functions to the Python interface. #12059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 12 commits
47f62d9
51a067c
8e45807
92082c4
b46d20d
049a7ec
c8a71a9
821e1b1
354391f
6a7152a
4bde019
b517d7b
6029fc7
4f66058
4a30b9f
457b53e
cd5bfba
8816629
1948ae7
094c8ed
0f10aff
6639b94
5bc828a
e9d6673
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,7 +87,7 @@ | |
| is_pyarrow_available, | ||
| py_str, | ||
| ) | ||
| from .objective import Objective, TreeObjective, _grad_arrinf | ||
| from .objective import Objective, _grad_arrinf, _stringify | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pandas import DataFrame as PdDataFrame | ||
|
|
@@ -2162,13 +2162,11 @@ def set_param( | |
| elif isinstance(params, str) and value is not None: | ||
| params = [(params, value)] | ||
| for key, val in cast(Iterable[Tuple[str, str]], params): | ||
| if isinstance(val, np.ndarray): | ||
| val = val.tolist() | ||
| elif hasattr(val, "__cuda_array_interface__") and hasattr(val, "tolist"): | ||
| val = val.tolist() | ||
| if val is not None: | ||
| _check_call( | ||
| _LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))) | ||
| _LIB.XGBoosterSetParam( | ||
| self.handle, c_str(key), c_str(_stringify(val)) | ||
| ) | ||
| ) | ||
|
|
||
| def update( | ||
|
|
@@ -2280,21 +2278,14 @@ def train_one_iter(grad: NumpyOrCupy, hess: NumpyOrCupy) -> None: | |
| vgrad: Optional[ArrayLike] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious what the latency differences between using the python interface or internal. One possible simplification could be to always use the python, there is just one code path.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The internal assumes split gradient must be available, and only does an extra leaf value computation at the end of iteration if there's an extra leaf value gradient. This is easy to implement as we only need one extra step. But that's not intuitive to users since the algorithm creates an extra "split gradient", at least that's the mental model. So, the interface assumes a value gradient is available, as in normal gradient boosting. The assumption is switched here.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I misread the original question. There's no difference in latency, it's just setting parameter. Yes, I think a single code path would be super nice. |
||
| vhess: Optional[ArrayLike] | ||
|
|
||
| if isinstance(fobj, TreeObjective): | ||
| # full gradient for leaf values | ||
| if isinstance(fobj, Objective): | ||
| vgrad, vhess = fobj(iteration, y_pred, dtrain) | ||
| # Reduced gradient for split nodes | ||
| split_grad = fobj.split_grad(iteration, vgrad, vhess) | ||
| # Switch the role of gradient if there's no split gradient but the tree | ||
| # objective is used. | ||
| if split_grad is not None: | ||
| sgrad, shess = split_grad | ||
| else: | ||
| sgrad, shess = vgrad, vhess | ||
| vgrad, vhess = None, None | ||
| elif isinstance(fobj, Objective): | ||
| sgrad, shess = fobj(iteration, y_pred, dtrain) | ||
| vgrad, vhess = None, None | ||
| else: | ||
| # Plain callable | ||
| sgrad, shess = fobj(y_pred, dtrain) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.