Skip to content

Commit

Permalink
🐛 Format last file
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Nov 19, 2024
1 parent 8d69b43 commit cf5e9e3
Showing 1 changed file with 8 additions and 24 deletions.
32 changes: 8 additions & 24 deletions torch_uncertainty/post_processing/abnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def __init__(
self.weights = []
for _ in range(num_models):
weight = torch.ones([num_classes])
weight[torch.randperm(num_classes)[:num_rp_classes]] += (
random_prior - 1
)
weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1
self.weights.append(weight)

def fit(self, dataset: Dataset) -> None:
Expand All @@ -104,9 +102,7 @@ def fit(self, dataset: Dataset) -> None:
ClassificationRoutine(
num_classes=self.num_classes,
model=mod,
loss=nn.CrossEntropyLoss(
weight=self.weights[i].to(device=self.device)
),
loss=nn.CrossEntropyLoss(weight=self.weights[i].to(device=self.device)),
optim_recipe=optim_abnn(mod, lr=self.base_lr),
eval_ood=True,
)
Expand All @@ -133,9 +129,7 @@ def fit(self, dataset: Dataset) -> None:
for baseline in baselines:
model = copy.deepcopy(source_model)
model.load_state_dict(baseline.model.state_dict())
final_models.extend(
[copy.deepcopy(model) for _ in range(self.num_samples)]
)
final_models.extend([copy.deepcopy(model) for _ in range(self.num_samples)])

self.final_model = deep_ensembles(final_models)

Expand All @@ -161,31 +155,21 @@ def _abnn_checks(
batch_size,
) -> None:
if random_prior < 0:
raise ValueError(
f"random_prior must be greater than 0. Got {random_prior}."
)
raise ValueError(f"random_prior must be greater than 0. Got {random_prior}.")
if batch_size < 1:
raise ValueError(
f"batch_size must be greater than 0. Got {batch_size}."
)
raise ValueError(f"batch_size must be greater than 0. Got {batch_size}.")
if max_epochs < 1:
raise ValueError(f"epoch must be greater than 0. Got {max_epochs}.")
if num_models < 1:
raise ValueError(
f"num_models must be greater than 0. Got {num_models}."
)
raise ValueError(f"num_models must be greater than 0. Got {num_models}.")
if num_samples < 1:
raise ValueError(
f"num_samples must be greater than 0. Got {num_samples}."
)
raise ValueError(f"num_samples must be greater than 0. Got {num_samples}.")
if alpha < 0:
raise ValueError(f"alpha must be greater than 0. Got {alpha}.")
if base_lr < 0:
raise ValueError(f"base_lr must be greater than 0. Got {base_lr}.")
if num_classes < 1:
raise ValueError(
f"num_classes must be greater than 0. Got {num_classes}."
)
raise ValueError(f"num_classes must be greater than 0. Got {num_classes}.")


def _replace_bn_layers(model: nn.Module, alpha: float) -> None:
Expand Down

0 comments on commit cf5e9e3

Please sign in to comment.