diff --git a/deepmatch/utils.py b/deepmatch/utils.py index becdc9f..a482b1c 100644 --- a/deepmatch/utils.py +++ b/deepmatch/utils.py @@ -50,8 +50,8 @@ def l2_normalize(x, axis=-1): return Lambda(lambda x: tf.nn.l2_normalize(x, axis))(x) -def inner_product(x, y, temperature=1.0): - return Lambda(lambda x: tf.reduce_sum(tf.multiply(x[0], x[1])) / temperature)([x, y]) +def inner_product(x, y, temperature=1.0, axis=None, keepdims=False): + return Lambda(lambda x: tf.reduce_sum(tf.multiply(x[0], x[1]),axis=axis, keepdims=keepdims) / temperature)([x, y]) def recall_N(y_true, y_pred, N=50):