diff --git a/pytorch_to_caffe.py b/pytorch_to_caffe.py index d58aec1..94cd6a1 100755 --- a/pytorch_to_caffe.py +++ b/pytorch_to_caffe.py @@ -191,6 +191,19 @@ def _avg_pool2d(raw,input, kernel_size, stride = None, padding = 0, ceil_mode = _pool('ave',raw,input, x, kernel_size, stride, padding,ceil_mode) return x +def _adaptive_avg_pool2d(raw, input, output_size): + _output_size = _list_with_default(output_size, input.size()) + x = raw(input, _output_size) + if isinstance(_output_size, int): + out_dim = _output_size + else: + out_dim = _output_size[0] + tmp = max(input.shape[2], input.shape[3]) + stride = tmp //out_dim + kernel_size = tmp - (out_dim - 1) * stride + _pool('ave', raw, input, x, kernel_size, stride, 0, False) + return x + def _max(raw,*args): x=raw(*args) if len(args)==1: