Skip to content

Commit

Permalink
change cli and README
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 13, 2024
1 parent 227e086 commit 8204026
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ import pandas as pd
from rektgbm import RektDataset, RektGBM, RektOptimizer

# Prepare your datasets
X_train = pd.read_csv("train.csv")
X_test = pd.read_csv("test.csv")
y_train = X_train.pop("target")
train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")
y_train = train.pop("target")

dtrain = RektDataset(data=X_train, label=y_train)
dtest = RektDataset(data=X_test, reference=dtrain)
dtrain = RektDataset(data=train, label=y_train)
dtest = RektDataset(data=test, reference=dtrain)

# Initialize RektOptimizer to automatically detect task type, objective, and metric
rekt_optimizer = RektOptimizer()
Expand Down
4 changes: 4 additions & 0 deletions rektgbm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def main(
raise ValueError(
f"The specified target column '{target}' does not exist in the training data."
)

if target in test_data.columns:
test_data.pop(target)

train_label = train_data.pop(target)
dtrain = RektDataset(data=train_data, label=train_label)
dtest = RektDataset(data=test_data, reference=dtrain)
Expand Down

0 comments on commit 8204026

Please sign in to comment.