forked from guoyang9/BPR-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
113 lines (93 loc) · 3.83 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import pandas as pd
import pickle
import scipy.sparse as sp
import random
import torch.utils.data as data
import config
def read_pickle(path):
with open(path, 'rb') as f:
ret = pickle.load(f)
return ret
def write_pickle(path, data):
with open(path, 'wb') as fw:
pickle.dump(data, fw);
def load_all():
""" We load all the three files here to save time in each epoch. """
train_data = read_pickle(config.train_data)
users = read_pickle(config.user_data)
items = read_pickle(config.item_data)
user_num = len(users)
item_num = len(items)
train_data = [[data['user_id'], data['business_id']] for data in train_data]
# load ratings as a dok matrix
# train_mat = sp.dok_matrix((user_num, item_num), dtype=np.float32)
# for x in train_data:
# train_mat[x[0], x[1]] = 1.0
train_mat = [[] for u in range(user_num)]
for x in train_data:
train_mat[x[0]].append(x[1])
test_data = read_pickle(config.test_negative)
gt_items = {entry['user_id']:entry['pos_business_id'] for entry in test_data}
return train_data, test_data, train_mat, user_num, item_num
class BPRData(data.Dataset):
def __init__(self, features, num_user,
num_item, train_mat=None, num_ng=0, is_training=None):
super(BPRData, self).__init__()
""" Note that the labels are only useful when training, we thus
add them in the ng_sample() function.
features are different in training and test
"""
self.features = features
self.num_user = num_user
self.num_item = num_item
self.train_mat = train_mat
self.num_ng = num_ng
self.is_training = is_training
# self.user_neg_dict = {u:set(range(num_item)) for u in range(num_user)}
# user_pos_dict = {u:set() for u in range(num_user)}
# if self.is_training:
# for x in features:
# user_pos_dict[x[0]].add(x[1])
# for u in user_pos_dict.keys():
# self.user_neg_dict[u] = list(self.user_neg_dict[u] - user_pos_dict[u])
if not self.is_training:
self.data = []
for input in features:
pos = [i for i in input['pos_business_id']]
neg = [i for i in input['neg_business_id']]
items = pos + neg
user = [input['user_id']] * len(items)
labels = [1] * len(pos) + [0] * len(neg)
self.data.append([np.asarray(user), np.asarray(items), np.asarray(labels)])
def ng_sample(self):
assert self.is_training, 'no need to sampling when testing'
self.features_fill = []
for x in self.features:
u, i = x[0], x[1]
if self.num_ng > 0:
for t in range(self.num_ng):
j = np.random.randint(self.num_item)
# while (u, j) in self.train_mat:
while j in self.train_mat[u]:
j = np.random.randint(self.num_item)
self.features_fill.append([u, i, j])
elif self.num_ng == 0:
j = random.choice(self.train_mat[u])
self.features_fill.append([u, i, j])
def __len__(self):
num_samples = self.num_ng if self.num_ng > 0 else 1
return num_samples * len(self.features) if \
self.is_training else len(self.features)
def __getitem__(self, idx):
if self.is_training:
features = self.features_fill
user = features[idx][0]
item_i = features[idx][1]
item_j = features[idx][2]
return user, item_i, item_j
else:
user_ids = self.data[idx][0]
item_ids = self.data[idx][1]
labels = self.data[idx][2]
return user_ids, item_ids, labels