Skip to content

Commit

Permalink
Merge branch 'main' into joshr17-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-ranjan authored Jul 30, 2024
2 parents d2feae6 + 2495e12 commit 7371985
Showing 1 changed file with 0 additions and 49 deletions.
49 changes: 0 additions & 49 deletions examples/lightgbm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
parser.add_argument("--dataset", type=str, default="rel-stack")
parser.add_argument("--task", type=str, default="user-engage")
parser.add_argument("--num_trials", type=int, default=10)
# Use auto-regressive label as hand-crafted feature as input to LightGBM
parser.add_argument("--use_ar_label", action="store_true", default=False)
parser.add_argument("--no-use_ar_label", dest="use_ar_label", action="store_false")
parser.add_argument(
"--sample_size",
type=int,
Expand All @@ -54,46 +51,6 @@
val_table = task.get_table("val")
test_table = task.get_table("test")

ar_label_cols = []

if args.use_ar_label:
### Adding AR labels into train/val/test_table
whole_df = pd.concat([train_table.df, val_table.df, test_table.df], axis=0)
num_ar_labels = max(train_table.df[train_table.time_col].nunique() - 2, 1)

sorted_unique_times = np.sort(whole_df[train_table.time_col].unique())
timedelta = sorted_unique_times[1:] - sorted_unique_times[:-1]
if (timedelta / timedelta[0] - 1).max() > 0.1:
raise RuntimeError(
"Timestamps are not equally spaced, making it inappropriate for "
"AR labels to be used."
)
TIME_IDX_COL = "time_idx"
time_df = pd.DataFrame(
{
task.time_col: sorted_unique_times,
"time_idx": np.arange(len(sorted_unique_times)),
}
)

whole_df = whole_df.merge(time_df, how="left", on=task.time_col)
whole_df.drop(task.time_col, axis=1, inplace=True)
# Shift timestamp of whole_df iteratively and join it with train/val/test_table
for i in range(1, num_ar_labels + 1):
whole_df_shifted = whole_df.copy(deep=True)
# Shift time index by i
whole_df_shifted[TIME_IDX_COL] += i
# Map time index back to datetime timestamp
whole_df_shifted = whole_df_shifted.merge(time_df, how="inner", on=TIME_IDX_COL)
whole_df_shifted.drop(TIME_IDX_COL, axis=1, inplace=True)
ar_label = f"AR_{i}"
ar_label_cols.append(ar_label)
whole_df_shifted.rename(columns={task.target_col: ar_label}, inplace=True)

for table in [train_table, val_table, test_table]:
table.df = table.df.merge(
whole_df_shifted, how="left", on=(task.entity_col, task.time_col)
)

dfs: Dict[str, pd.DataFrame] = {}
entity_table = dataset.get_db().table_dict[task.entity_table]
Expand All @@ -117,16 +74,10 @@

if task.task_type == TaskType.BINARY_CLASSIFICATION:
col_to_stype[task.target_col] = torch_frame.categorical
for ar_label in ar_label_cols:
col_to_stype[ar_label] = torch_frame.categorical
elif task.task_type == TaskType.REGRESSION:
col_to_stype[task.target_col] = torch_frame.numerical
for ar_label in ar_label_cols:
col_to_stype[ar_label] = torch_frame.numerical
elif task.task_type == TaskType.MULTILABEL_CLASSIFICATION:
col_to_stype[task.target_col] = torch_frame.embedding
for ar_label in ar_label_cols:
col_to_stype[ar_label] = torch_frame.embedding
else:
raise ValueError(f"Unsupported task type called {task.task_type}")

Expand Down

0 comments on commit 7371985

Please sign in to comment.