-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
507 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
*.DS_Store | ||
wandb/ | ||
**.jpg | ||
**.png | ||
results/ | ||
.idea/ | ||
.vscode/ | ||
venv/ | ||
*.tar.gz | ||
*.zip | ||
*.pkl | ||
*.pyc | ||
*.mat | ||
*.npy | ||
*.pth |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
from pathlib import Path | ||
from torchvision import io | ||
from torchvision.transforms import functional as tf | ||
|
||
VAL = .1 | ||
INPUT = ".jpg" | ||
TARGET = ".npy" | ||
|
||
def sort_key(filepath): | ||
# sorts by name | ||
return int(filepath.name.split("_")[0]) | ||
|
||
|
||
def read_image(path): | ||
return tf.convert_image_dtype(io.read_image(path), torch.float32) | ||
|
||
|
||
def save_raw(raw, path, extension="npy"): | ||
raw = (raw * 1024).astype(np.uint16) | ||
final_path = f"{path}.{extension}" | ||
np.save(final_path, raw) | ||
|
||
|
||
class Evaluation_Dataset(Dataset): | ||
def __init__(self, opt): | ||
self.opt = opt | ||
root_dir = opt.dataset_dir | ||
|
||
source = Path(root_dir) | ||
input_list = sorted(source.glob(f"*{INPUT}"), key=sort_key) | ||
|
||
self.patch_list = [ | ||
{"input": path.as_posix()} for path in input_list | ||
] | ||
|
||
|
||
def __getitem__(self, index): | ||
input = read_image(self.patch_list[index]["input"]) | ||
name = self.patch_list[index]["input"].split('/')[-1].split(".")[0] | ||
sample = {'input':input, 'name':name} | ||
return sample | ||
|
||
def __len__(self): | ||
return len(self.patch_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import numpy as np | ||
from torchinfo import summary | ||
|
||
|
||
def get_gmacs_and_params(model, device, input_size=(1, 3, 6, 1060, 1900), print_detailed_breakdown=False): | ||
""" This function calculates the total MACs and Parameters of a given pytorch model. | ||
Args: | ||
model: A pytorch model object | ||
input_size: (batch_size, num images, channels, height, width) - input dimensions for a single NTIRE test scene | ||
Returns: | ||
total_mult_adds: The total number of GMacs for the given model and input size | ||
total_params: The total number of parameters in the model | ||
""" | ||
model_summary = summary(model, device=device,input_size=input_size, verbose=2 if print_detailed_breakdown else 0) | ||
return model_summary.total_mult_adds/10**9, model_summary.total_params | ||
|
||
def get_runtime(model, device, input_size=(1, 3, 6, 1060, 1900), num_reps=100): | ||
""" This function calculates the mean runtime of a given pytorch model. | ||
More info: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/ | ||
Args: | ||
model: A pytorch model object | ||
input_size: (batch_size, num images, channels, height, width) - input dimensions for a single NTIRE test scene | ||
num_reps: The number of repetitions over which to calculate the mean runtime | ||
Returns: | ||
mean_runtime: The everage runtime of the model over num_reps iterations | ||
""" | ||
# Set measurement to device, in this case we set this to cuda | ||
#device = torch.device("cuda") | ||
model.to(device) | ||
# Define input, for this example we will use a random dummy input | ||
input = torch.randn(input_size, dtype=torch.float).to(device) | ||
# Define start and stop cuda events | ||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | ||
times=np.zeros((num_reps, 1)) | ||
# Perform warm-up runs (that are normally slower) | ||
#with torch.no_grad(): | ||
for _ in range(10): | ||
_ = model(input) | ||
# Measure actual runtime | ||
with torch.no_grad(): | ||
for it in range(num_reps): | ||
starter.record() | ||
_ = model(input) | ||
ender.record() | ||
# Await for GPU to finish the job and sync | ||
torch.cuda.synchronize() | ||
curr_time = starter.elapsed_time(ender) | ||
|
||
times[it] = curr_time / 1000 # Convert from miliseconds to seconds | ||
# Average all measured times | ||
mean_runtime = np.sum(times) / num_reps | ||
return mean_runtime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# -*- coding:utf-8 _*- | ||
import os | ||
import torch | ||
from options import BaseOptions | ||
from torch.utils.data import DataLoader | ||
from dataset import dataset | ||
import logging | ||
import models | ||
from tqdm import tqdm | ||
from utils import utils | ||
from evaluation.performance import get_gmacs_and_params, get_runtime | ||
|
||
|
||
def main(): | ||
# settings | ||
args = BaseOptions().parse() | ||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | ||
|
||
# cuda and devices | ||
use_cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
if use_cuda: | ||
device = torch.device('cuda') | ||
if args.gpu is not None: | ||
torch.cuda.set_device(args.gpu) | ||
device = torch.device(torch.cuda.current_device()) | ||
else: | ||
device = torch.device('cpu') | ||
|
||
# dataset and dataloader | ||
inference_dataset = dataset.Evaluation_Dataset(opt=args) | ||
inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False, num_workers=args.workers) | ||
|
||
# model architectures | ||
model = models.create_model(args) | ||
model.to(device) | ||
weights = os.path.join(args.checkpoints, f'{args.trained_model}.pth') | ||
model = utils.load_from_checkpoint(weights, model, device) | ||
output_folder = os.path.join(args.output_folder, args.trained_model) | ||
os.makedirs(output_folder, exist_ok=True) | ||
|
||
n_samples = len(inference_loader) | ||
with torch.no_grad(): | ||
desc_phase = "Inference:" | ||
tqbar = tqdm(inference_loader, leave=False, total=n_samples, desc=desc_phase) | ||
for batch_idx, batch_data in enumerate(tqbar, 0): | ||
|
||
name = batch_data['name'][0] | ||
img_output_path = os.path.join(output_folder, name) | ||
tqbar.set_description(name, refresh=True) | ||
|
||
input = batch_data['input'].to(device) | ||
estimation = model(input).clamp(0, 1).squeeze(0).cpu().permute(1, 2, 0).numpy() | ||
|
||
dataset.save_raw(estimation, img_output_path) | ||
|
||
print("Running ops metrics") | ||
if args.trained_model == "p20": | ||
sample = (1, 3, 496,496) | ||
elif args.trained_model == "s7": | ||
sample = (1, 3, 504, 504) | ||
else: | ||
sample = (1, 3, 512, 512) | ||
|
||
with torch.no_grad(): | ||
total_macs, total_params = get_gmacs_and_params(model, device, input_size=sample) | ||
mean_runtime = get_runtime(model, device, input_size=sample) | ||
|
||
|
||
print("runtime per image [s] : " + str(mean_runtime)) | ||
print("number of operations [GMAcc] : " + str(total_macs)) | ||
print("number of parameters : " + str(total_params)) | ||
|
||
metrics_path = os.path.join(output_folder, "readme.txt") | ||
with open(metrics_path, 'w') as f: | ||
f.write(f"Runtime per image {sample}[s] : " + str(mean_runtime)) | ||
#f.write('\n') | ||
#f.write("number of operations [GMAcc] : " + str(total_macs)) | ||
#f.write('\n') | ||
#f.write("number of parameters : " + str(total_params)) | ||
f.write('\n') | ||
f.write("CPU[1] / GPU[0] : 0") | ||
f.write('\n') | ||
f.write("Extra Data [1] / No Extra Data [0] : 0") | ||
f.write('\n') | ||
f.write("Other description: We have a Pytorch implementation, and report single GPU runtime. The method was trained on the training dataset (- 10% for validation) for 300 epochs.") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#nice python inference_evaluation.py --gpu 0 --amp --dataset_dir /mnt/D/data/aim-reverse-isp/p20/test --trained_model p20 | ||
nice python inference_evaluation.py --gpu 0 --amp --dataset_dir /mnt/D/data/aim-reverse-isp/s7/test --trained_model s7 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import importlib | ||
import torch | ||
from torch import nn | ||
import logging | ||
ACTIVATIONS = ["relu", "leaky", "mish", "tanh", "silu", "sigmoid", "gelu"] | ||
MODELS = [ | ||
"base", "base_m", "base_cm", | ||
"medium_cm", "micro","base3bmix","base3bmix2", | ||
"base9_0cm", "base9_001","base9_001cm","base9_00x","base9_00xcm","unet9_00xcm","base9_001cm_mix","base9_001cm_mix2", | ||
"rgb2raw", "rfdn", "aimbase", "base2","base2mix", "base2bmix","base3","base3b","base3c", "base4", "upi", "base5", "base6", "base7", | ||
"base7_0","base7_01", "base7_02", "base7_03", "base8", "base9","base9_0","base9_00","base10","base10_1", "base7_3x3", "base7_3x3f", "base7_3x3cm", | ||
] | ||
|
||
def find_model_using_name(model_name): | ||
# Given the option --model [modelname], | ||
# the file "models/modelname_model.py" | ||
# will be imported. | ||
model_lib_name = "models." + model_name + "_model" | ||
modellib = importlib.import_module(model_lib_name) | ||
|
||
|
||
model = None | ||
target_model_name = model_name | ||
for name, cls in modellib.__dict__.items(): | ||
if name.lower() == target_model_name.lower() \ | ||
and issubclass(cls, torch.nn.Module): | ||
model = cls | ||
|
||
if model is None: | ||
logging.error("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_lib_name, target_model_name)) | ||
exit(0) | ||
|
||
return model | ||
|
||
def create_model(opt, name=None): | ||
if name is None: | ||
name = opt.model | ||
|
||
model = find_model_using_name(name) | ||
instance = model(opt) | ||
logging.info("# Model [%s] created" % (type(instance).__name__)) | ||
|
||
return instance | ||
|
||
|
||
|
||
def activation(act_type="relu", slope=0.2): | ||
if act_type == "leaky": | ||
return nn.LeakyReLU(negative_slope=slope) | ||
elif act_type == "relu": | ||
return nn.ReLU() | ||
elif act_type == "mish": | ||
return nn.Mish() | ||
elif act_type == "tanh": | ||
return nn.Tanh() | ||
elif act_type == "silu": | ||
return nn.SiLU() | ||
elif act_type == "gelu": | ||
return nn.GELU() | ||
elif act_type == "sigmoid": | ||
return nn.Sigmoid() | ||
else: | ||
raise ValueError(f"Unknown activation [{act_type}]") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch.nn as nn | ||
from models import activation | ||
|
||
|
||
class RB(nn.Module): | ||
"""Residual Block""" | ||
def __init__(self, bc, act="silu"): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(bc, bc * 2, kernel_size=3, padding=1, padding_mode="reflect", bias=True) | ||
self.act1 = activation(act) | ||
self.conv2 = nn.Conv2d(bc * 2, bc, kernel_size=3, padding=1, padding_mode="reflect", bias=True) | ||
self.act2 = activation(act) | ||
|
||
def forward(self, x): | ||
xx = self.act1(self.conv1(x)) | ||
xx = self.conv2(xx) + x | ||
return self.act2(xx) | ||
|
||
class CSTB(nn.Module): | ||
"""Color Shift Transformation Block""" | ||
def __init__(self, ic, bc, act="silu"): | ||
super().__init__() | ||
self.block = nn.Sequential( | ||
CB(ic, bc, ks=3, act=act), | ||
CB(bc, bc, ks=3, act=act), | ||
CB(bc, bc, ks=3, act=act), | ||
nn.Conv2d(bc , bc, kernel_size=1, bias=True) | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
class CB(nn.Module): | ||
"""Convolutional Block""" | ||
def __init__(self, ic, oc, act="silu", ks=3): | ||
super().__init__() | ||
if ks == 1: | ||
self.conv = nn.Conv2d(ic, oc, kernel_size=ks, bias=True) | ||
else: | ||
padding = ks // 2 | ||
self.conv = nn.Conv2d(ic, oc, kernel_size=ks, padding=padding, padding_mode="reflect", bias=True) | ||
self.act = activation(act) | ||
|
||
def forward(self, x): | ||
return self.act(self.conv(x)) | ||
|
||
|
||
class TMB(nn.Module): | ||
"""Tone Mapping Block""" | ||
def __init__(self, ic, bc, act="silu"): | ||
super().__init__() | ||
self.block = nn.Sequential( | ||
CB(ic, bc, ks=1, act=act), | ||
CB(bc, bc, ks=1, act=act), | ||
CB(bc, bc, ks=1, act=act), | ||
nn.Conv2d(bc , bc, kernel_size=1, bias=True) | ||
) | ||
|
||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
class LSB(nn.Module): | ||
"""Lens Shading Block""" | ||
def __init__(self, ic, bc, act="silu"): | ||
super().__init__() | ||
self.block = nn.Sequential( | ||
CB(ic, bc, ks=3, act=act), | ||
CB(bc, bc, ks=3, act=act), | ||
CB(bc, bc, ks=3, act="sigmoid"), | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
class Base(nn.Module): | ||
def __init__(self, opt=None): | ||
super().__init__() | ||
self.opt = opt | ||
|
||
ic = 3 | ||
bc = opt.base_channels | ||
|
||
self.p = CB(ic, bc, ks=1) | ||
self.tm = TMB(bc, bc) | ||
self.r1 = RB(bc) | ||
self.cstb = CSTB(bc, bc) | ||
self.r2 = RB(bc) | ||
self.lsb = LSB(bc, bc) | ||
self.cout = nn.Conv2d(bc, 4, kernel_size=2, padding=0, stride=2, bias=True) | ||
self.aout = activation("relu") | ||
|
||
|
||
def forward(self, x): | ||
|
||
|
||
xp = self.p(x) | ||
x_ = xp * self.tm(xp) | ||
x_ = self.r1(x_) | ||
x_ = x_ * self.cstb(xp) | ||
x_ = self.r2(x_) | ||
x_ = x_ * self.lsb(xp) | ||
out = self.aout(self.cout(x_)) | ||
return out |
Oops, something went wrong.