-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_manager.py
143 lines (129 loc) · 4.49 KB
/
data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file data_manager.py
@author libo
@date 2017.12.4
This module processes input data, splits it into train set and test set, generates Tag Genome DataFrame(removing the not included
movie items.
"""
import pickle
import numpy as np
import random
import pandas as pd
class Data_Factory():
def read_rating(self, path):
"""
@brief: process input data
"""
user_movie_ratings = []
with open(path, 'r') as f:
for line in f:
record = line.split("::")
user_movie_ratings.append((int(record[0]), int(record[1]), float(record[2])))
return user_movie_ratings
def generate_train_test_file(self, R, ratio):
"""
@brief: split input data into train set and test set
"""
n = np.array(R).shape[0]
test_n = int(n * ratio)
random.shuffle(R)
test_ratings = R[0 : test_n]
train_ratings = R[test_n : ]
print("train size is %d" % len(train_ratings))
user = {}
movie = {}
for i in range(len(train_ratings)):
user.setdefault(train_ratings[i][0], 0)
movie.setdefault(train_ratings[i][1], 0)
print("user number in train set is %d" % len(user))
print("movie number in train set is %d" % len(movie))
print("test size is %d" % len(test_ratings))
user = {}
movie = {}
for i in range(len(test_ratings)):
user.setdefault(test_ratings[i][0], 0)
movie.setdefault(test_ratings[i][1], 0)
print("user number in test set is %d" % len(user))
print("movie number in test set is %d" % len(movie))
return train_ratings, test_ratings
def save(self, data, path):
"""
@brief: save data set
"""
pickle.dump(data, open(path, 'wb'))
return
def load(self, path):
"""
@brief: load data set
"""
data = pickle.load(open(path, 'rb'))
return data
def generate_genome(self):
"""
@brief: generate Tag Genome DataFrame
"""
df_genome = pd.read_csv('./data/ml-20m/genome-scores.csv', ',')
df_genome = df_genome.set_index(['movieId', 'tagId']).unstack(level=-1)
mid = []
ori_mid = []
exclude = []
mid_map = {}
with open('./data/ml-1m/mid_map', 'r') as f:
for line in f:
line = line.rstrip().split('\t')
try:
corr_mid = int(line[-1])
mid_map[int(line[0])] = corr_mid
except:
exclude.append(int(line[0]))
with open('./data/ml-1m/movies.dat', 'rb') as f:
for line in f:
line = line.decode('latin1').split("::")
cur_mid = int(line[0])
if cur_mid not in exclude:
ori_mid.append(cur_mid)
if cur_mid in mid_map:
mid.append(mid_map[cur_mid])
else:
mid.append(cur_mid)
df_train = df_genome.loc[mid]
df_train['ori_mid'] = ori_mid
df_train.dropna(axis=0, how='any', inplace=True)
df_train.index = df_train['ori_mid']
df_train = df_train.drop('ori_mid', axis=1)
return df_train
def generate_train_valid_test_file_with_remove(self, data, ratio, df_train):
"""
@biref: remove data items whose movie is not included in Tag Genome Data
"""
R = []
for i in range(len(data)):
if data[i][1] in df_train.index:
R.append(data[i])
n = np.array(R).shape[0]
test_n = int(n * ratio)
random.shuffle(R)
test_ratings = R[0:test_n]
train_ratings = R[test_n:]
return train_ratings, test_ratings
if __name__ == '__main__':
a = Data_Factory()
R = a.read_rating('./data/ml-1m/ratings.dat')
print("original ratings' size is %d" % len(R))
df_train = a.generate_genome()
R_remove = []
user_dict = {}
movie_dict = {}
for i in range(len(R)):
if R[i][1] in df_train.index:
user_dict[R[i][0]] = 1
movie_dict[R[i][1]] = 1
R_remove.append(R[i])
print("users' size is %d" % len(user_dict))
print("items' size is %d" % len(movie_dict))
train, test = a.generate_train_test_file(R_remove, 0.25)
a.save(train, './data/ml-1m/0.25/train.dat')
a.save(test, './data/ml-1m/0.25/test.dat')
pass