-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobjective.py
111 lines (93 loc) · 4.47 KB
/
objective.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from benchopt import BaseObjective, safe_import_context
# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.
with safe_import_context() as import_ctx:
import torch
from torch.utils.data import DataLoader
import deepinv as dinv
# The benchmark objective must be named `Objective` and
# inherit from `BaseObjective` for `benchopt` to work properly.
class Objective(BaseObjective):
# Name to select the objective in the CLI and to display the results.
name = "Inverse Problems"
# URL of the main repo for this benchmark.
url = "https://github.com/benchopt/benchmark_inverse_problems"
# List of parameters for the objective. The benchmark will consider
# the cross product for each key in the dictionary.
# All parameters 'p' defined here are available as 'self.p'.
# This means the OLS objective will have a parameter `self.whiten_y`.
parameters = {}
# List of packages needed to run the benchmark.
# They are installed with conda; to use pip, use 'pip:packagename'. To
# install from a specific conda channel, use 'channelname:packagename'.
# Packages that are not necessary to the whole benchmark but only to some
# solvers or datasets should be declared in Dataset or Solver (see
# simulated.py and python-gd.py).
# Example syntax: requirements = ['numpy', 'pip:jax', 'pytorch:pytorch']
requirements = ["pytorch", "numpy", "deepinv"]
# Minimal version of benchopt required to run this benchmark.
# Bump it up if the benchmark depends on a new feature of benchopt.
min_benchopt_version = "1.5"
def set_data(self,
train_dataset,
test_dataset,
physics,
dataset_name,
task_name):
# The keyword arguments of this function are the keys of the dictionary
# returned by `Dataset.get_data`. This defines the benchmark's
# API to pass data. This is customizable for each benchmark.
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.physics = physics
self.dataset_name = dataset_name
self.task_name = task_name
def evaluate_result(self, model, model_name, device):
# The keyword arguments of this function are the keys of the
# dictionary returned by `Solver.get_result`. This defines the
# benchmark's API to pass solvers' result. This is customizable for
# each benchmark.
batch_size = 2
test_dataloader = DataLoader(
self.test_dataset, batch_size=batch_size, shuffle=False
)
if isinstance(model, dinv.models.DeepImagePrior):
psnr = []
ssim = []
for x, y in test_dataloader:
x, y = x.to(device), y.to(device)
x_hat = torch.cat([
model(y_i[None], self.physics) for y_i in y
])
psnr.append(dinv.metric.PSNR()(x_hat, x))
ssim.append(dinv.metric.SSIM()(x_hat, x))
psnr = torch.mean(torch.cat(psnr)).item()
ssim = torch.mean(torch.cat(ssim)).item()
results = dict(PSNR=psnr, SSIM=ssim)
else:
results = dinv.test(
model,
test_dataloader,
self.physics,
metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()],
device=device
)
# This method can return many metrics in a dictionary. One of these
# metrics needs to be `value` for convergence detection purposes.
return dict(
value=results["PSNR"],
ssim=results["SSIM"],
)
def get_one_result(self):
# Return one solution. The return value should be an object compatible
# with `self.evaluate_result`. This is mainly for testing purposes.
model = dinv.optim.DPIR(sigma=0.03, device="cpu")
return dict(model=model, model_name="TestSolver", device="cpu")
def get_objective(self):
# Define the information to pass to each solver to run the benchmark.
# The output of this function are the keyword arguments
# for `Solver.set_objective`. This defines the
# benchmark's API for passing the objective to the solver.
# It is customizable for each benchmark.
return dict(train_dataset=self.train_dataset, physics=self.physics)