diff --git a/torch_uncertainty/models/wrappers/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py index 2a333e07..da240c08 100644 --- a/torch_uncertainty/models/wrappers/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -107,7 +107,7 @@ def deep_ensembles( if reset_model_parameters: for model in models: - for layer in model.children(): + for layer in model.modules(): if hasattr(layer, "reset_parameters"): layer.reset_parameters()