From 708c78d020908cba93b404060e4f06f08b695918 Mon Sep 17 00:00:00 2001 From: Mathieu Turgeon-Pelchat Date: Tue, 8 Oct 2024 11:07:57 -0400 Subject: [PATCH] script_model fix for TTA --- utils/script_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/script_model.py b/utils/script_model.py index fd8a43ac..bee43287 100644 --- a/utils/script_model.py +++ b/utils/script_model.py @@ -32,8 +32,8 @@ def forward(self, input): shape = input.shape B, C = shape[0], shape[1] input = (self.max_val - self.min_val) * (input - self.min) / (self.max -self.min) + self.min_val - input = (input.view(B, C, -1) - self.mean) / self.std - input = input.view(shape) + input = (input.reshape(B, C, -1) - self.mean) / self.std + input = input.reshape(shape) output = self.model_scripted(input.to(self.device)) if self.from_logits: if self.num_classes == 1: