|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import sys |
| 6 | +import pickle |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from six.moves import xrange # python2/3 compatible |
| 10 | +import tensorflow as tf |
| 11 | +import string |
| 12 | +import scipy |
| 13 | +import scipy.sparse as sparse |
| 14 | +import os |
| 15 | + |
| 16 | +# import code of this project |
| 17 | +sys.path.insert(0, '../util/') |
| 18 | +from util import config_to_name |
| 19 | +sys.path.insert(0, '../model/') |
| 20 | +from embedding import fit_emb |
| 21 | +from embedding import evaluate_emb |
| 22 | +from embedding import dense_array_feeder |
| 23 | +from embedding import sparse_array_feeder |
| 24 | +from random_data import rand_data |
| 25 | + |
| 26 | +def embedding_experiment(config, dataset): |
| 27 | + np.random.seed(seed=27) |
| 28 | + |
| 29 | + ## Step 1: load data |
| 30 | + print('Generating a dataset ...') |
| 31 | + |
| 32 | + data = rand_data() # the training/test dataset generated by rand_data has two fields, but only 'scores' are needed here |
| 33 | + |
| 34 | + trainset = data['trainset']['scores'] |
| 35 | + testset = data['testset']['scores'] |
| 36 | + |
| 37 | + """ |
| 38 | + trainset: scores: a sparse matrix, each ij entry is the rating of movie j given by person i, or the count of item j in basket i |
| 39 | + testset: [same structure as trainset] |
| 40 | + """ |
| 41 | + |
| 42 | + # one can always redefine zie.generate_batch(reviews, rind) to use other format of trainset and testset |
| 43 | + |
| 44 | + print('The training set has %d rows and %d columns, and the test set has %d rows' % |
| 45 | + (trainset.shape[0], trainset.shape[1], testset.shape[0])) |
| 46 | + |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | + # batch_feeder is a function, which will be executed as batch_feeder(trainset[i]) |
| 51 | + # its output will be fed into tf place holders |
| 52 | + batch_feeder = sparse_array_feeder |
| 53 | + |
| 54 | + # fit an emb model |
| 55 | + print('Training set has size: ', trainset.shape) |
| 56 | + emb_model, logg = fit_emb(trainset, batch_feeder, config) |
| 57 | + print('Training done!') |
| 58 | + |
| 59 | + print('Test set has size: ', testset.shape) |
| 60 | + test_llh = evaluate_emb(testset, batch_feeder, emb_model, config) |
| 61 | + print('Testing done!') |
| 62 | + |
| 63 | + # Save result |
| 64 | + print('Check result...') |
| 65 | + emb_vec = emb_model['alpha'] |
| 66 | + print('Embedding matrix has shape ', emb_vec.shape) |
| 67 | + # Save wherever you want |
| 68 | + |
| 69 | + print('Done!') |
| 70 | + |
| 71 | +if __name__ == '__main__': |
| 72 | + |
| 73 | + dataset = 'random' |
| 74 | + dist = 'poisson' |
| 75 | + max_iter = 500 |
| 76 | + nprint = 100 |
| 77 | + |
| 78 | + config = dict( |
| 79 | + # the dimensionality of the embedding vectors |
| 80 | + K=50, |
| 81 | + # the embedding distribution 'poisson' or 'binomial' (N=3) |
| 82 | + dist=dist, |
| 83 | + # ratio of negative samples. if there are N0 zeros in one row, only sample (0.1 * N0) from these zero, |
| 84 | + # it is equivalent to downweight zero-targets with weight 0.1 |
| 85 | + neg_ratio=0.1, |
| 86 | + # number of optimization iterations |
| 87 | + max_iter=max_iter, |
| 88 | + # number of iterations to print objective, training log-likelihood, and validation log-likelihood, and debug values |
| 89 | + nprint=nprint, |
| 90 | + # weight for regularization terms of embedding vectors |
| 91 | + ar_sigma2=1, |
| 92 | + # uncomment the following line to use the base model |
| 93 | + #model='base', |
| 94 | + # uncomment the following line to use context selection. Only the prior 'fixed_bern' works for now |
| 95 | + model='context_select', prior='fixed_bern', nsample=30, hidden_size=[30, 15], histogram_size=40, nsample_test=1000, selsize=10, |
| 96 | + ) |
| 97 | + |
| 98 | + print('The configuration is: ') |
| 99 | + print(config) |
| 100 | + |
| 101 | + embedding_experiment(config, dataset) |
| 102 | + |
| 103 | + |
| 104 | + |
0 commit comments