From cf5e9e3ccfcee2724645ef4f5c7e3cca8d34779f Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 19 Nov 2024 10:03:55 +0100 Subject: [PATCH] :bug: Format last file --- torch_uncertainty/post_processing/abnn.py | 32 ++++++----------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 9dd4e79d..0ec24375 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -85,9 +85,7 @@ def __init__( self.weights = [] for _ in range(num_models): weight = torch.ones([num_classes]) - weight[torch.randperm(num_classes)[:num_rp_classes]] += ( - random_prior - 1 - ) + weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1 self.weights.append(weight) def fit(self, dataset: Dataset) -> None: @@ -104,9 +102,7 @@ def fit(self, dataset: Dataset) -> None: ClassificationRoutine( num_classes=self.num_classes, model=mod, - loss=nn.CrossEntropyLoss( - weight=self.weights[i].to(device=self.device) - ), + loss=nn.CrossEntropyLoss(weight=self.weights[i].to(device=self.device)), optim_recipe=optim_abnn(mod, lr=self.base_lr), eval_ood=True, ) @@ -133,9 +129,7 @@ def fit(self, dataset: Dataset) -> None: for baseline in baselines: model = copy.deepcopy(source_model) model.load_state_dict(baseline.model.state_dict()) - final_models.extend( - [copy.deepcopy(model) for _ in range(self.num_samples)] - ) + final_models.extend([copy.deepcopy(model) for _ in range(self.num_samples)]) self.final_model = deep_ensembles(final_models) @@ -161,31 +155,21 @@ def _abnn_checks( batch_size, ) -> None: if random_prior < 0: - raise ValueError( - f"random_prior must be greater than 0. Got {random_prior}." - ) + raise ValueError(f"random_prior must be greater than 0. Got {random_prior}.") if batch_size < 1: - raise ValueError( - f"batch_size must be greater than 0. Got {batch_size}." - ) + raise ValueError(f"batch_size must be greater than 0. Got {batch_size}.") if max_epochs < 1: raise ValueError(f"epoch must be greater than 0. Got {max_epochs}.") if num_models < 1: - raise ValueError( - f"num_models must be greater than 0. Got {num_models}." - ) + raise ValueError(f"num_models must be greater than 0. Got {num_models}.") if num_samples < 1: - raise ValueError( - f"num_samples must be greater than 0. Got {num_samples}." - ) + raise ValueError(f"num_samples must be greater than 0. Got {num_samples}.") if alpha < 0: raise ValueError(f"alpha must be greater than 0. Got {alpha}.") if base_lr < 0: raise ValueError(f"base_lr must be greater than 0. Got {base_lr}.") if num_classes < 1: - raise ValueError( - f"num_classes must be greater than 0. Got {num_classes}." - ) + raise ValueError(f"num_classes must be greater than 0. Got {num_classes}.") def _replace_bn_layers(model: nn.Module, alpha: float) -> None: