-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Large difference between builtin softmax and custom softmax objective #6219
Comments
Hey @qianyun210603, thanks for using LightGBM. The difference is due to the different init scores, you can find a detailed explanation in #5114 (comment). Please let us know if you have further doubts. |
@jmoralez Thanks a lot! Could you kindly point me where can I find the logic of setting init scores in builtin multiclassification please? Is it also mean of label? in this case |
Sure. The averages of the labels are computed here LightGBM/src/objective/multiclass_objective.hpp Lines 53 to 84 in 2ee3ec8
And then the init scores are the logs of the averages (since those are the raw scores) LightGBM/src/objective/multiclass_objective.hpp Lines 155 to 157 in 2ee3ec8
|
Many thanks! |
Hi @jmoralez, sorry to trouble you again. Thanks for your previous instruction, I finally get the customised loss function aligned with the buildin one. However, another questions arised which I hope to get your help: Also if my guess is really the case, what's the reason that lightGBM doesn't add back the Please kindly help me understand above questions. Thanks a lot! |
Since you're using the scikit-learn API, if you use a custom objective the output of predict and predict(raw_score=True) is the same LightGBM/python-package/lightgbm/sklearn.py Lines 1233 to 1237 in 5083df1
Exactly
I think it's because with the built-in objectives the boosting starts from a single number for each class, so that's the value at the root of the first tree, whereas when you provide a custom objective it starts from zero and then you can provide an init score for each sample. At inference time we may have a different number of samples, so we wouldn't know which value to add to each sample. |
Thanks a lot! That's very clear. |
Found myself writing up a Python implementation of calculating the multiclass init score (based on the code @jmoralez shared in #6219 (comment)), thought it would be useful to post the snippet here for others finding this from search. import lightgbm as lgb
import numpy as np
from sklearn.datasets import make_blobs
# generate multiclass dataset with 5 classes
X, y = make_blobs(n_samples=1_000, centers=5, random_state=773)
# fit a small multiclass classification model
clf = lgb.LGBMClassifier(n_estimators=3, num_leaves=4, seed=708)
clf.fit(X, y)
# for the builtin multiclass objective, LightGBM
# begins boosting from the weighted mean of the label
_, counts = np.unique(np.sort(y), return_counts=True)
init_score = np.log(counts/y.shape[0])
print(init_score) That
|
Description
I'm using the sklearn interface to solve some 3-category classification problem.
I tried to benchmark the custom softmax objective function
sklearn_multiclass_custom_objective
copied fromtests/python_package_test/utils.py
for multi-category classification with the builtin one to verify it's accuracy so that I can further customise it to fit my own needs.However, I see large difference of predicted result on original train data. I want to figure out whether it is expected and is it possible to align the prediction result of the two?
Reproducible example
The result is:
We see
The dataset I used (
test_data.bin
) is shared in MS OneDrive (Link: https://1drv.ms/u/s!AnPL7Q5hAP8rlBAIOMZpK_Q5z3EL?e=T4cAev)Environment info
LightGBM version or commit hash:
Command(s) you used to install LightGBM
I used pip to install lightgbm
The version installed is 4.1.0
Additional Comments
Could you kindly help me understand/investigate this discrepancy please?
The text was updated successfully, but these errors were encountered: