diff --git a/basicsr/archs/srvgg_arch.py b/basicsr/archs/srvgg_arch.py index d8fe5ceb4..f178a42f7 100644 --- a/basicsr/archs/srvgg_arch.py +++ b/basicsr/archs/srvgg_arch.py @@ -60,11 +60,11 @@ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale= def forward(self, x): out = x - for i in range(0, len(self.body)): - out = self.body[i](out) + for layer in self.body: + out = layer(out) out = self.upsampler(out) # add the nearest upsampled image, so that the network learns the residual - base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + base = F.interpolate(x, scale_factor=float(self.upscale), mode='nearest') out += base return out