Skip to content
This repository has been archived by the owner on Jun 27, 2022. It is now read-only.

Patch BilinearResize2D #28

Merged
merged 4 commits into from
Apr 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
.. autosummary::
:nosignatures:

BilinearResize2D
CTCLoss
DeformableConvolution
DeformablePSROIPooling
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
.. autosummary::
:nosignatures:

BilinearResize2D
CTCLoss
DeformableConvolution
DeformablePSROIPooling
Expand Down
180 changes: 180 additions & 0 deletions src/operator/contrib/bilinear_resize-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file bilinear_resize-inl.h
* \brief bilinear resize operator
* \author Hang Zhang
*/
#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_
#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/ndarray.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
/* contrib
#include "../ndarray/ndarray_function.h"
#include "./operator_common.h"
#include "./mxnet_op.h"
#include "./mshadow_op.h"
*/
#include "../../ndarray/ndarray_function.h"
#include "../operator_common.h"
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "../tensor/init_op.h"

namespace mxnet {
namespace op {

struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
int height;
int width;
DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
DMLC_DECLARE_FIELD(height).set_range(1, 1000)
.describe("output height (required)");
DMLC_DECLARE_FIELD(width).set_range(1, 1000)
.describe("output width (required)");
}
};

static inline bool IsWriting(const OpReqType ort) {
return ort == kWriteTo || ort == kWriteInplace;
}

template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output);

template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output);

#if MXNET_USE_CUDA
template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output);

template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output);
#endif // MXNET_USE_CUDA

template <typename xpu>
inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
SpatialUpSamplingBilinearUpdateOutput<xpu, DType, AccReal>(s, inputs, outputs);
});
}


template <typename xpu>
inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (IsWriting(req[0])) {
// zero grad before backwarding
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Fill<false>(s, outputs[0], kWriteTo, 0);
})
}
MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, outputs);
});
}


static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
CHECK_EQ(out_shape->size(), 1U) << "Output:[data]";
const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
TShape dshape(in_shape->at(0));
if (dshape.ndim() == 0) return false;
dshape[2] = param.height;
dshape[3] = param.width;
out_shape->clear();
out_shape->push_back(dshape);
return true;
}

static bool BilinearSampleOpInferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
using namespace mshadow;
CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param = 0;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype_param);
return true;
}

static inline bool BilinearSampleOpStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
*dispatch_mode = DispatchMode::kFCompute;
for (int& v : *in_attrs) {
if (v == - 1) v = kDefaultStorage;
}
for (size_t i = 0; i < out_attrs->size(); i++) {
(*out_attrs)[i] = kDefaultStorage;
}
return true;
}


} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_

199 changes: 199 additions & 0 deletions src/operator/contrib/bilinear_resize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file bilinear_resize.cc
* \brief bilinear resize operator
* \author Hang Zhang
*/
#include "bilinear_resize-inl.h"
// #include "elemwise_op_common.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

using namespace mshadow;

template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output) {
Tensor<xpu, 4, DType> itensor = input[0].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> otensor = output[0].get<xpu, 4, DType>(s);
int nbatch = otensor.size(0);
int channels = otensor.size(1);
int outputHeight = otensor.size(2);
int outputWidth = otensor.size(3);
int inputHeight = itensor.size(2);
int inputWidth = itensor.size(3);

DType *idata = itensor.dptr_;
DType *odata = otensor.dptr_;
channels = nbatch * channels;
// special case: just copy
if (inputHeight == outputHeight && inputWidth == outputWidth) {
for (int h2 = 0; h2 < outputHeight; ++h2) {
const int h1 = h2;
for (int w2 = 0; w2 < outputWidth; ++w2) {
const int w1 = w2;
const DType* pos1 = &idata[h1 * inputWidth + w1];
DType* pos2 = &odata[h2 * outputWidth + w2];
for (int c = 0; c < channels; ++c) {
pos2[0] = pos1[0];
pos1 += inputWidth * inputHeight;
pos2 += outputWidth * outputHeight;
}
}
}
return;
}
const float rheight =(outputHeight > 1) ? static_cast<float>(inputHeight - 1)/
(outputHeight - 1) : 0.f;
const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1) /
(outputWidth - 1) : 0.f;
for (int h2 = 0; h2 < outputHeight; ++h2) {
const float h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
const DType h1lambda = h1r - h1;
const DType h0lambda = (DType)1. - h1lambda;
for (int w2 = 0; w2 < outputWidth; ++w2) {
const float w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
const DType w1lambda = w1r - w1;
const DType w0lambda = (DType)1. - w1lambda;
const DType* pos1 = &idata[h1 * inputWidth + w1];
DType* pos2 = &odata[h2 * outputWidth + w2];
for (int c = 0; c < channels; ++c) {
pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p])
+ h1lambda * (w0lambda * pos1[h1p * inputWidth]
+ w1lambda * pos1[h1p * inputWidth + w1p]);
pos1 += inputWidth * inputHeight;
pos2 += outputWidth * outputHeight;
}
}
}
}


template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output) {
Tensor<xpu, 4, DType> gradOutput = input[0].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> gradInput = output[0].get<xpu, 4, DType>(s);

int nbatch = gradInput.size(0);
int channels = gradInput.size(1);
int outputHeight = gradOutput.size(2);
int outputWidth = gradOutput.size(3);
int inputHeight = gradInput.size(2);
int inputWidth = gradInput.size(3);

DType *data1 = gradInput.dptr_;
DType *data2 = gradOutput.dptr_;
channels = nbatch * channels;

// special case: same-size matching grids
if (inputHeight == outputHeight && inputWidth == outputWidth) {
for (int h2 = 0; h2 < outputHeight; ++h2) {
const int h1 = h2;
for (int w2 = 0; w2 < outputWidth; ++w2) {
const int w1 = w2;
DType* pos1 = &data1[h1 * inputWidth + w1];
const DType* pos2 = &data2[h2 * outputWidth + w2];
for (int c = 0; c < channels; ++c) {
pos1[0] += pos2[0];
pos1 += inputWidth * inputHeight;
pos2 += outputWidth * outputHeight;
}
}
}
return;
}
const float rheight =(outputHeight > 1) ? static_cast<float>(inputHeight - 1)/
(outputHeight - 1) : 0.f;
const float rwidth = (outputWidth > 1) ? static_cast<float>(inputWidth - 1)/
(outputWidth - 1) : 0.f;
for (int h2 = 0; h2 < outputHeight; ++h2) {
const float h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
const DType h1lambda = h1r - h1;
const DType h0lambda = (DType)1. - h1lambda;
for (int w2 = 0; w2 < outputWidth; ++w2) {
const float w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
const DType w1lambda = w1r - w1;
const DType w0lambda = (DType)1. - w1lambda;
DType* pos1 = &data1[h1 * inputWidth + w1];
const DType* pos2 = &data2[h2 * outputWidth + w2];
for (int c = 0; c < channels; ++c) {
pos1[0] += h0lambda * w0lambda * pos2[0];
pos1[w1p] += h0lambda * w1lambda * pos2[0];
pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0];
pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0];
pos1 += inputWidth * inputHeight;
pos2 += outputWidth * outputHeight;
}
}
}
}


DMLC_REGISTER_PARAMETER(BilinearSampleParam);

NNVM_REGISTER_OP(_contrib_BilinearResize2D)
.describe(R"code(
Perform 2D resizing (upsampling or downsampling) for 4D input using bilinear interpolation.

Expected input is a 4 dimensional NDArray (NCHW) and the output
with the shape of (N x C x height x width).
The key idea of bilinear interpolation is to perform linear interpolation
first in one direction, and then again in the other direction. See the wikipedia of
`Bilinear interpolation <https://en.wikipedia.org/wiki/Bilinear_interpolation>`_
for more details.
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<BilinearSampleParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", BilinearSampleOpInferShape)
.set_attr<nnvm::FInferType>("FInferType", BilinearSampleOpInferType)
.set_attr<FInferStorageType>("FInferStorageType", BilinearSampleOpStorageType)
.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"})
.add_argument("data", "NDArray-or-Symbol", "Input data")
.add_arguments(BilinearSampleParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
.set_attr_parser(ParamParser<BilinearSampleParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BilinearSampleOpStorageType)
.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpBackward<cpu>);


} // namespace op
} // namespace mxnet

Loading