-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
80 lines (67 loc) · 3.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import tensorflow as tf
from models.helpers import get_id_feature
from utils.eval_metrics import create_evaluation_metrics
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
def create_train_op(loss, hparams):
'''
Function used to train the model
:param loss: loss function to evaluate the error
:param hparams: hiper-parameters used to configure the optimizer
:return: training updates
'''
train_op = tf.contrib.layers.optimize_loss(
loss=loss, # loss function used
global_step=tf.contrib.framework.get_global_step(), # number of batches seen so far
learning_rate=hparams.learning_rate, # learning rate
clip_gradients=10.0, # clip gradient to a max value
optimizer=hparams.optimizer) # optimizer used
return train_op
def create_model_fn(hparams, model_impl):
'''
Function used to create the model according different implementations and usage mode
:param hparams: hiper-parameters used to configure the model
:param model_impl: implementation of the model used, have to use the same interface to inject a different model
:return: probabilities of the predicted class, value of the loss function, operation to execute the training
'''
def model_fn(features, targets, mode):
context, context_len = get_id_feature(features, "context", "context_len")
utterance, utterance_len = get_id_feature(features, "utterance", "utterance_len")
if mode == tf.contrib.learn.ModeKeys.TRAIN:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
targets)
train_op = create_train_op(loss, hparams)
return probs, loss, train_op
if mode == tf.contrib.learn.ModeKeys.INFER:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
None)
return probs, 0.0, None
if mode == tf.contrib.learn.ModeKeys.EVAL:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
targets)
split_probs = tf.split(probs, num_or_size_splits=10, axis=0) # split the probabilities between the first positive utterance and the following 9 negative distractors
shaped_probs = tf.concat(split_probs, axis=1) # matrix with shape(?, 10) for each example we have the probability of each response
# Add summaries
tf.summary.histogram("eval_correct_probs_hist", split_probs[0])
tf.summary.scalar("eval_correct_probs_average", tf.reduce_mean(split_probs[0]))
tf.summary.histogram("eval_incorrect_probs_hist", split_probs[1])
tf.summary.scalar("eval_incorrect_probs_average", tf.reduce_mean(split_probs[1]))
return shaped_probs, loss, None
return model_fn