Skip to content

Commit 1c911da

Browse files
Sklearn api x (#405)
* changed signature of automl.predict and automl.predict_proba to X * XGBoostEstimator * changed signature of Prophet predict to X * changed signature of ARIMA predict to X * changed signature of TS_SKLearn_Regressor predict to X
1 parent a6d70ef commit 1c911da

File tree

2 files changed

+52
-56
lines changed

2 files changed

+52
-56
lines changed

flaml/automl.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -714,13 +714,11 @@ def time_to_find_best_model(self) -> float:
714714
"""Time taken to find best model in seconds."""
715715
return self.__dict__.get("_time_taken_best_iter")
716716

717-
def predict(
718-
self, X_test: Union[np.array, pd.DataFrame, List[str], List[List[str]]]
719-
):
717+
def predict(self, X: Union[np.array, pd.DataFrame, List[str], List[List[str]]]):
720718
"""Predict label from features.
721719
722720
Args:
723-
X_test: A numpy array of featurized instances, shape n * m,
721+
X: A numpy array of featurized instances, shape n * m,
724722
or for 'ts_forecast' task:
725723
a pandas dataframe with the first column containing
726724
timestamp values (datetime type) or an integer n for
@@ -748,8 +746,8 @@ def predict(
748746
"No estimator is trained. Please run fit with enough budget."
749747
)
750748
return None
751-
X_test = self._preprocess(X_test)
752-
y_pred = estimator.predict(X_test)
749+
X = self._preprocess(X)
750+
y_pred = estimator.predict(X)
753751
if (
754752
isinstance(y_pred, np.ndarray)
755753
and y_pred.ndim > 1
@@ -763,12 +761,12 @@ def predict(
763761
else:
764762
return y_pred
765763

766-
def predict_proba(self, X_test):
764+
def predict_proba(self, X):
767765
"""Predict the probability of each class from features, only works for
768766
classification problems.
769767
770768
Args:
771-
X_test: A numpy array of featurized instances, shape n * m.
769+
X: A numpy array of featurized instances, shape n * m.
772770
773771
Returns:
774772
A numpy array of shape n * c. c is the # classes. Each element at
@@ -780,8 +778,8 @@ def predict_proba(self, X_test):
780778
"No estimator is trained. Please run fit with enough budget."
781779
)
782780
return None
783-
X_test = self._preprocess(X_test)
784-
proba = self._trained_estimator.predict_proba(X_test)
781+
X = self._preprocess(X)
782+
proba = self._trained_estimator.predict_proba(X)
785783
return proba
786784

787785
def _preprocess(self, X):

flaml/model.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,32 @@ def fit(self, X_train, y_train, budget=None, **kwargs):
197197
train_time = self._fit(X_train, y_train, **kwargs)
198198
return train_time
199199

200-
def predict(self, X_test):
200+
def predict(self, X):
201201
"""Predict label from features.
202202
203203
Args:
204-
X_test: A numpy array or a dataframe of featurized instances, shape n*m.
204+
X: A numpy array or a dataframe of featurized instances, shape n*m.
205205
206206
Returns:
207207
A numpy array of shape n*1.
208208
Each element is the label for a instance.
209209
"""
210210
if self._model is not None:
211-
X_test = self._preprocess(X_test)
212-
return self._model.predict(X_test)
211+
X = self._preprocess(X)
212+
return self._model.predict(X)
213213
else:
214214
logger.warning(
215215
"Estimator is not fit yet. Please run fit() before predict()."
216216
)
217-
return np.ones(X_test.shape[0])
217+
return np.ones(X.shape[0])
218218

219-
def predict_proba(self, X_test):
219+
def predict_proba(self, X):
220220
"""Predict the probability of each class from features.
221221
222222
Only works for classification problems
223223
224224
Args:
225-
X_test: A numpy array of featurized instances, shape n*m.
225+
X: A numpy array of featurized instances, shape n*m.
226226
227227
Returns:
228228
A numpy array of shape n*c. c is the # classes.
@@ -231,8 +231,8 @@ class j.
231231
"""
232232
assert self._task in CLASSIFICATION, "predict_proba() only for classification."
233233

234-
X_test = self._preprocess(X_test)
235-
return self._model.predict_proba(X_test)
234+
X = self._preprocess(X)
235+
return self._model.predict_proba(X)
236236

237237
def cleanup(self):
238238
del self._model
@@ -708,18 +708,18 @@ def _init_model_for_predict(self, X_test):
708708
)
709709
return test_dataset, training_args
710710

711-
def predict_proba(self, X_test):
711+
def predict_proba(self, X):
712712
assert (
713713
self._task in CLASSIFICATION
714714
), "predict_proba() only for classification tasks."
715715

716-
test_dataset, _ = self._init_model_for_predict(X_test)
716+
test_dataset, _ = self._init_model_for_predict(X)
717717
predictions = self._trainer.predict(test_dataset)
718718
self._trainer = None
719719
return predictions.predictions
720720

721-
def predict(self, X_test):
722-
test_dataset, training_args = self._init_model_for_predict(X_test)
721+
def predict(self, X):
722+
test_dataset, training_args = self._init_model_for_predict(X)
723723
if self._task not in NLG_TASKS:
724724
predictions = self._trainer.predict(test_dataset)
725725
else:
@@ -1108,12 +1108,12 @@ def fit(self, X_train, y_train, budget=None, **kwargs):
11081108
train_time = time.time() - start_time
11091109
return train_time
11101110

1111-
def predict(self, X_test):
1111+
def predict(self, X):
11121112
import xgboost as xgb
11131113

1114-
if not issparse(X_test):
1115-
X_test = self._preprocess(X_test)
1116-
dtest = xgb.DMatrix(X_test)
1114+
if not issparse(X):
1115+
X = self._preprocess(X)
1116+
dtest = xgb.DMatrix(X)
11171117
return super().predict(dtest)
11181118

11191119
@classmethod
@@ -1598,22 +1598,22 @@ def fit(self, X_train, y_train, budget=None, **kwargs):
15981598
self._model = model
15991599
return train_time
16001600

1601-
def predict(self, X_test):
1602-
if isinstance(X_test, int):
1601+
def predict(self, X):
1602+
if isinstance(X, int):
16031603
raise ValueError(
16041604
"predict() with steps is only supported for arima/sarimax."
16051605
" For Prophet, pass a dataframe with the first column containing"
16061606
" the timestamp values."
16071607
)
16081608
if self._model is not None:
1609-
X_test = self._preprocess(X_test)
1610-
forecast = self._model.predict(X_test)
1609+
X = self._preprocess(X)
1610+
forecast = self._model.predict(X)
16111611
return forecast["yhat"]
16121612
else:
16131613
logger.warning(
16141614
"Estimator is not fit yet. Please run fit() before predict()."
16151615
)
1616-
return np.ones(X_test.shape[0])
1616+
return np.ones(X.shape[0])
16171617

16181618

16191619
class ARIMA(Prophet):
@@ -1678,30 +1678,30 @@ def fit(self, X_train, y_train, budget=None, **kwargs):
16781678
self._model = model
16791679
return train_time
16801680

1681-
def predict(self, X_test):
1681+
def predict(self, X):
16821682
if self._model is not None:
1683-
if isinstance(X_test, int):
1684-
forecast = self._model.forecast(steps=X_test)
1685-
elif isinstance(X_test, DataFrame):
1686-
start = X_test[TS_TIMESTAMP_COL].iloc[0]
1687-
end = X_test[TS_TIMESTAMP_COL].iloc[-1]
1688-
if len(X_test.columns) > 1:
1689-
X_test = self._preprocess(X_test.drop(columns=TS_TIMESTAMP_COL))
1690-
regressors = list(X_test)
1691-
print(start, end, X_test.shape)
1683+
if isinstance(X, int):
1684+
forecast = self._model.forecast(steps=X)
1685+
elif isinstance(X, DataFrame):
1686+
start = X[TS_TIMESTAMP_COL].iloc[0]
1687+
end = X[TS_TIMESTAMP_COL].iloc[-1]
1688+
if len(X.columns) > 1:
1689+
X = self._preprocess(X.drop(columns=TS_TIMESTAMP_COL))
1690+
regressors = list(X)
1691+
print(start, end, X.shape)
16921692
forecast = self._model.predict(
1693-
start=start, end=end, exog=X_test[regressors]
1693+
start=start, end=end, exog=X[regressors]
16941694
)
16951695
else:
16961696
forecast = self._model.predict(start=start, end=end)
16971697
else:
16981698
raise ValueError(
1699-
"X_test needs to be either a pandas Dataframe with dates as the first column"
1699+
"X needs to be either a pandas Dataframe with dates as the first column"
17001700
" or an int number of periods for predict()."
17011701
)
17021702
return forecast
17031703
else:
1704-
return np.ones(X_test if isinstance(X_test, int) else X_test.shape[0])
1704+
return np.ones(X if isinstance(X, int) else X.shape[0])
17051705

17061706

17071707
class SARIMAX(ARIMA):
@@ -1873,42 +1873,40 @@ def fit(self, X_train, y_train, budget=None, **kwargs):
18731873
train_time = time.time() - current_time
18741874
return train_time
18751875

1876-
def predict(self, X_test):
1876+
def predict(self, X):
18771877
if self._model is not None:
1878-
X_test = self.transform_X(X_test)
1879-
X_test = self._preprocess(X_test)
1878+
X = self.transform_X(X)
1879+
X = self._preprocess(X)
18801880
if isinstance(self._model, list):
18811881
assert len(self._model) == len(
1882-
X_test
1883-
), "Model is optimized for horizon, length of X_test must be equal to `period`."
1882+
X
1883+
), "Model is optimized for horizon, length of X must be equal to `period`."
18841884
preds = []
18851885
for i in range(1, len(self._model) + 1):
18861886
(
18871887
X_pred,
18881888
_,
18891889
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(
1890-
X_test.iloc[:i, :]
1890+
X.iloc[:i, :]
18911891
)
18921892
preds.append(self._model[i - 1].predict(X_pred)[-1])
18931893
forecast = DataFrame(
18941894
data=np.asarray(preds).reshape(-1, 1),
18951895
columns=[self.hcrystaball_model.name],
1896-
index=X_test.index,
1896+
index=X.index,
18971897
)
18981898
else:
18991899
(
19001900
X_pred,
19011901
_,
1902-
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(
1903-
X_test
1904-
)
1902+
) = self.hcrystaball_model._transform_data_to_tsmodel_input_format(X)
19051903
forecast = self._model.predict(X_pred)
19061904
return forecast
19071905
else:
19081906
logger.warning(
19091907
"Estimator is not fit yet. Please run fit() before predict()."
19101908
)
1911-
return np.ones(X_test.shape[0])
1909+
return np.ones(X.shape[0])
19121910

19131911

19141912
class LGBM_TS_Regressor(TS_SKLearn_Regressor):

0 commit comments

Comments
 (0)