From b2039a20f7d4ab369504aaea4467346b49bdb0a5 Mon Sep 17 00:00:00 2001 From: lizhenyang Date: Wed, 11 Jan 2023 19:47:35 +0800 Subject: [PATCH] enhance inner_product function update parameters with axis and keepdims --- deepmatch/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmatch/utils.py b/deepmatch/utils.py index becdc9f..a482b1c 100644 --- a/deepmatch/utils.py +++ b/deepmatch/utils.py @@ -50,8 +50,8 @@ def l2_normalize(x, axis=-1): return Lambda(lambda x: tf.nn.l2_normalize(x, axis))(x) -def inner_product(x, y, temperature=1.0): - return Lambda(lambda x: tf.reduce_sum(tf.multiply(x[0], x[1])) / temperature)([x, y]) +def inner_product(x, y, temperature=1.0, axis=None, keepdims=False): + return Lambda(lambda x: tf.reduce_sum(tf.multiply(x[0], x[1]),axis=axis, keepdims=keepdims) / temperature)([x, y]) def recall_N(y_true, y_pred, N=50):