diff --git a/lume_model/torch/module.py b/lume_model/torch/module.py index 1d2b02c..5943dc1 100644 --- a/lume_model/torch/module.py +++ b/lume_model/torch/module.py @@ -31,6 +31,8 @@ def __init__( self._feature_order = feature_order self._output_order = output_order self.register_module("base_model", self._model.model) + if not model.model.training: # PyTorchModel defines train/eval mode + self.eval() @property def feature_order(self):