Skip to content

Commit bcce17e

Browse files
authored
Remove text loading in basic walk through demo. (dmlc#7753)
1 parent c467e90 commit bcce17e

File tree

3 files changed

+43
-62
lines changed

3 files changed

+43
-62
lines changed

Diff for: demo/guide-python/basic_walkthrough.py

+35-57
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,65 @@
11
"""
22
Getting started with XGBoost
33
============================
4+
5+
This is a simple example of using the native XGBoost interface, there are other
6+
interfaces in the Python package like scikit-learn interface and Dask interface.
7+
8+
9+
See :doc:`/python/python_intro` and :doc:`/tutorials/index` for other references.
10+
411
"""
512
import numpy as np
6-
import scipy.sparse
713
import pickle
814
import xgboost as xgb
915
import os
1016

17+
from sklearn.datasets import load_svmlight_file
18+
1119
# Make sure the demo knows where to load the data.
1220
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
1321
XGBOOST_ROOT_DIR = os.path.dirname(os.path.dirname(CURRENT_DIR))
14-
DEMO_DIR = os.path.join(XGBOOST_ROOT_DIR, 'demo')
22+
DEMO_DIR = os.path.join(XGBOOST_ROOT_DIR, "demo")
1523

16-
# simple example
17-
# load file from text file, also binary buffer generated by xgboost
18-
dtrain = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train?indexing_mode=1'))
19-
dtest = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.test?indexing_mode=1'))
24+
# X is a scipy csr matrix, XGBoost supports many other input types,
25+
X, y = load_svmlight_file(os.path.join(DEMO_DIR, "data", "agaricus.txt.train"))
26+
dtrain = xgb.DMatrix(X, y)
27+
# validation set
28+
X_test, y_test = load_svmlight_file(os.path.join(DEMO_DIR, "data", "agaricus.txt.test"))
29+
dtest = xgb.DMatrix(X_test, y_test)
2030

2131
# specify parameters via map, definition are same as c++ version
22-
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
32+
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
2333

2434
# specify validations set to watch performance
25-
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
35+
watchlist = [(dtest, "eval"), (dtrain, "train")]
36+
# number of boosting rounds
2637
num_round = 2
27-
bst = xgb.train(param, dtrain, num_round, watchlist)
38+
bst = xgb.train(param, dtrain, num_boost_round=num_round, evals=watchlist)
2839

29-
# this is prediction
40+
# run prediction
3041
preds = bst.predict(dtest)
3142
labels = dtest.get_label()
32-
print('error=%f' %
33-
(sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) /
34-
float(len(preds))))
35-
bst.save_model('0001.model')
43+
print(
44+
"error=%f"
45+
% (
46+
sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i])
47+
/ float(len(preds))
48+
)
49+
)
50+
bst.save_model("model-0.json")
3651
# dump model
37-
bst.dump_model('dump.raw.txt')
52+
bst.dump_model("dump.raw.txt")
3853
# dump model with feature map
39-
bst.dump_model('dump.nice.txt', os.path.join(DEMO_DIR, 'data/featmap.txt'))
54+
bst.dump_model("dump.nice.txt", os.path.join(DEMO_DIR, "data/featmap.txt"))
4055

4156
# save dmatrix into binary buffer
42-
dtest.save_binary('dtest.buffer')
57+
dtest.save_binary("dtest.dmatrix")
4358
# save model
44-
bst.save_model('xgb.model')
59+
bst.save_model("model-1.json")
4560
# load model and data in
46-
bst2 = xgb.Booster(model_file='xgb.model')
47-
dtest2 = xgb.DMatrix('dtest.buffer')
61+
bst2 = xgb.Booster(model_file="model-1.json")
62+
dtest2 = xgb.DMatrix("dtest.dmatrix")
4863
preds2 = bst2.predict(dtest2)
4964
# assert they are the same
5065
assert np.sum(np.abs(preds2 - preds)) == 0
@@ -56,40 +71,3 @@
5671
preds3 = bst3.predict(dtest2)
5772
# assert they are the same
5873
assert np.sum(np.abs(preds3 - preds)) == 0
59-
60-
###
61-
# build dmatrix from scipy.sparse
62-
print('start running example of build DMatrix from scipy.sparse CSR Matrix')
63-
labels = []
64-
row = []
65-
col = []
66-
dat = []
67-
i = 0
68-
for l in open(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train')):
69-
arr = l.split()
70-
labels.append(int(arr[0]))
71-
for it in arr[1:]:
72-
k, v = it.split(':')
73-
row.append(i)
74-
col.append(int(k))
75-
dat.append(float(v))
76-
i += 1
77-
csr = scipy.sparse.csr_matrix((dat, (row, col)))
78-
dtrain = xgb.DMatrix(csr, label=labels)
79-
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
80-
bst = xgb.train(param, dtrain, num_round, watchlist)
81-
82-
print('start running example of build DMatrix from scipy.sparse CSC Matrix')
83-
# we can also construct from csc matrix
84-
csc = scipy.sparse.csc_matrix((dat, (row, col)))
85-
dtrain = xgb.DMatrix(csc, label=labels)
86-
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
87-
bst = xgb.train(param, dtrain, num_round, watchlist)
88-
89-
print('start running example of build DMatrix from numpy array')
90-
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix
91-
# in internal implementation then convert to DMatrix
92-
npymat = csr.todense()
93-
dtrain = xgb.DMatrix(npymat, label=labels)
94-
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
95-
bst = xgb.train(param, dtrain, num_round, watchlist)

Diff for: doc/python/python_intro.rst

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ including:
4545
- XGBoost binary buffer file.
4646
- LIBSVM text format file
4747
- Comma-separated values (CSV) file
48+
- Arrow table.
4849

4950
(See :doc:`/tutorials/input_format` for detailed description of text input format.)
5051

Diff for: python-package/xgboost/core.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -565,12 +565,14 @@ def __init__(
565565
"""Parameters
566566
----------
567567
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
568-
dt.Frame/cudf.DataFrame/cupy.array/dlpack
568+
dt.Frame/cudf.DataFrame/cupy.array/dlpack/arrow.Table
569+
569570
Data source of DMatrix.
570-
When data is string or os.PathLike type, it represents the path
571-
libsvm format txt file, csv file (by specifying uri parameter
572-
'path_to_csv?format=csv'), or binary file that xgboost can read
573-
from.
571+
572+
When data is string or os.PathLike type, it represents the path libsvm
573+
format txt file, csv file (by specifying uri parameter
574+
'path_to_csv?format=csv'), or binary file that xgboost can read from.
575+
574576
label : array_like
575577
Label of the training data.
576578
weight : array_like

0 commit comments

Comments
 (0)