From e3c41b6906c09c5d69b808442c3110317d1413eb Mon Sep 17 00:00:00 2001 From: Melvin-klein Date: Tue, 8 Apr 2025 10:34:32 +0200 Subject: [PATCH 1/2] Add DiffPIR --- solvers/diffpir.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 solvers/diffpir.py diff --git a/solvers/diffpir.py b/solvers/diffpir.py new file mode 100644 index 0000000..ce02f31 --- /dev/null +++ b/solvers/diffpir.py @@ -0,0 +1,39 @@ +from benchopt import BaseSolver, safe_import_context + +with safe_import_context() as import_ctx: + import torch + from torch.utils.data import DataLoader + import deepinv as dinv + import numpy as np + + +class Solver(BaseSolver): + name = 'DiffPIR' + + parameters = {} + + sampling_strategy = 'run_once' + + requirements = [] + + def set_objective(self, train_dataset, physics): + batch_size = 2 + self.train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=False + ) + self.device = ( + dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" + ) + self.physics = physics + + def run(self, n_iter): + denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) + + self.model = dinv.sampling.DiffPIR( + model=denoiser, + data_fidelity=dinv.optim.data_fidelity.L2() + ) + self.model.eval() + + def get_result(self): + return dict(model=self.model, model_name="DiffPIR", device=self.device) From f9502bcd5e1aac9900d77e4d62768c52b236d9d3 Mon Sep 17 00:00:00 2001 From: Melvin-klein Date: Tue, 8 Apr 2025 10:44:30 +0200 Subject: [PATCH 2/2] Fix flake 8 --- solvers/diffpir.py | 1 - 1 file changed, 1 deletion(-) diff --git a/solvers/diffpir.py b/solvers/diffpir.py index ce02f31..95ca08b 100644 --- a/solvers/diffpir.py +++ b/solvers/diffpir.py @@ -4,7 +4,6 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv - import numpy as np class Solver(BaseSolver):