Skip to content

Commit

Permalink
Merge branch 'main' into rr-task-cls-rename
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-ranjan authored Jul 22, 2024
2 parents 8cabbf7 + 26aa7e8 commit 6ff7d5c
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 22 deletions.
7 changes: 1 addition & 6 deletions examples/gnn_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
loss_fn = BCEWithLogitsLoss()
tune_metric = "roc_auc"
higher_is_better = True
multilabel = False
elif task.task_type == TaskType.REGRESSION:
out_channels = 1
loss_fn = L1Loss()
Expand All @@ -93,22 +92,18 @@
clamp_min, clamp_max = np.percentile(
train_table.df[task.target_col].to_numpy(), [2, 98]
)
multilabel = False
elif task.task_type == TaskType.MULTILABEL_CLASSIFICATION:
out_channels = task.num_labels
loss_fn = BCEWithLogitsLoss()
tune_metric = "multilabel_auprc_macro"
higher_is_better = True
multilabel = True
else:
raise ValueError(f"Task type {task.task_type} is unsupported")

loader_dict: Dict[str, NeighborLoader] = {}
for split in ["train", "val", "test"]:
table = task.get_table(split)
table_input = get_node_train_table_input(
table=table, task=task, multilabel=multilabel
)
table_input = get_node_train_table_input(table=table, task=task)
entity_table = table_input.nodes[0]
loader_dict[split] = NeighborLoader(
data,
Expand Down
9 changes: 2 additions & 7 deletions examples/hybrid_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"--sample_size",
type=int,
default=50_000,
help="Subsample the specified number of training data to train lightgbm model.",
help="Subsample the specified number of training data to train LightGBM model.",
)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
Expand Down Expand Up @@ -94,7 +94,6 @@
loss_fn = BCEWithLogitsLoss()
tune_metric = "roc_auc"
higher_is_better = True
multilabel = False
elif task.task_type == TaskType.REGRESSION:
out_channels = 1
loss_fn = L1Loss()
Expand All @@ -104,20 +103,16 @@
clamp_min, clamp_max = np.percentile(
task.get_table("train").df[task.target_col].to_numpy(), [2, 98]
)
multilabel = False
elif task.task_type == TaskType.MULTILABEL_CLASSIFICATION:
out_channels = task.num_labels
loss_fn = BCEWithLogitsLoss()
tune_metric = "multilabel_auprc_macro"
higher_is_better = True
multilabel = True

loader_dict: Dict[str, NeighborLoader] = {}
for split in ["train", "val", "test"]:
table = task.get_table(split)
table_input = get_node_train_table_input(
table=table, task=task, multilabel=multilabel
)
table_input = get_node_train_table_input(table=table, task=task)
entity_table = table_input.nodes[0]
loader_dict[split] = NeighborLoader(
data,
Expand Down
8 changes: 4 additions & 4 deletions examples/lightgbm_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
json.dump(col_to_stype_dict, f, indent=2, default=str)


# Prepare col_to_stype dictioanry mapping between column names and stypes
# Prepare col_to_stype dictionary mapping between column names and stypes
# for torch_frame Dataset initialization.
col_to_stype = {}
src_entity_table_col_to_stype = copy.deepcopy(col_to_stype_dict[task.src_entity_table])
Expand Down Expand Up @@ -318,7 +318,7 @@ def interleave_lists(list1, list2):
return evaluate_table_df


# Prepare val dataset for lightGBM model evalution
# Prepare val dataset for lightGBM model evaluation
val_df_pred_column_names = list(val_table.df.columns)
val_df_pred_column_names.remove(dst_entity)
val_df_pred = val_table.df[val_df_pred_column_names]
Expand All @@ -328,7 +328,7 @@ def interleave_lists(list1, list2):
val_df_pred = prepare_for_link_pred_eval(val_df_pred, val_past_table_df)
dfs["val_pred"] = val_df_pred

# Prepare test dataset for lightGBM model evalution
# Prepare test dataset for lightGBM model evaluation
test_df_column_names = list(test_table.df.columns)
test_df_column_names.remove(dst_entity)
test_df = test_table.df[test_df_column_names]
Expand Down Expand Up @@ -419,7 +419,7 @@ def adjust_past_dst_entities(values):
return metrics


# NOTE: train/val metrics will be artifically high since all true links are
# NOTE: train/val metrics will be artificially high since all true links are
# included in the candidate set
pred = model.predict(tf_test=tf_train).numpy()
lightgbm_output = dfs["train"]
Expand Down
10 changes: 10 additions & 0 deletions relbench/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@
from .task_base import BaseTask, TaskType
from .task_link import RecommendationTask
from .task_node import EntityTask

__all__ = [
"Database",
"Dataset",
"Table",
"BaseTask",
"TaskType",
"RecommendationTask",
"EntityTask",
]
7 changes: 3 additions & 4 deletions relbench/modeling/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __call__(self, batch: HeteroData) -> HeteroData:


class NodeTrainTableInput(NamedTuple):
r"""Trainining table input for node prediction.
r"""Training table input for node prediction.
- nodes is a Tensor of node indices.
- time is a Tensor of node timestamps.
Expand All @@ -147,7 +147,6 @@ class NodeTrainTableInput(NamedTuple):
def get_node_train_table_input(
table: Table,
task: EntityTask,
multilabel: bool = False,
) -> NodeTrainTableInput:
r"""Get the training table input for node prediction."""

Expand All @@ -161,7 +160,7 @@ def get_node_train_table_input(
transform: Optional[AttachTargetTransform] = None
if task.target_col in table.df:
target_type = float
if task.task_type == "multiclass_classification":
if task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
target_type = int
if task.task_type == TaskType.MULTILABEL_CLASSIFICATION:
target = torch.from_numpy(np.stack(table.df[task.target_col].values))
Expand All @@ -180,7 +179,7 @@ def get_node_train_table_input(


class LinkTrainTableInput(NamedTuple):
r"""Trainining table input for link prediction.
r"""Training table input for link prediction.
- src_nodes is a Tensor of source node indices.
- dst_nodes is PyTorch sparse tensor in csr format.
Expand Down
2 changes: 1 addition & 1 deletion relbench/modeling/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def forward(
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for i, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
x_dict = {key: x.relu() for key, x in x_dict.items()}
Expand Down

0 comments on commit 6ff7d5c

Please sign in to comment.