Skip to content

Commit

Permalink
Fix bugs in transform
Browse files Browse the repository at this point in the history
  • Loading branch information
hyp1231 committed Nov 29, 2023
1 parent 967763d commit 05aa5cb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@


def construct_transform(config):
if config['transform'] is None:
if config['unisrec_transform'] is None:
logger = getLogger()
logger.warning('Equal transform')
return Equal(config)
else:
str2transform = {
'plm_emb': PLMEmb
}
return str2transform[config['transform']](config)
return str2transform[config['unisrec_transform']](config)


class Equal:
def __init__(self, config):
pass

def __call__(self, dataloader, interaction):
def __call__(self, dataset, interaction):
return interaction


Expand All @@ -31,13 +31,13 @@ def __init__(self, config):
self.item_drop_ratio = config['item_drop_ratio']
self.item_drop_coefficient = config['item_drop_coefficient']

def __call__(self, dataloader, interaction):
def __call__(self, dataset, interaction):
'''Sequence augmentation and PLM embedding fetching
'''
item_seq_len = interaction['item_length']
item_seq = interaction['item_id_list']

plm_embedding = dataloader.dataset.plm_embedding
plm_embedding = dataset.plm_embedding
item_emb_seq = plm_embedding(item_seq)
pos_item_id = interaction['item_id']
pos_item_emb = plm_embedding(pos_item_id)
Expand All @@ -59,7 +59,7 @@ def __call__(self, dataloader, interaction):
item_emb_seq_aug = plm_embedding(item_seq_aug)
else:
# Word drop
plm_embedding_aug = dataloader.dataset.plm_embedding_aug
plm_embedding_aug = dataset.plm_embedding_aug
full_item_emb_seq_aug = plm_embedding_aug(item_seq)

item_seq_aug = item_seq
Expand Down
2 changes: 1 addition & 1 deletion props/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ alias_of_item_id: [item_id_list]
load_col:
inter: [user_id, item_id_list, item_id]
train_neg_sample_args: ~
transform: ~
unisrec_transform: ~

topk: [10, 50]
metrics: [HIT, NDCG]
Expand Down
4 changes: 2 additions & 2 deletions props/pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ benchmark_filename: [train]
alias_of_item_id: [item_id_list]
load_col:
inter: [user_id, item_id_list, item_id]
neg_sampling: ~
transform: plm_emb
train_neg_sample_args: ~
unisrec_transform: plm_emb

train_stage: pretrain
pretrain_epochs: 300
Expand Down

0 comments on commit 05aa5cb

Please sign in to comment.